diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 000000000..d9089f491
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,58 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: end-of-file-fixer
+ - id: check-case-conflict
+ - id: check-executables-have-shebangs
+ - id: requirements-txt-fixer
+ - id: check-added-large-files
+ - id: check-case-conflict
+ - id: check-toml
+ exclude: ^packages/jupyter-ai-module-cookiecutter
+ - id: check-yaml
+ exclude: ^packages/jupyter-ai-module-cookiecutter
+ - id: debug-statements
+ - id: forbid-new-submodules
+ - id: check-builtin-literals
+ - id: trailing-whitespace
+
+ - repo: https://github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ args: ["--profile", "black"]
+ files: \.py$
+
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v3.4.0
+ hooks:
+ - id: pyupgrade
+ args: [--py37-plus]
+
+ - repo: https://github.com/pycqa/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+ additional_dependencies:
+ [
+ "flake8-bugbear==20.1.4",
+ "flake8-logging-format==0.6.0",
+ "flake8-implicit-str-concat==0.2.0",
+ ]
+ stages: [manual]
+
+ - repo: https://github.com/sirosen/check-jsonschema
+ rev: 0.23.1
+ hooks:
+ - id: check-jsonschema
+ name: "Check GitHub Workflows"
+ files: ^\.github/workflows/
+ types: [yaml]
+ args: ["--schemafile", "https://json.schemastore.org/github-workflow"]
+ stages: [manual]
diff --git a/README.md b/README.md
index 9b08ed418..e9327fa11 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Jupyter AI
-Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly
+Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly
and powerful way to explore generative AI models in notebooks and improve your productivity
in JupyterLab and the Jupyter Notebook. More specifically, Jupyter AI offers:
@@ -41,7 +41,7 @@ First, install [conda](https://conda.io/projects/conda/en/latest/user-guide/inst
If you are using an Apple Silicon-based Mac (M1, M1 Pro, M2, etc.), you need to uninstall the `pip` provided version of `grpcio` and install the version provided by `conda` instead.
- $ pip uninstall grpcio; conda install grpcio
+ $ pip uninstall grpcio; conda install grpcio
If you are not using JupyterLab and you only want to install the Jupyter AI `%%ai` magic, skip the `pip install jupyter_ai` step above, and instead, run:
@@ -59,7 +59,7 @@ Once you have installed the `%%ai` magic, you can enable it in any notebook or t
or:
%load_ext jupyter_ai
-
+
The screenshots below are from notebooks in the `examples/` directory of this package.
Then, you can use the `%%ai` magic command to specify a model and natural language prompt:
diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css
index 1858f4fee..e91d1bbcd 100644
--- a/docs/source/_static/css/custom.css
+++ b/docs/source/_static/css/custom.css
@@ -2,4 +2,4 @@ img.screenshot {
/* Copied from div.admonition */
box-shadow: 0 .2rem .5rem var(--pst-color-shadow),0 0 .0625rem var(--pst-color-shadow);
border-color: var(--pst-color-info);
-}
\ No newline at end of file
+}
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 8edc80417..bdd277e22 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -6,9 +6,9 @@
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
-project = 'Jupyter AI'
-copyright = '2023, Project Jupyter'
-author = 'Project Jupyter'
+project = "Jupyter AI"
+copyright = "2023, Project Jupyter"
+author = "Project Jupyter"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
@@ -16,17 +16,17 @@
extensions = ["myst_parser"]
myst_enable_extensions = ["colon_fence"]
-templates_path = ['_templates']
+templates_path = ["_templates"]
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "pydata_sphinx_theme"
-html_static_path = ['_static']
+html_static_path = ["_static"]
html_css_files = [
- 'css/custom.css',
+ "css/custom.css",
]
# -- Jupyter theming -------------------------------------------------
diff --git a/docs/source/contributors/index.md b/docs/source/contributors/index.md
index cb1bc68e9..c0823ef4b 100644
--- a/docs/source/contributors/index.md
+++ b/docs/source/contributors/index.md
@@ -21,7 +21,7 @@ Due to a compatibility issue with Webpack, Node.js 18.15.0 does not work with Ju
:::
## Development install
-After you have installed the prerequisites, create a new conda environment and activate it.
+After you have installed the prerequisites, create a new conda environment and activate it.
```
conda create -n jupyter-ai python=3.10
diff --git a/docs/source/index.md b/docs/source/index.md
index 118789088..b93973a96 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -1,6 +1,6 @@
# Jupyter AI
-Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly
+Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly
and powerful way to explore generative AI models in notebooks and improve your productivity
in JupyterLab and the Jupyter Notebook. More specifically, Jupyter AI offers:
@@ -19,4 +19,4 @@ maxdepth: 2
users/index
contributors/index
-```
\ No newline at end of file
+```
diff --git a/docs/source/users/index.md b/docs/source/users/index.md
index 40c8a4c62..3e1f5def0 100644
--- a/docs/source/users/index.md
+++ b/docs/source/users/index.md
@@ -2,7 +2,7 @@
Welcome to the user documentation for Jupyter AI.
-If you are interested in contributing to Jupyter AI,
+If you are interested in contributing to Jupyter AI,
please see our {doc}`contributor's guide `.
## Prerequisites
@@ -101,7 +101,7 @@ First, install [conda](https://conda.io/projects/conda/en/latest/user-guide/inst
If you are using an Apple Silicon-based Mac (M1, M1 Pro, M2, etc.), you need to uninstall the `pip` provided version of `grpcio` and install the version provided by `conda` instead.
- $ pip uninstall grpcio; conda install grpcio
+ $ pip uninstall grpcio; conda install grpcio
If you are not using JupyterLab and you only want to install the Jupyter AI `%%ai` magic, you can run:
@@ -127,7 +127,7 @@ or
## The chat interface
-The easiest way to get started with Jupyter AI is to use the chat interface.
+The easiest way to get started with Jupyter AI is to use the chat interface.
:::{attention}
:name: open-ai-privacy-cost
@@ -449,7 +449,7 @@ will look like properly typeset equations.
Generate the 2D heat equation in LaTeX surrounded by `$$`. Do not include an explanation.
```
-This prompt will produce output as a code cell below the input cell.
+This prompt will produce output as a code cell below the input cell.
:::{warning}
:name: run-code
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
index ff35147db..c3dce919a 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
@@ -1,4 +1,11 @@
from ._version import __version__
+
+# expose embedding model providers on the package root
+from .embedding_providers import (
+ CohereEmbeddingsProvider,
+ HfHubEmbeddingsProvider,
+ OpenAIEmbeddingsProvider,
+)
from .exception import store_exception
from .magics import AiMagics
@@ -6,24 +13,20 @@
from .providers import (
AI21Provider,
AnthropicProvider,
+ BaseProvider,
+ ChatOpenAINewProvider,
+ ChatOpenAIProvider,
CohereProvider,
HfHubProvider,
OpenAIProvider,
- ChatOpenAIProvider,
- ChatOpenAINewProvider,
- SmEndpointProvider
-)
-# expose embedding model providers on the package root
-from .embedding_providers import (
- OpenAIEmbeddingsProvider,
- CohereEmbeddingsProvider,
- HfHubEmbeddingsProvider
+ SmEndpointProvider,
)
-from .providers import BaseProvider
+
def load_ipython_extension(ipython):
ipython.register_magics(AiMagics)
ipython.set_custom_exc((BaseException,), store_exception)
+
def unload_ipython_extension(ipython):
ipython.set_custom_exc((BaseException,), ipython.CustomTB)
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
index 74be10485..ab383af32 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
@@ -3,4 +3,4 @@
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
-}
\ No newline at end of file
+}
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
index 5b31725bf..0d3f866ae 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
@@ -1,8 +1,13 @@
from typing import ClassVar, List, Type
+
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy, Field
-from pydantic import BaseModel, Extra
-from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings
+from langchain.embeddings import (
+ CohereEmbeddings,
+ HuggingFaceHubEmbeddings,
+ OpenAIEmbeddings,
+)
from langchain.embeddings.base import Embeddings
+from pydantic import BaseModel, Extra
class BaseEmbeddingsProvider(BaseModel):
@@ -33,7 +38,7 @@ class Config:
model_id: str
- provider_klass: ClassVar[Type[Embeddings]]
+ provider_klass: ClassVar[Type[Embeddings]]
registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""
@@ -41,14 +46,12 @@ class Config:
fields: ClassVar[List[Field]] = []
"""Fields expected by this provider in its constructor. Each `Field` `f`
should be passed as a keyword argument, keyed by `f.key`."""
-
+
class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider):
id = "openai"
name = "OpenAI"
- models = [
- "text-embedding-ada-002"
- ]
+ models = ["text-embedding-ada-002"]
model_id_key = "model"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
@@ -58,11 +61,7 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider):
class CohereEmbeddingsProvider(BaseEmbeddingsProvider):
id = "cohere"
name = "Cohere"
- models = [
- 'large',
- 'multilingual-22-12',
- 'small'
- ]
+ models = ["large", "multilingual-22-12", "small"]
model_id_key = "model"
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py b/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py
index 733481ce1..126eedf52 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py
@@ -1,12 +1,13 @@
-from IPython.core.magic import register_line_magic
+import traceback
+
from IPython.core.getipython import get_ipython
+from IPython.core.magic import register_line_magic
from IPython.core.ultratb import ListTB
-import traceback
def store_exception(shell, etype, evalue, tb, tb_offset=None):
# A structured traceback (a list of strings) or None
-
+
if issubclass(etype, SyntaxError):
# Disable ANSI color strings
shell.SyntaxTB.color_toggle()
@@ -18,7 +19,9 @@ def store_exception(shell, etype, evalue, tb, tb_offset=None):
else:
# Disable ANSI color strings
shell.InteractiveTB.color_toggle()
- stb = shell.InteractiveTB.structured_traceback(etype, evalue, tb, tb_offset=tb_offset)
+ stb = shell.InteractiveTB.structured_traceback(
+ etype, evalue, tb, tb_offset=tb_offset
+ )
stb_text = shell.InteractiveTB.stb2text(stb)
# Re-enable ANSI color strings
shell.InteractiveTB.color_toggle()
@@ -31,6 +34,6 @@ def store_exception(shell, etype, evalue, tb, tb_offset=None):
err = shell.user_ns.get("Err", {})
err[prompt_number] = styled_exception
shell.user_ns["Err"] = err
-
- # Return
+
+ # Return
return etraceback
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
index f9eebcac1..c42f5efdd 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
@@ -9,22 +9,23 @@
import click
from IPython import get_ipython
-from IPython.core.magic import Magics, magics_class, line_cell_magic
+from IPython.core.magic import Magics, line_cell_magic, magics_class
from IPython.display import HTML, JSON, Markdown, Math
from jupyter_ai_magics.utils import decompose_model_id, load_providers
+from langchain.chains import LLMChain
-from .providers import BaseProvider
-from .parsers import (cell_magic_parser,
- line_magic_parser,
+from .parsers import (
CellArgs,
DeleteArgs,
ErrorArgs,
HelpArgs,
ListArgs,
RegisterArgs,
- UpdateArgs)
-
-from langchain.chains import LLMChain
+ UpdateArgs,
+ cell_magic_parser,
+ line_magic_parser,
+)
+from .providers import BaseProvider
MODEL_ID_ALIASES = {
"gpt2": "huggingface_hub:gpt2",
@@ -33,35 +34,30 @@
"gpt4": "openai-chat:gpt-4",
}
-class TextOrMarkdown(object):
+class TextOrMarkdown:
def __init__(self, text, markdown):
self.text = text
self.markdown = markdown
def _repr_mimebundle_(self, include=None, exclude=None):
- return (
- {
- 'text/plain': self.text,
- 'text/markdown': self.markdown
- }
- )
+ return {"text/plain": self.text, "text/markdown": self.markdown}
-class TextWithMetadata(object):
+class TextWithMetadata:
def __init__(self, text, metadata):
self.text = text
self.metadata = metadata
def _repr_mimebundle_(self, include=None, exclude=None):
- return ({'text/plain': self.text}, self.metadata)
+ return ({"text/plain": self.text}, self.metadata)
class Base64Image:
def __init__(self, mimeData, metadata):
- mimeDataParts = mimeData.split(',')
- self.data = base64.b64decode(mimeDataParts[1]);
- self.mimeType = re.sub(r';base64$', '', mimeDataParts[0])
+ mimeDataParts = mimeData.split(",")
+ self.data = base64.b64decode(mimeDataParts[1])
+ self.mimeType = re.sub(r";base64$", "", mimeDataParts[0])
self.metadata = metadata
def _repr_mimebundle_(self, include=None, exclude=None):
@@ -76,14 +72,14 @@ def _repr_mimebundle_(self, include=None, exclude=None):
"math": Math,
"md": Markdown,
"json": JSON,
- "text": TextWithMetadata
+ "text": TextWithMetadata,
}
NA_MESSAGE = 'N/A'
-MARKDOWN_PROMPT_TEMPLATE = '{prompt}\n\nProduce output in markdown format only.'
+MARKDOWN_PROMPT_TEMPLATE = "{prompt}\n\nProduce output in markdown format only."
-PROVIDER_NO_MODELS = 'This provider does not define a list of models.'
+PROVIDER_NO_MODELS = "This provider does not define a list of models."
CANNOT_DETERMINE_MODEL_TEXT = """Cannot determine model provider from model ID '{0}'.
@@ -95,132 +91,150 @@ def _repr_mimebundle_(self, include=None, exclude=None):
PROMPT_TEMPLATES_BY_FORMAT = {
- "code": '{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.',
- "html": '{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.',
- "image": '{prompt}\n\nProduce output as an image only, with no text before or after it.',
+ "code": "{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.",
+ "html": "{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.",
+ "image": "{prompt}\n\nProduce output as an image only, with no text before or after it.",
"markdown": MARKDOWN_PROMPT_TEMPLATE,
"md": MARKDOWN_PROMPT_TEMPLATE,
- "math": '{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.',
- "json": '{prompt}\n\nProduce output in JSON format only, with nothing before or after it.',
- "text": '{prompt}' # No customization
+ "math": "{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.",
+ "json": "{prompt}\n\nProduce output in JSON format only, with nothing before or after it.",
+ "text": "{prompt}", # No customization
}
-AI_COMMANDS = { "delete", "error", "help", "list", "register", "update" }
+AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"}
+
class FormatDict(dict):
"""Subclass of dict to be passed to str#format(). Suppresses KeyError and
leaves replacement field unchanged if replacement field is not associated
with a value."""
- def __missing__(self, key):
+
+ def __missing__(self, key):
return key.join("{}")
+
class EnvironmentError(BaseException):
pass
+
class CellMagicError(BaseException):
pass
+
@magics_class
class AiMagics(Magics):
def __init__(self, shell):
- super(AiMagics, self).__init__(shell)
+ super().__init__(shell)
self.transcript_openai = []
# suppress warning when using old OpenAIChat provider
- warnings.filterwarnings("ignore", message="You are trying to use a chat model. This way of initializing it is "
+ warnings.filterwarnings(
+ "ignore",
+ message="You are trying to use a chat model. This way of initializing it is "
"no longer supported. Instead, please use: "
- "`from langchain.chat_models import ChatOpenAI`")
+ "`from langchain.chat_models import ChatOpenAI`",
+ )
self.providers = load_providers()
-
+
# initialize a registry of custom model/chain names
self.custom_model_registry = MODEL_ID_ALIASES
-
+
def _ai_bulleted_list_models_for_provider(self, provider_id, Provider):
output = ""
- if (len(Provider.models) == 1 and Provider.models[0] == "*"):
+ if len(Provider.models) == 1 and Provider.models[0] == "*":
output += f"* {PROVIDER_NO_MODELS}\n"
else:
for model_id in Provider.models:
- output += f"* {provider_id}:{model_id}\n";
- output += "\n" # End of bulleted list
-
+ output += f"* {provider_id}:{model_id}\n"
+ output += "\n" # End of bulleted list
+
return output
def _ai_inline_list_models_for_provider(self, provider_id, Provider):
output = ""
- if (len(Provider.models) == 1 and Provider.models[0] == "*"):
+ if len(Provider.models) == 1 and Provider.models[0] == "*":
return PROVIDER_NO_MODELS
for model_id in Provider.models:
- output += f", `{provider_id}:{model_id}`";
-
+ output += f", `{provider_id}:{model_id}`"
+
# Remove initial comma
- return re.sub(r'^, ', '', output)
-
+ return re.sub(r"^, ", "", output)
+
# Is the required environment variable set?
def _ai_env_status_for_provider_markdown(self, provider_id):
- na_message = 'Not applicable. | ' + NA_MESSAGE
+ na_message = "Not applicable. | " + NA_MESSAGE
+
+ if (
+ provider_id not in self.providers
+ or self.providers[provider_id].auth_strategy == None
+ ):
+ return na_message # No emoji
- if (provider_id not in self.providers or
- self.providers[provider_id].auth_strategy == None):
- return na_message # No emoji
-
try:
env_var = self.providers[provider_id].auth_strategy.name
- except AttributeError: # No "name" attribute
+ except AttributeError: # No "name" attribute
return na_message
-
+
output = f"`{env_var}` | "
- if (os.getenv(env_var) == None):
- output += ("❌");
+ if os.getenv(env_var) == None:
+ output += (
+ '❌"
+ )
else:
- output += ("✅");
-
+ output += (
+ '✅"
+ )
+
return output
def _ai_env_status_for_provider_text(self, provider_id):
- if (provider_id not in self.providers or
- self.providers[provider_id].auth_strategy == None):
- return '' # No message necessary
-
- try:
+ if (
+ provider_id not in self.providers
+ or self.providers[provider_id].auth_strategy == None
+ ):
+ return "" # No message necessary
+
+ try:
env_var = self.providers[provider_id].auth_strategy.name
- except AttributeError: # No "name" attribute
- return ''
-
+ except AttributeError: # No "name" attribute
+ return ""
+
output = f"Requires environment variable {env_var} "
- if (os.getenv(env_var) != None):
+ if os.getenv(env_var) != None:
output += "(set)"
else:
output += "(not set)"
-
+
return output + "\n"
# Is this a name of a Python variable that can be called as a LangChain chain?
def _is_langchain_chain(self, name):
# Reserved word in Python?
- if (keyword.iskeyword(name)):
- return False;
-
- acceptable_name = re.compile('^[a-zA-Z0-9_]+$')
- if (not acceptable_name.match(name)):
- return False;
-
+ if keyword.iskeyword(name):
+ return False
+
+ acceptable_name = re.compile("^[a-zA-Z0-9_]+$")
+ if not acceptable_name.match(name):
+ return False
+
ipython = get_ipython()
- return(name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain))
+ return name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain)
# Is this an acceptable name for an alias?
def _validate_name(self, register_name):
# A registry name contains ASCII letters, numbers, hyphens, underscores,
# and periods. No other characters, including a colon, are permitted
- acceptable_name = re.compile('^[a-zA-Z0-9._-]+$')
- if (not acceptable_name.match(register_name)):
- raise ValueError('A registry name may contain ASCII letters, numbers, hyphens, underscores, '
- + 'and periods. No other characters, including a colon, are permitted')
+ acceptable_name = re.compile("^[a-zA-Z0-9._-]+$")
+ if not acceptable_name.match(register_name):
+ raise ValueError(
+ "A registry name may contain ASCII letters, numbers, hyphens, underscores, "
+ + "and periods. No other characters, including a colon, are permitted"
+ )
# Initially set or update an alias to a target
def _safely_set_target(self, register_name, target):
@@ -230,33 +244,38 @@ def _safely_set_target(self, register_name, target):
self.custom_model_registry[register_name] = ip.user_ns[target]
else:
# Ensure that the destination is properly formatted
- if (':' not in target):
+ if ":" not in target:
raise ValueError(
- 'Target model must be an LLMChain object or a model name in PROVIDER_ID:MODEL_NAME format')
+ "Target model must be an LLMChain object or a model name in PROVIDER_ID:MODEL_NAME format"
+ )
self.custom_model_registry[register_name] = target
def handle_delete(self, args: DeleteArgs):
- if (args.name in AI_COMMANDS):
- raise ValueError(f"Reserved command names, including {args.name}, cannot be deleted")
-
- if (args.name not in self.custom_model_registry):
+ if args.name in AI_COMMANDS:
+ raise ValueError(
+ f"Reserved command names, including {args.name}, cannot be deleted"
+ )
+
+ if args.name not in self.custom_model_registry:
raise ValueError(f"There is no alias called {args.name}")
-
+
del self.custom_model_registry[args.name]
output = f"Deleted alias `{args.name}`"
return TextOrMarkdown(output, output)
def handle_register(self, args: RegisterArgs):
# Existing command names are not allowed
- if (args.name in AI_COMMANDS):
+ if args.name in AI_COMMANDS:
raise ValueError(f"The name {args.name} is reserved for a command")
-
+
# Existing registered names are not allowed
- if (args.name in self.custom_model_registry):
- raise ValueError(f"The name {args.name} is already associated with a custom model; "
- + 'use %ai update to change its target')
-
+ if args.name in self.custom_model_registry:
+ raise ValueError(
+ f"The name {args.name} is already associated with a custom model; "
+ + "use %ai update to change its target"
+ )
+
# Does the new name match expected format?
self._validate_name(args.name)
@@ -265,62 +284,75 @@ def handle_register(self, args: RegisterArgs):
return TextOrMarkdown(output, output)
def handle_update(self, args: UpdateArgs):
- if (args.name in AI_COMMANDS):
- raise ValueError(f"Reserved command names, including {args.name}, cannot be updated")
-
- if (args.name not in self.custom_model_registry):
+ if args.name in AI_COMMANDS:
+ raise ValueError(
+ f"Reserved command names, including {args.name}, cannot be updated"
+ )
+
+ if args.name not in self.custom_model_registry:
raise ValueError(f"There is no alias called {args.name}")
-
+
self._safely_set_target(args.name, args.target)
output = f"Updated target of alias `{args.name}`"
return TextOrMarkdown(output, output)
def _ai_list_command_markdown(self, single_provider=None):
- output = ("| Provider | Environment variable | Set? | Models |\n"
- + "|----------|----------------------|------|--------|\n")
- if (single_provider is not None and single_provider not in self.providers):
- return f"There is no model provider with ID `{single_provider}`.";
+ output = (
+ "| Provider | Environment variable | Set? | Models |\n"
+ + "|----------|----------------------|------|--------|\n"
+ )
+ if single_provider is not None and single_provider not in self.providers:
+ return f"There is no model provider with ID `{single_provider}`."
for provider_id, Provider in self.providers.items():
- if (single_provider is not None and provider_id != single_provider):
- continue;
+ if single_provider is not None and provider_id != single_provider:
+ continue
- output += (f"| `{provider_id}` | "
- + self._ai_env_status_for_provider_markdown(provider_id) + " | "
+ output += (
+ f"| `{provider_id}` | "
+ + self._ai_env_status_for_provider_markdown(provider_id)
+ + " | "
+ self._ai_inline_list_models_for_provider(provider_id, Provider)
- + " |\n")
+ + " |\n"
+ )
# Also list aliases.
- if (single_provider is None and len(self.custom_model_registry) > 0):
- output += ("\nAliases and custom commands:\n\n"
+ if single_provider is None and len(self.custom_model_registry) > 0:
+ output += (
+ "\nAliases and custom commands:\n\n"
+ "| Name | Target |\n"
- + "|------|--------|\n")
+ + "|------|--------|\n"
+ )
for key, value in self.custom_model_registry.items():
output += f"| `{key}` | "
if isinstance(value, str):
output += f"`{value}`"
else:
output += "*custom chain*"
-
+
output += " |\n"
return output
def _ai_list_command_text(self, single_provider=None):
output = ""
- if (single_provider is not None and single_provider not in self.providers):
- return f"There is no model provider with ID '{single_provider}'.";
+ if single_provider is not None and single_provider not in self.providers:
+ return f"There is no model provider with ID '{single_provider}'."
for provider_id, Provider in self.providers.items():
- if (single_provider is not None and provider_id != single_provider):
- continue;
-
- output += (f"{provider_id}\n"
- + self._ai_env_status_for_provider_text(provider_id) # includes \n if nonblank
- + self._ai_bulleted_list_models_for_provider(provider_id, Provider))
+ if single_provider is not None and provider_id != single_provider:
+ continue
+
+ output += (
+ f"{provider_id}\n"
+ + self._ai_env_status_for_provider_text(
+ provider_id
+ ) # includes \n if nonblank
+ + self._ai_bulleted_list_models_for_provider(provider_id, Provider)
+ )
# Also list aliases.
- if (single_provider is None and len(self.custom_model_registry) > 0):
+ if single_provider is None and len(self.custom_model_registry) > 0:
output += "\nAliases and custom commands:\n"
for key, value in self.custom_model_registry.items():
output += f"{key} - "
@@ -328,7 +360,7 @@ def _ai_list_command_text(self, single_provider=None):
output += value
else:
output += "custom chain"
-
+
output += "\n"
return output
@@ -338,42 +370,34 @@ def handle_error(self, args: ErrorArgs):
# Find the most recent error.
ip = get_ipython()
- if ('Err' not in ip.user_ns):
+ if "Err" not in ip.user_ns:
return TextOrMarkdown(no_errors, no_errors)
- err = ip.user_ns['Err']
+ err = ip.user_ns["Err"]
# Start from the previous execution count
excount = ip.execution_count - 1
last_error = None
- while (excount >= 0 and last_error is None):
- if(excount in err):
+ while excount >= 0 and last_error is None:
+ if excount in err:
last_error = err[excount]
else:
- excount = excount - 1;
+ excount = excount - 1
- if (last_error is None):
+ if last_error is None:
return TextOrMarkdown(no_errors, no_errors)
prompt = f"Explain the following error:\n\n{last_error}"
# Set CellArgs based on ErrorArgs
cell_args = CellArgs(
- type="root",
- model_id=args.model_id,
- format=args.format,
- reset=False)
+ type="root", model_id=args.model_id, format=args.format, reset=False
+ )
return self.run_ai_cell(cell_args, prompt)
def _append_exchange_openai(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
- self.transcript_openai.append({
- "role": "user",
- "content": prompt
- })
- self.transcript_openai.append({
- "role": "assistant",
- "content": output
- })
+ self.transcript_openai.append({"role": "user", "content": prompt})
+ self.transcript_openai.append({"role": "assistant", "content": output})
def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
@@ -388,29 +412,31 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
return None
return self.providers[provider_id]
-
+
def display_output(self, output, display_format, md):
# build output display
DisplayClass = DISPLAYS_BY_FORMAT[display_format]
# if the user wants code, add another cell with the output.
- if display_format == 'code':
+ if display_format == "code":
# Strip a leading language indicator and trailing triple-backticks
- lang_indicator = r'^```[a-zA-Z0-9]*\n'
- output = re.sub(lang_indicator, '', output)
- output = re.sub(r'\n```$', '', output)
+ lang_indicator = r"^```[a-zA-Z0-9]*\n"
+ output = re.sub(lang_indicator, "", output)
+ output = re.sub(r"\n```$", "", output)
new_cell_payload = dict(
- source='set_next_input',
+ source="set_next_input",
text=output,
replace=False,
)
ip = get_ipython()
ip.payload_manager.write_payload(new_cell_payload)
- return HTML('AI generated code inserted below ⬇️', metadata=md);
+ return HTML(
+ "AI generated code inserted below ⬇️", metadata=md
+ )
if DisplayClass is None:
return output
- if display_format == 'json':
+ if display_format == "json":
# JSON display expects a dict, not a JSON string
output = json.loads(output)
output_display = DisplayClass(output, metadata=md)
@@ -426,12 +452,12 @@ def handle_help(self, _: HelpArgs):
def handle_list(self, args: ListArgs):
return TextOrMarkdown(
self._ai_list_command_text(args.provider_id),
- self._ai_list_command_markdown(args.provider_id)
+ self._ai_list_command_markdown(args.provider_id),
)
def run_ai_cell(self, args: CellArgs, prompt: str):
# Apply a prompt template.
- prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt = prompt)
+ prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt=prompt)
# interpolate user namespace into prompt
ip = get_ipython()
@@ -439,30 +465,29 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
# Determine provider and local model IDs
# If this is a custom chain, send the message to the custom chain.
- if (args.model_id in self.custom_model_registry and
- isinstance(self.custom_model_registry[args.model_id], LLMChain)):
-
+ if args.model_id in self.custom_model_registry and isinstance(
+ self.custom_model_registry[args.model_id], LLMChain
+ ):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
- {
- "jupyter_ai": {
- "custom_chain_id": args.model_id
- }
- })
+ {"jupyter_ai": {"custom_chain_id": args.model_id}},
+ )
provider_id, local_model_id = self._decompose_model_id(args.model_id)
Provider = self._get_provider(provider_id)
if Provider is None:
return TextOrMarkdown(
- CANNOT_DETERMINE_MODEL_TEXT.format(args.model_id) + "\n\n"
+ CANNOT_DETERMINE_MODEL_TEXT.format(args.model_id)
+ + "\n\n"
+ "If you were trying to run a command, run '%ai help' to see a list of commands.",
- CANNOT_DETERMINE_MODEL_MARKDOWN.format(args.model_id) + "\n\n"
- + "If you were trying to run a command, run `%ai help` to see a list of commands."
+ CANNOT_DETERMINE_MODEL_MARKDOWN.format(args.model_id)
+ + "\n\n"
+ + "If you were trying to run a command, run `%ai help` to see a list of commands.",
)
# if `--reset` is specified, reset transcript and return early
- if (provider_id == "openai-chat" and args.reset):
+ if provider_id == "openai-chat" and args.reset:
self.transcript_openai = []
return
@@ -471,26 +496,26 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
if auth_strategy:
# TODO: handle auth strategies besides EnvAuthStrategy
if auth_strategy.type == "env" and auth_strategy.name not in os.environ:
- raise EnvironmentError(
+ raise OSError(
f"Authentication environment variable {auth_strategy.name} not provided.\n"
f"An authentication token is required to use models from the {Provider.name} provider.\n"
f"Please specify it via `%env {auth_strategy.name}=token`. "
) from None
# configure and instantiate provider
- provider_params = { "model_id": local_model_id }
+ provider_params = {"model_id": local_model_id}
if provider_id == "openai-chat":
provider_params["prefix_messages"] = self.transcript_openai
# for SageMaker, validate that required params are specified
if provider_id == "sagemaker-endpoint":
if (
- args.region_name is None or
- args.request_schema is None or
- args.response_path is None
+ args.region_name is None
+ or args.request_schema is None
+ or args.response_path is None
):
raise ValueError(
- "When using the sagemaker-endpoint provider, you must specify all of " +
- "the --region-name, --request-schema, and --response-path options."
+ "When using the sagemaker-endpoint provider, you must specify all of "
+ + "the --region-name, --request-schema, and --response-path options."
)
provider_params["region_name"] = args.region_name
provider_params["request_schema"] = args.request_schema
@@ -506,18 +531,13 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
if provider_id == "openai-chat":
self._append_exchange_openai(prompt, output)
- md = {
- "jupyter_ai": {
- "provider_id": provider_id,
- "model_id": local_model_id
- }
- }
+ md = {"jupyter_ai": {"provider_id": provider_id, "model_id": local_model_id}}
return self.display_output(output, args.format, md)
@line_cell_magic
def ai(self, line, cell=None):
- raw_args = line.split(' ')
+ raw_args = line.split(" ")
if cell:
args = cell_magic_parser(raw_args, prog_name="%%ai", standalone_mode=False)
else:
@@ -554,10 +574,10 @@ def ai(self, line, cell=None):
"""[0.8+]: To invoke a language model, you must use the `%%ai`
cell magic. The `%ai` line magic is only for use with
subcommands."""
- )
+ )
prompt = cell.strip()
-
+
# interpolate user namespace into prompt
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
index d2b0476b5..a6acf3525 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
@@ -1,27 +1,37 @@
+from typing import Literal, Optional, get_args
+
import click
from pydantic import BaseModel
-from typing import Optional, Literal, get_args
-FORMAT_CHOICES_TYPE = Literal["code", "html", "image", "json", "markdown", "math", "md", "text"]
+FORMAT_CHOICES_TYPE = Literal[
+ "code", "html", "image", "json", "markdown", "math", "md", "text"
+]
FORMAT_CHOICES = list(get_args(FORMAT_CHOICES_TYPE))
FORMAT_HELP = """IPython display to use when rendering output. [default="markdown"]"""
-REGION_NAME_SHORT_OPTION = '-n'
-REGION_NAME_LONG_OPTION = '--region-name'
-REGION_NAME_HELP = ("AWS region name, e.g. 'us-east-1'. Required for SageMaker provider; " +
- "does nothing with other providers.")
+REGION_NAME_SHORT_OPTION = "-n"
+REGION_NAME_LONG_OPTION = "--region-name"
+REGION_NAME_HELP = (
+ "AWS region name, e.g. 'us-east-1'. Required for SageMaker provider; "
+ + "does nothing with other providers."
+)
-REQUEST_SCHEMA_SHORT_OPTION = '-q'
-REQUEST_SCHEMA_LONG_OPTION = '--request-schema'
-REQUEST_SCHEMA_HELP = ("The JSON object the endpoint expects, with the prompt being " +
- "substituted into any value that matches the string literal ''. " +
- "Required for SageMaker provider; does nothing with other providers.")
+REQUEST_SCHEMA_SHORT_OPTION = "-q"
+REQUEST_SCHEMA_LONG_OPTION = "--request-schema"
+REQUEST_SCHEMA_HELP = (
+ "The JSON object the endpoint expects, with the prompt being "
+ + "substituted into any value that matches the string literal ''. "
+ + "Required for SageMaker provider; does nothing with other providers."
+)
+
+RESPONSE_PATH_SHORT_OPTION = "-p"
+RESPONSE_PATH_LONG_OPTION = "--response-path"
+RESPONSE_PATH_HELP = (
+ "A JSONPath string that retrieves the language model's output "
+ + "from the endpoint's JSON response. Required for SageMaker provider; "
+ + "does nothing with other providers."
+)
-RESPONSE_PATH_SHORT_OPTION = '-p'
-RESPONSE_PATH_LONG_OPTION = '--response-path'
-RESPONSE_PATH_HELP = ("A JSONPath string that retrieves the language model's output " +
- "from the endpoint's JSON response. Required for SageMaker provider; " +
- "does nothing with other providers.")
class CellArgs(BaseModel):
type: Literal["root"] = "root"
@@ -33,6 +43,7 @@ class CellArgs(BaseModel):
request_schema: Optional[str]
response_path: Optional[str]
+
# Should match CellArgs, but without "reset"
class ErrorArgs(BaseModel):
type: Literal["error"] = "error"
@@ -43,51 +54,79 @@ class ErrorArgs(BaseModel):
request_schema: Optional[str]
response_path: Optional[str]
+
class HelpArgs(BaseModel):
type: Literal["help"] = "help"
+
class ListArgs(BaseModel):
type: Literal["list"] = "list"
provider_id: Optional[str]
+
class RegisterArgs(BaseModel):
type: Literal["register"] = "register"
name: str
target: str
+
class DeleteArgs(BaseModel):
type: Literal["delete"] = "delete"
name: str
+
class UpdateArgs(BaseModel):
type: Literal["update"] = "update"
name: str
target: str
+
class LineMagicGroup(click.Group):
"""Helper class to print the help string for cell magics as well when
`%ai --help` is called."""
+
def get_help(self, ctx):
with click.Context(cell_magic_parser, info_name="%%ai") as ctx:
click.echo(cell_magic_parser.get_help(ctx))
- click.echo('-' * 78)
+ click.echo("-" * 78)
with click.Context(line_magic_parser, info_name="%ai") as ctx:
click.echo(super().get_help(ctx))
+
@click.command()
-@click.argument('model_id')
-@click.option('-f', '--format',
+@click.argument("model_id")
+@click.option(
+ "-f",
+ "--format",
type=click.Choice(FORMAT_CHOICES, case_sensitive=False),
default="markdown",
- help=FORMAT_HELP
+ help=FORMAT_HELP,
)
-@click.option('-r', '--reset', is_flag=True,
+@click.option(
+ "-r",
+ "--reset",
+ is_flag=True,
help="""Clears the conversation transcript used when interacting with an
- OpenAI chat model provider. Does nothing with other providers."""
+ OpenAI chat model provider. Does nothing with other providers.""",
+)
+@click.option(
+ REGION_NAME_SHORT_OPTION,
+ REGION_NAME_LONG_OPTION,
+ required=False,
+ help=REGION_NAME_HELP,
+)
+@click.option(
+ REQUEST_SCHEMA_SHORT_OPTION,
+ REQUEST_SCHEMA_LONG_OPTION,
+ required=False,
+ help=REQUEST_SCHEMA_HELP,
+)
+@click.option(
+ RESPONSE_PATH_SHORT_OPTION,
+ RESPONSE_PATH_LONG_OPTION,
+ required=False,
+ help=RESPONSE_PATH_HELP,
)
-@click.option(REGION_NAME_SHORT_OPTION, REGION_NAME_LONG_OPTION, required=False, help=REGION_NAME_HELP)
-@click.option(REQUEST_SCHEMA_SHORT_OPTION, REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP)
-@click.option(RESPONSE_PATH_SHORT_OPTION, RESPONSE_PATH_LONG_OPTION, required=False, help=RESPONSE_PATH_HELP)
def cell_magic_parser(**kwargs):
"""
Invokes a language model identified by MODEL_ID, with the prompt being
@@ -99,22 +138,41 @@ def cell_magic_parser(**kwargs):
"""
return CellArgs(**kwargs)
+
@click.group(cls=LineMagicGroup)
def line_magic_parser():
"""
Invokes a subcommand.
"""
-@line_magic_parser.command(name='error')
-@click.argument('model_id')
-@click.option('-f', '--format',
+
+@line_magic_parser.command(name="error")
+@click.argument("model_id")
+@click.option(
+ "-f",
+ "--format",
type=click.Choice(FORMAT_CHOICES, case_sensitive=False),
default="markdown",
- help=FORMAT_HELP
+ help=FORMAT_HELP,
+)
+@click.option(
+ REGION_NAME_SHORT_OPTION,
+ REGION_NAME_LONG_OPTION,
+ required=False,
+ help=REGION_NAME_HELP,
+)
+@click.option(
+ REQUEST_SCHEMA_SHORT_OPTION,
+ REQUEST_SCHEMA_LONG_OPTION,
+ required=False,
+ help=REQUEST_SCHEMA_HELP,
+)
+@click.option(
+ RESPONSE_PATH_SHORT_OPTION,
+ RESPONSE_PATH_LONG_OPTION,
+ required=False,
+ help=RESPONSE_PATH_HELP,
)
-@click.option(REGION_NAME_SHORT_OPTION, REGION_NAME_LONG_OPTION, required=False, help=REGION_NAME_HELP)
-@click.option(REQUEST_SCHEMA_SHORT_OPTION, REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP)
-@click.option(RESPONSE_PATH_SHORT_OPTION, RESPONSE_PATH_LONG_OPTION, required=False, help=RESPONSE_PATH_HELP)
def error_subparser(**kwargs):
"""
Explains the most recent error. Takes the same options (except -r) as
@@ -122,41 +180,48 @@ def error_subparser(**kwargs):
"""
return ErrorArgs(**kwargs)
-@line_magic_parser.command(name='help')
+
+@line_magic_parser.command(name="help")
def help_subparser():
"""Show this message and exit."""
return HelpArgs()
-@line_magic_parser.command(name='list',
- short_help="List language models. See `%ai list --help` for options."
+
+@line_magic_parser.command(
+ name="list", short_help="List language models. See `%ai list --help` for options."
)
-@click.argument('provider_id', required=False)
+@click.argument("provider_id", required=False)
def list_subparser(**kwargs):
"""List language models, optionally scoped to PROVIDER_ID."""
return ListArgs(**kwargs)
-@line_magic_parser.command(name='register',
- short_help="Register a new alias. See `%ai register --help` for options."
+
+@line_magic_parser.command(
+ name="register",
+ short_help="Register a new alias. See `%ai register --help` for options.",
)
-@click.argument('name')
-@click.argument('target')
+@click.argument("name")
+@click.argument("target")
def register_subparser(**kwargs):
"""Register a new alias called NAME for the model or chain named TARGET."""
return RegisterArgs(**kwargs)
-@line_magic_parser.command(name='delete',
- short_help="Delete an alias. See `%ai delete --help` for options."
+
+@line_magic_parser.command(
+ name="delete", short_help="Delete an alias. See `%ai delete --help` for options."
)
-@click.argument('name')
+@click.argument("name")
def register_subparser(**kwargs):
"""Delete an alias called NAME."""
return DeleteArgs(**kwargs)
-@line_magic_parser.command(name='update',
- short_help="Update the target of an alias. See `%ai update --help` for options."
+
+@line_magic_parser.command(
+ name="update",
+ short_help="Update the target of an alias. See `%ai update --help` for options.",
)
-@click.argument('name')
-@click.argument('target')
+@click.argument("name")
+@click.argument("target")
def register_subparser(**kwargs):
"""Update an alias called NAME to refer to the model or chain named TARGET."""
return UpdateArgs(**kwargs)
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
index 5a573b8b7..e5c68d5d7 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
@@ -1,11 +1,11 @@
-from typing import Any, ClassVar, Dict, List, Union, Literal, Optional
import base64
+import copy
import io
import json
-import copy
+from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
from jsonpath_ng import jsonpath, parse
-from langchain.schema import BaseModel as BaseLangchainProvider
+from langchain.chat_models import ChatOpenAI
from langchain.llms import (
AI21,
Anthropic,
@@ -13,29 +13,32 @@
HuggingFaceHub,
OpenAI,
OpenAIChat,
- SagemakerEndpoint
+ SagemakerEndpoint,
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
-from langchain.utils import get_from_dict_or_env
from langchain.llms.utils import enforce_stop_tokens
-
+from langchain.schema import BaseModel as BaseLangchainProvider
+from langchain.utils import get_from_dict_or_env
from pydantic import BaseModel, Extra, root_validator
-from langchain.chat_models import ChatOpenAI
+
class EnvAuthStrategy(BaseModel):
"""Require one auth token via an environment variable."""
+
type: Literal["env"] = "env"
name: str
class MultiEnvAuthStrategy(BaseModel):
"""Require multiple auth tokens via multiple environment variables."""
+
type: Literal["file"] = "file"
names: List[str]
class AwsAuthStrategy(BaseModel):
"""Require AWS authentication via Boto3"""
+
type: Literal["aws"] = "aws"
@@ -47,18 +50,22 @@ class AwsAuthStrategy(BaseModel):
]
]
+
class TextField(BaseModel):
type: Literal["text"] = "text"
key: str
label: str
+
class MultilineTextField(BaseModel):
type: Literal["text-multiline"] = "text-multiline"
key: str
label: str
+
Field = Union[TextField, MultilineTextField]
+
class BaseProvider(BaseLangchainProvider):
#
# pydantic config
@@ -93,7 +100,7 @@ class Config:
"""Whether this provider is a registry provider."""
fields: ClassVar[List[Field]] = []
- """User inputs expected by this provider when initializing it. Each `Field` `f`
+ """User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""
#
@@ -105,14 +112,16 @@ def __init__(self, *args, **kwargs):
try:
assert kwargs["model_id"]
except:
- raise AssertionError("model_id was not specified. Please specify it as a keyword argument.")
+ raise AssertionError(
+ "model_id was not specified. Please specify it as a keyword argument."
+ )
model_kwargs = {}
model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]
super().__init__(*args, **kwargs, **model_kwargs)
-
+
class AI21Provider(BaseProvider, AI21):
id = "ai21"
name = "AI21"
@@ -131,6 +140,7 @@ class AI21Provider(BaseProvider, AI21):
pypi_package_deps = ["ai21"]
auth_strategy = EnvAuthStrategy(name="AI21_API_KEY")
+
class AnthropicProvider(BaseProvider, Anthropic):
id = "anthropic"
name = "Anthropic"
@@ -145,6 +155,7 @@ class AnthropicProvider(BaseProvider, Anthropic):
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")
+
class CohereProvider(BaseProvider, Cohere):
id = "cohere"
name = "Cohere"
@@ -153,7 +164,13 @@ class CohereProvider(BaseProvider, Cohere):
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")
-HUGGINGFACE_HUB_VALID_TASKS = ("text2text-generation", "text-generation", "text-to-image")
+
+HUGGINGFACE_HUB_VALID_TASKS = (
+ "text2text-generation",
+ "text-generation",
+ "text-to-image",
+)
+
class HfHubProvider(BaseProvider, HuggingFaceHub):
id = "huggingface_hub"
@@ -195,7 +212,7 @@ def validate_environment(cls, values: Dict) -> Dict:
"Please install it with `pip install huggingface_hub`."
)
return values
-
+
# Handle image outputs
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
@@ -220,21 +237,21 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
# Custom code for responding to image generation responses
if self.client.task == "text-to-image":
- imageFormat = response.format # Presume it's a PIL ImageFile
- mimeType = ''
- if (imageFormat == 'JPEG'):
- mimeType = 'image/jpeg'
- elif (imageFormat == 'PNG'):
- mimeType = 'image/png'
- elif (imageFormat == 'GIF'):
- mimeType = 'image/gif'
+ imageFormat = response.format # Presume it's a PIL ImageFile
+ mimeType = ""
+ if imageFormat == "JPEG":
+ mimeType = "image/jpeg"
+ elif imageFormat == "PNG":
+ mimeType = "image/png"
+ elif imageFormat == "GIF":
+ mimeType = "image/gif"
else:
raise ValueError(f"Unrecognized image format {imageFormat}")
-
+
buffer = io.BytesIO()
response.save(buffer, format=imageFormat)
# Encode image data to Base64 bytes, then decode bytes to str
- return (mimeType + ';base64,' + base64.b64encode(buffer.getvalue()).decode())
+ return mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()
if self.client.task == "text-generation":
# Text generation return includes the starter text.
@@ -252,6 +269,7 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
text = enforce_stop_tokens(text, stop)
return text
+
class OpenAIProvider(BaseProvider, OpenAI):
id = "openai"
name = "OpenAI"
@@ -270,6 +288,7 @@ class OpenAIProvider(BaseProvider, OpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
+
class ChatOpenAIProvider(BaseProvider, OpenAIChat):
id = "openai-chat"
name = "OpenAI"
@@ -288,14 +307,9 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat):
def append_exchange(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
- self.prefix_messages.append({
- "role": "user",
- "content": prompt
- })
- self.prefix_messages.append({
- "role": "assistant",
- "content": output
- })
+ self.prefix_messages.append({"role": "user", "content": prompt})
+ self.prefix_messages.append({"role": "assistant", "content": output})
+
# uses the new OpenAIChat provider. temporarily living as a separate class until
# conflicts can be resolved
@@ -314,6 +328,7 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
+
class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
@@ -330,20 +345,21 @@ def replace_values(self, old_val, new_val, d: Dict[str, Any]):
d[key] = new_val
if isinstance(val, dict):
self.replace_values(old_val, new_val, val)
-
+
return d
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
request_obj = copy.deepcopy(self.request_schema)
self.replace_values("", prompt, request_obj)
- request = json.dumps(request_obj).encode('utf-8')
+ request = json.dumps(request_obj).encode("utf-8")
return request
-
+
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
matches = self.response_parser.find(response_json)
return matches[0].value
+
class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
@@ -364,12 +380,13 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
TextField(
key="response_path",
label="Response path",
- )
+ ),
]
-
+
def __init__(self, *args, **kwargs):
- request_schema = kwargs.pop('request_schema')
- response_path = kwargs.pop('response_path')
- content_handler = JsonContentHandler(request_schema=request_schema, response_path=response_path)
+ request_schema = kwargs.pop("request_schema")
+ response_path = kwargs.pop("response_path")
+ content_handler = JsonContentHandler(
+ request_schema=request_schema, response_path=response_path
+ )
super().__init__(*args, **kwargs, content_handler=content_handler)
-
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
index aab722240..a6c717889 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
@@ -1,13 +1,11 @@
import logging
from typing import Dict, Optional, Tuple, Union
+
from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
-
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
-
from jupyter_ai_magics.providers import BaseProvider
-
Logger = Union[logging.Logger, logging.LoggerAdapter]
@@ -23,15 +21,19 @@ def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]:
try:
provider = model_provider_ep.load()
except:
- log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.")
+ log.error(
+ f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
+ )
continue
providers[provider.id] = provider
log.info(f"Registered model provider `{provider.id}`.")
-
+
return providers
-def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbeddingsProvider]:
+def load_embedding_providers(
+ log: Optional[Logger] = None,
+) -> Dict[str, BaseEmbeddingsProvider]:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
@@ -42,29 +44,34 @@ def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbe
try:
provider = model_provider_ep.load()
except:
- log.error(f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`.")
+ log.error(
+ f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
+ )
continue
providers[provider.id] = provider
log.info(f"Registered embeddings model provider `{provider.id}`.")
-
+
return providers
-def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]:
- """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
- if model_id in MODEL_ID_ALIASES:
- model_id = MODEL_ID_ALIASES[model_id]
- if ":" not in model_id:
- # case: model ID was not provided with a prefix indicating the provider
- # ID. try to infer the provider ID before returning (None, None).
+def decompose_model_id(
+ model_id: str, providers: Dict[str, BaseProvider]
+) -> Tuple[str, str]:
+ """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
+ if model_id in MODEL_ID_ALIASES:
+ model_id = MODEL_ID_ALIASES[model_id]
+
+ if ":" not in model_id:
+ # case: model ID was not provided with a prefix indicating the provider
+ # ID. try to infer the provider ID before returning (None, None).
+
+ # naively search through the dictionary and return the first provider
+ # that provides a model of the same ID.
+ for provider_id, provider in providers.items():
+ if model_id in provider.models:
+ return (provider_id, model_id)
- # naively search through the dictionary and return the first provider
- # that provides a model of the same ID.
- for provider_id, provider in providers.items():
- if model_id in provider.models:
- return (provider_id, model_id)
-
- return (None, None)
+ return (None, None)
- provider_id, local_model_id = model_id.split(":", 1)
- return (provider_id, local_model_id)
+ provider_id, local_model_id = model_id.split(":", 1)
+ return (provider_id, local_model_id)
diff --git a/packages/jupyter-ai-magics/package.json b/packages/jupyter-ai-magics/package.json
index 922d8a22b..94d527c8b 100644
--- a/packages/jupyter-ai-magics/package.json
+++ b/packages/jupyter-ai-magics/package.json
@@ -18,7 +18,7 @@
"url": "https://github.com/jupyterlab/jupyter-ai.git"
},
"scripts": {
- "dev-install": "pip install -e \".[all]\"",
+ "dev-install": "pip install -e \".[dev,all]\"",
"dev-uninstall": "pip uninstall jupyter_ai_magics -y"
}
}
diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml
index 2d59c9567..464284679 100644
--- a/packages/jupyter-ai-magics/pyproject.toml
+++ b/packages/jupyter-ai-magics/pyproject.toml
@@ -31,6 +31,10 @@ dependencies = [
]
[project.optional-dependencies]
+dev = [
+ "pre-commit~=3.3.3"
+]
+
test = [
"coverage",
"pytest",
diff --git a/packages/jupyter-ai-magics/setup.py b/packages/jupyter-ai-magics/setup.py
index bea233743..aefdf20db 100644
--- a/packages/jupyter-ai-magics/setup.py
+++ b/packages/jupyter-ai-magics/setup.py
@@ -1 +1 @@
-__import__('setuptools').setup()
+__import__("setuptools").setup()
diff --git a/packages/jupyter-ai-module-cookiecutter/README.md b/packages/jupyter-ai-module-cookiecutter/README.md
index e80d6ac08..b99f3617d 100644
--- a/packages/jupyter-ai-module-cookiecutter/README.md
+++ b/packages/jupyter-ai-module-cookiecutter/README.md
@@ -2,8 +2,8 @@
A [cookiecutter](https://github.com/audreyr/cookiecutter) template for creating
a AI module. The AI module constructed from the template serves as a very simple
-example that can be extended however you wish.
-
+example that can be extended however you wish.
+
## Usage
Install cookiecutter.
diff --git a/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py b/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py
index 48b57e8be..aa816e397 100644
--- a/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py
+++ b/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py
@@ -6,7 +6,7 @@
def remove_path(path: Path) -> None:
"""Remove the provided path.
-
+
If the target path is a directory, remove it recursively.
"""
if not path.exists():
@@ -21,7 +21,6 @@ def remove_path(path: Path) -> None:
if __name__ == "__main__":
-
if not "{{ cookiecutter.has_settings }}".lower().startswith("y"):
remove_path(PROJECT_DIRECTORY / "schema")
@@ -30,7 +29,9 @@ def remove_path(path: Path) -> None:
remove_path(PROJECT_DIRECTORY / ".github/workflows/binder-on-pr.yml")
if not "{{ cookiecutter.test }}".lower().startswith("y"):
- remove_path(PROJECT_DIRECTORY / ".github" / "workflows" / "update-integration-tests.yml")
+ remove_path(
+ PROJECT_DIRECTORY / ".github" / "workflows" / "update-integration-tests.yml"
+ )
remove_path(PROJECT_DIRECTORY / "src" / "__tests__")
remove_path(PROJECT_DIRECTORY / "ui-tests")
remove_path(PROJECT_DIRECTORY / "{{ cookiecutter.python_name }}" / "tests")
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml
index a975d584d..e79041157 100644
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml
@@ -25,7 +25,7 @@ jobs:
set -eux
jlpm
jlpm run lint:check
-{% if cookiecutter.test.lower().startswith('y') %}
+{% if cookiecutter.test.lower().startswith('y') %}
- name: Test the extension
run: |
set -eux
@@ -102,7 +102,7 @@ jobs:
uses: actions/download-artifact@v3
with:
name: extension-artifacts
-
+
- name: Install the extension
run: |
set -eux
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild
index 1a20c19a8..d4b13f656 100755
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild
@@ -15,10 +15,11 @@ from pathlib import Path
ROOT = Path.cwd()
+
def _(*args, **kwargs):
- """ Run a command, echoing the args
+ """Run a command, echoing the args
- fails hard if something goes wrong
+ fails hard if something goes wrong
"""
print("\n\t", " ".join(args), "\n")
return_code = subprocess.call(args, **kwargs)
@@ -26,6 +27,7 @@ def _(*args, **kwargs):
print("\nERROR", return_code, " ".join(args))
sys.exit(return_code)
+
# verify the environment is self-consistent before even starting
_(sys.executable, "-m", "pip", "check")
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py
index bea233743..aefdf20db 100644
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py
@@ -1 +1 @@
-__import__('setuptools').setup()
+__import__("setuptools").setup()
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py
index 5ba7a914e..23d06f6dd 100644
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py
@@ -10,7 +10,7 @@
c.ServerApp.port_retries = 0
c.ServerApp.open_browser = False
-c.ServerApp.root_dir = mkdtemp(prefix='galata-test-')
+c.ServerApp.root_dir = mkdtemp(prefix="galata-test-")
c.ServerApp.token = ""
c.ServerApp.password = ""
c.ServerApp.disable_check_xsrf = True
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py
index d14127453..db43ba181 100644
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py
@@ -7,7 +7,4 @@
def _jupyter_labextension_paths():
- return [{
- "src": "labextension",
- "dest": "{{ cookiecutter.labextension_name }}"
- }]
+ return [{"src": "labextension", "dest": "{{ cookiecutter.labextension_name }}"}]
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py
index 420bde3d3..c32e86148 100644
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py
@@ -3,6 +3,7 @@
from jupyter_ai.engine import BaseModelEngine
from jupyter_ai.models import DescribeTaskResponse
+
class TestModelEngine(BaseModelEngine):
name = "test"
input_type = "txt"
@@ -18,7 +19,9 @@ class TestModelEngine(BaseModelEngine):
# )
#
- async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]):
+ async def execute(
+ self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]
+ ):
# Core method that executes a model when provided with a task
# description and a dictionary of prompt variables. For example, to
# execute an OpenAI text completion model:
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py
index c3cb8b9dc..07226c447 100644
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py
@@ -1,4 +1,5 @@
from typing import List
+
from jupyter_ai import DefaultTaskDefinition
# tasks your AI module exposes by default. declared in `pyproject.toml` as an
@@ -9,6 +10,6 @@
"name": "Test task",
"prompt_template": "{body}",
"modality": "txt2txt",
- "insertion_mode": "test"
+ "insertion_mode": "test",
}
]
diff --git a/packages/jupyter-ai/.gitignore b/packages/jupyter-ai/.gitignore
index 56891ff87..7c34e452b 100644
--- a/packages/jupyter-ai/.gitignore
+++ b/packages/jupyter-ai/.gitignore
@@ -123,4 +123,4 @@ dmypy.json
.vscode
# Coverage reports
-coverage/*
\ No newline at end of file
+coverage/*
diff --git a/packages/jupyter-ai/conftest.py b/packages/jupyter-ai/conftest.py
index 12f736c03..dc3b0e05c 100644
--- a/packages/jupyter-ai/conftest.py
+++ b/packages/jupyter-ai/conftest.py
@@ -1,6 +1,6 @@
import pytest
-pytest_plugins = ("jupyter_server.pytest_plugin", )
+pytest_plugins = ("jupyter_server.pytest_plugin",)
@pytest.fixture
diff --git a/packages/jupyter-ai/jupyter_ai/__init__.py b/packages/jupyter-ai/jupyter_ai/__init__.py
index 7bb42b861..701dc2a86 100644
--- a/packages/jupyter-ai/jupyter_ai/__init__.py
+++ b/packages/jupyter-ai/jupyter_ai/__init__.py
@@ -1,24 +1,19 @@
-from ._version import __version__
-from .extension import AiExtension
-
# expose jupyter_ai_magics ipython extension
from jupyter_ai_magics import load_ipython_extension, unload_ipython_extension
+from ._version import __version__
+
# imports to expose entry points. DO NOT REMOVE.
from .engine import GPT3ModelEngine
-from .tasks import tasks
+from .extension import AiExtension
# imports to expose types to other AI modules. DO NOT REMOVE.
-from .tasks import DefaultTaskDefinition
+from .tasks import DefaultTaskDefinition, tasks
+
def _jupyter_labextension_paths():
- return [{
- "src": "labextension",
- "dest": "@jupyter-ai/core"
- }]
+ return [{"src": "labextension", "dest": "@jupyter-ai/core"}]
+
def _jupyter_server_extension_points():
- return [{
- "module": "jupyter_ai",
- "app": AiExtension
- }]
+ return [{"module": "jupyter_ai", "app": AiExtension}]
diff --git a/packages/jupyter-ai/jupyter_ai/actors/__init__.py b/packages/jupyter-ai/jupyter_ai/actors/__init__.py
index a48ae5056..9cfbc9156 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/__init__.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/__init__.py
@@ -1 +1 @@
-"""Actor classes that process incoming chat messages"""
\ No newline at end of file
+"""Actor classes that process incoming chat messages"""
diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py
index e78837ca4..45f2d5b9b 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/ask.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py
@@ -1,15 +1,13 @@
import argparse
from typing import Dict, List, Type
-from jupyter_ai_magics.providers import BaseProvider
import ray
-from ray.util.queue import Queue
-
+from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger
+from jupyter_ai.models import HumanChatMessage
+from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.schema import BaseRetriever, Document
-
-from jupyter_ai.models import HumanChatMessage
-from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger
+from ray.util.queue import Queue
@ray.remote
@@ -24,38 +22,39 @@ class AskActor(BaseActor):
def __init__(self, reply_queue: Queue, log: Logger):
super().__init__(reply_queue=reply_queue, log=log)
- self.parser.prog = '/ask'
- self.parser.add_argument('query', nargs=argparse.REMAINDER)
+ self.parser.prog = "/ask"
+ self.parser.add_argument("query", nargs=argparse.REMAINDER)
- def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
+ def create_llm_chain(
+ self, provider: Type[BaseProvider], provider_params: Dict[str, str]
+ ):
retriever = Retriever()
self.llm = provider(**provider_params)
self.chat_history = []
- self.llm_chain = ConversationalRetrievalChain.from_llm(
- self.llm,
- retriever
- )
+ self.llm_chain = ConversationalRetrievalChain.from_llm(self.llm, retriever)
def _process_message(self, message: HumanChatMessage):
args = self.parse_args(message)
if args is None:
return
- query = ' '.join(args.query)
+ query = " ".join(args.query)
if not query:
self.reply(f"{self.parser.format_usage()}", message)
return
-
+
self.get_llm_chain()
try:
- result = self.llm_chain({"question": query, "chat_history": self.chat_history})
- response = result['answer']
+ result = self.llm_chain(
+ {"question": query, "chat_history": self.chat_history}
+ )
+ response = result["answer"]
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
- response = """Sorry, an error occurred while reading the from the learned documents.
- If you have changed the embedding provider, try deleting the existing index by running
+ response = """Sorry, an error occurred while reading the from the learned documents.
+ If you have changed the embedding provider, try deleting the existing index by running
`/learn -d` command and then re-submitting the `learn ` to learn the documents,
and then asking the question again.
"""
@@ -68,11 +67,11 @@ class Retriever(BaseRetriever):
of inconsistent de-serialization of index when it's
accessed directly from the ask actor.
"""
-
+
def get_relevant_documents(self, question: str):
index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value)
docs = ray.get(index_actor.get_relevant_documents.remote(question))
return docs
-
+
async def aget_relevant_documents(self, query: str) -> List[Document]:
- return await super().aget_relevant_documents(query)
\ No newline at end of file
+ return await super().aget_relevant_documents(query)
diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py
index 84a62bf8b..d1ddddcf3 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/base.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/base.py
@@ -1,20 +1,19 @@
import argparse
-from enum import Enum
-from uuid import uuid4
-import time
import logging
-from typing import Dict, Optional, Type, Union
+import time
import traceback
+from enum import Enum
+from typing import Dict, Optional, Type, Union
+from uuid import uuid4
-from jupyter_ai_magics.providers import BaseProvider
import ray
-
+from jupyter_ai.models import AgentChatMessage, HumanChatMessage
+from jupyter_ai_magics.providers import BaseProvider
from ray.util.queue import Queue
-from jupyter_ai.models import HumanChatMessage, AgentChatMessage
-
Logger = Union[logging.Logger, logging.LoggerAdapter]
+
class ACTOR_TYPE(str, Enum):
# the top level actor that routes incoming messages to the appropriate actor
ROUTER = "router"
@@ -23,28 +22,26 @@ class ACTOR_TYPE(str, Enum):
DEFAULT = "default"
ASK = "ask"
- LEARN = 'learn'
- MEMORY = 'memory'
- GENERATE = 'generate'
- PROVIDERS = 'providers'
- CONFIG = 'config'
- CHAT_PROVIDER = 'chat_provider'
- EMBEDDINGS_PROVIDER = 'embeddings_provider'
+ LEARN = "learn"
+ MEMORY = "memory"
+ GENERATE = "generate"
+ PROVIDERS = "providers"
+ CONFIG = "config"
+ CHAT_PROVIDER = "chat_provider"
+ EMBEDDINGS_PROVIDER = "embeddings_provider"
+
COMMANDS = {
- '/ask': ACTOR_TYPE.ASK,
- '/learn': ACTOR_TYPE.LEARN,
- '/generate': ACTOR_TYPE.GENERATE
+ "/ask": ACTOR_TYPE.ASK,
+ "/learn": ACTOR_TYPE.LEARN,
+ "/generate": ACTOR_TYPE.GENERATE,
}
-class BaseActor():
+
+class BaseActor:
"""Base actor implemented by actors that are called by the `Router`"""
- def __init__(
- self,
- log: Logger,
- reply_queue: Queue
- ):
+ def __init__(self, log: Logger, reply_queue: Queue):
self.log = log
self.reply_queue = reply_queue
self.parser = argparse.ArgumentParser()
@@ -63,17 +60,17 @@ def process_message(self, message: HumanChatMessage):
formatted_e = traceback.format_exc()
response = f"Sorry, something went wrong and I wasn't able to index that path.\n\n```\n{formatted_e}\n```"
self.reply(response, message)
-
+
def _process_message(self, message: HumanChatMessage):
"""Processes the message passed by the `Router`"""
raise NotImplementedError("Should be implemented by subclasses.")
-
+
def reply(self, response, message: Optional[HumanChatMessage] = None):
m = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
- reply_to=message.id if message else ""
+ reply_to=message.id if message else "",
)
self.reply_queue.put(m)
@@ -82,40 +79,50 @@ def get_llm_chain(self):
lm_provider = ray.get(actor.get_provider.remote())
lm_provider_params = ray.get(actor.get_provider_params.remote())
- curr_lm_id = f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None
- next_lm_id = f'{lm_provider.id}:{lm_provider_params["model_id"]}' if lm_provider else None
+ curr_lm_id = (
+ f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None
+ )
+ next_lm_id = (
+ f'{lm_provider.id}:{lm_provider_params["model_id"]}'
+ if lm_provider
+ else None
+ )
if not lm_provider:
return None
-
+
if curr_lm_id != next_lm_id:
- self.log.info(f"Switching chat language model from {curr_lm_id} to {next_lm_id}.")
+ self.log.info(
+ f"Switching chat language model from {curr_lm_id} to {next_lm_id}."
+ )
self.create_llm_chain(lm_provider, lm_provider_params)
return self.llm_chain
-
+
def get_embeddings(self):
actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER)
provider = ray.get(actor.get_provider.remote())
embedding_params = ray.get(actor.get_provider_params.remote())
embedding_model_id = ray.get(actor.get_model_id.remote())
-
+
if not provider:
return None
-
+
if embedding_model_id != self.embedding_model_id:
self.embeddings = provider(**embedding_params)
return self.embeddings
-
- def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
+
+ def create_llm_chain(
+ self, provider: Type[BaseProvider], provider_params: Dict[str, str]
+ ):
raise NotImplementedError("Should be implemented by subclasses")
-
+
def parse_args(self, message):
- args = message.body.split(' ')
+ args = message.body.split(" ")
try:
args = self.parser.parse_args(args[1:])
except (argparse.ArgumentError, SystemExit) as e:
response = f"{self.parser.format_usage()}"
self.reply(response, message)
return None
- return args
\ No newline at end of file
+ return args
diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py
index a225beb7b..e108f9f2a 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py
@@ -1,10 +1,10 @@
-from jupyter_ai.actors.base import Logger, ACTOR_TYPE
-from jupyter_ai.models import GlobalConfig
import ray
+from jupyter_ai.actors.base import ACTOR_TYPE, Logger
+from jupyter_ai.models import GlobalConfig
-@ray.remote
-class ChatProviderActor():
+@ray.remote
+class ChatProviderActor:
def __init__(self, log: Logger):
self.log = log
self.provider = None
@@ -16,26 +16,28 @@ def update(self, config: GlobalConfig):
local_model_id, provider = ray.get(
actor.get_model_provider_data.remote(model_id)
)
-
+
if not provider:
raise ValueError(f"No provider and model found with '{model_id}'")
-
+
fields = config.fields.get(model_id, {})
- provider_params = { "model_id": local_model_id, **fields }
-
+ provider_params = {"model_id": local_model_id, **fields}
+
auth_strategy = provider.auth_strategy
if auth_strategy and auth_strategy.type == "env":
api_keys = config.api_keys
name = auth_strategy.name
if name not in api_keys:
- raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.")
+ raise ValueError(
+ f"Missing value for '{auth_strategy.name}' in the config."
+ )
provider_params[name.lower()] = api_keys[name]
-
+
self.provider = provider
self.provider_params = provider_params
def get_provider(self):
return self.provider
-
+
def get_provider_params(self):
- return self.provider_params
\ No newline at end of file
+ return self.provider_params
diff --git a/packages/jupyter-ai/jupyter_ai/actors/config.py b/packages/jupyter-ai/jupyter_ai/actors/config.py
index f51c9e5f7..5ca3c518e 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/config.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/config.py
@@ -1,21 +1,22 @@
import json
import os
+
+import ray
from jupyter_ai.actors.base import ACTOR_TYPE, Logger
from jupyter_ai.models import GlobalConfig
-import ray
from jupyter_core.paths import jupyter_data_dir
@ray.remote
-class ConfigActor():
- """Provides model and embedding provider id along
+class ConfigActor:
+ """Provides model and embedding provider id along
with the credentials to authenticate providers.
"""
def __init__(self, log: Logger):
self.log = log
- self.save_dir = os.path.join(jupyter_data_dir(), 'jupyter_ai')
- self.save_path = os.path.join(self.save_dir, 'config.json')
+ self.save_dir = os.path.join(jupyter_data_dir(), "jupyter_ai")
+ self.save_path = os.path.join(self.save_dir, "config.json")
self.config = None
self._load()
@@ -25,7 +26,7 @@ def update(self, config: GlobalConfig, save_to_disk: bool = True):
if save_to_disk:
self._save(config)
self.config = config
-
+
def _update_chat_provider(self, config: GlobalConfig):
if not config.model_provider_id:
return
@@ -43,19 +44,19 @@ def _update_embeddings_provider(self, config: GlobalConfig):
def _save(self, config: GlobalConfig):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
-
- with open(self.save_path, 'w') as f:
+
+ with open(self.save_path, "w") as f:
f.write(config.json())
def _load(self):
if os.path.exists(self.save_path):
- with open(self.save_path, 'r', encoding='utf-8') as f:
+ with open(self.save_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
self.update(config, False)
return
-
+
# otherwise, create a new empty config file
self.update(GlobalConfig(), True)
def get_config(self):
- return self.config
\ No newline at end of file
+ return self.config
diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py
index f5434c8b6..1792a0b31 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/default.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/default.py
@@ -1,21 +1,18 @@
-from typing import Dict, Type, List
-import ray
+from typing import Dict, List, Type
+import ray
+from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor
+from jupyter_ai.actors.memory import RemoteMemory
+from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
+from jupyter_ai_magics.providers import BaseProvider
from langchain import ConversationChain
from langchain.prompts import (
- ChatPromptTemplate,
- MessagesPlaceholder,
+ ChatPromptTemplate,
HumanMessagePromptTemplate,
- SystemMessagePromptTemplate
-)
-from langchain.schema import (
- AIMessage,
+ MessagesPlaceholder,
+ SystemMessagePromptTemplate,
)
-
-from jupyter_ai.actors.base import BaseActor, ACTOR_TYPE
-from jupyter_ai.actors.memory import RemoteMemory
-from jupyter_ai.models import HumanChatMessage, ClearMessage, ChatMessage
-from jupyter_ai_magics.providers import BaseProvider
+from langchain.schema import AIMessage
SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
@@ -28,6 +25,7 @@
The following is a friendly conversation between you and a human.
""".strip()
+
@ray.remote
class DefaultActor(BaseActor):
def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
@@ -35,25 +33,27 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
self.memory = None
self.chat_history = chat_history
- def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
+ def create_llm_chain(
+ self, provider: Type[BaseProvider], provider_params: Dict[str, str]
+ ):
llm = provider(**provider_params)
self.memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY)
- prompt_template = ChatPromptTemplate.from_messages([
- SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(provider_name=llm.name, local_model_id=llm.model_id),
- MessagesPlaceholder(variable_name="history"),
- HumanMessagePromptTemplate.from_template("{input}"),
- AIMessage(content="")
- ])
+ prompt_template = ChatPromptTemplate.from_messages(
+ [
+ SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(
+ provider_name=llm.name, local_model_id=llm.model_id
+ ),
+ MessagesPlaceholder(variable_name="history"),
+ HumanMessagePromptTemplate.from_template("{input}"),
+ AIMessage(content=""),
+ ]
+ )
self.llm = llm
self.llm_chain = ConversationChain(
- llm=llm,
- prompt=prompt_template,
- verbose=True,
- memory=self.memory
+ llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
)
-
+
def clear_memory(self):
-
# clear chain memory
if self.memory:
self.memory.clear()
@@ -68,8 +68,5 @@ def clear_memory(self):
def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()
- response = self.llm_chain.predict(
- input=message.body,
- stop=["\nHuman:"]
- )
+ response = self.llm_chain.predict(input=message.body, stop=["\nHuman:"])
self.reply(response, message)
diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py
index 068ce0388..3067e85e8 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py
@@ -1,10 +1,10 @@
-from jupyter_ai.actors.base import Logger, ACTOR_TYPE
-from jupyter_ai.models import GlobalConfig
import ray
+from jupyter_ai.actors.base import ACTOR_TYPE, Logger
+from jupyter_ai.models import GlobalConfig
-@ray.remote
-class EmbeddingsProviderActor():
+@ray.remote
+class EmbeddingsProviderActor:
def __init__(self, log: Logger):
self.log = log
self.provider = None
@@ -17,26 +17,28 @@ def update(self, config: GlobalConfig):
local_model_id, provider = ray.get(
actor.get_embeddings_provider_data.remote(model_id)
)
-
+
if not provider:
raise ValueError(f"No provider and model found with '{model_id}'")
-
+
provider_params = {}
provider_params[provider.model_id_key] = local_model_id
-
+
auth_strategy = provider.auth_strategy
if auth_strategy and auth_strategy.type == "env":
api_keys = config.api_keys
name = auth_strategy.name
if name not in api_keys:
- raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.")
+ raise ValueError(
+ f"Missing value for '{auth_strategy.name}' in the config."
+ )
provider_params[name.lower()] = api_keys[name]
-
+
self.provider = provider.provider_klass
self.provider_params = provider_params
previous_model_id = self.model_id
self.model_id = model_id
-
+
if previous_model_id and previous_model_id != model_id:
# delete the index
actor = ray.get_actor(ACTOR_TYPE.LEARN)
@@ -44,9 +46,9 @@ def update(self, config: GlobalConfig):
def get_provider(self):
return self.provider
-
+
def get_provider_params(self):
return self.provider_params
-
+
def get_model_id(self):
- return self.model_id
\ No newline at end of file
+ return self.model_id
diff --git a/packages/jupyter-ai/jupyter_ai/actors/generate.py b/packages/jupyter-ai/jupyter_ai/actors/generate.py
index a240078b5..07efa63a8 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/generate.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/generate.py
@@ -2,19 +2,15 @@
import os
from typing import Dict, Type
-import ray
-from ray.util.queue import Queue
-
-from langchain.llms import BaseLLM
-from langchain.prompts import PromptTemplate
-from langchain.llms import BaseLLM
-from langchain.chains import LLMChain
-
import nbformat
-
-from jupyter_ai.models import HumanChatMessage
+import ray
from jupyter_ai.actors.base import BaseActor, Logger
+from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider, ChatOpenAINewProvider
+from langchain.chains import LLMChain
+from langchain.llms import BaseLLM
+from langchain.prompts import PromptTemplate
+from ray.util.queue import Queue
schema = """{
"$schema": "http://json-schema.org/draft-07/schema#",
@@ -42,11 +38,12 @@
"required": ["sections"]
}"""
+
class NotebookOutlineChain(LLMChain):
"""Chain to generate a notebook outline, with section titles and descriptions."""
@classmethod
- def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
+ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
task_creation_template = (
"You are an AI that creates a detailed content outline for a Jupyter notebook on a given topic.\n"
"Generate the outline as JSON data that will validate against this JSON schema:\n"
@@ -56,24 +53,23 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
)
prompt = PromptTemplate(
template=task_creation_template,
- input_variables=[
- "description",
- "schema"
- ],
+ input_variables=["description", "schema"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
+
def generate_outline(description, llm=None, verbose=False):
"""Generate an outline of sections given a description of a notebook."""
chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose)
outline = chain.predict(description=description, schema=schema)
return json.loads(outline)
+
class CodeImproverChain(LLMChain):
"""Chain to improve source code."""
@classmethod
- def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
+ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
task_creation_template = (
"Improve the following code and make sure it is valid. Make sure to return the improved code only - don't give an explanation of the improvements.\n"
"{code}"
@@ -86,18 +82,22 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
+
def improve_code(code, llm=None, verbose=False):
"""Improve source code using an LLM."""
chain = CodeImproverChain.from_llm(llm=llm, verbose=verbose)
improved_code = chain.predict(code=code)
- improved_code = '\n'.join([line for line in improved_code.split('/n') if not line.startswith("```")])
+ improved_code = "\n".join(
+ [line for line in improved_code.split("/n") if not line.startswith("```")]
+ )
return improved_code
+
class NotebookSectionCodeChain(LLMChain):
"""Chain to generate source code for a notebook section."""
@classmethod
- def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
+ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
task_creation_template = (
"You are an AI that writes code for a single section of a Jupyter notebook.\n"
"Overall topic of the notebook: {description}\n"
@@ -110,35 +110,32 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
)
prompt = PromptTemplate(
template=task_creation_template,
- input_variables=[
- "description",
- "title",
- "content",
- "code_so_far"
- ],
+ input_variables=["description", "title", "content", "code_so_far"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
+
def generate_code(outline, llm=None, verbose=False):
"""Generate source code for a section given a description of the notebook and section."""
chain = NotebookSectionCodeChain.from_llm(llm=llm, verbose=verbose)
code_so_far = []
- for section in outline['sections']:
+ for section in outline["sections"]:
code = chain.predict(
- description=outline['description'],
- title=section['title'],
- content=section['content'],
- code_so_far='\n'.join(code_so_far)
+ description=outline["description"],
+ title=section["title"],
+ content=section["content"],
+ code_so_far="\n".join(code_so_far),
)
- section['code'] = improve_code(code, llm=llm, verbose=verbose)
- code_so_far.append(section['code'])
+ section["code"] = improve_code(code, llm=llm, verbose=verbose)
+ code_so_far.append(section["code"])
return outline
+
class NotebookSummaryChain(LLMChain):
"""Chain to generate a short summary of a notebook."""
@classmethod
- def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
+ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
task_creation_template = (
"Create a markdown summary for a Jupyter notebook with the following content."
" The summary should consist of a single paragraph.\n"
@@ -152,11 +149,12 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
+
class NotebookTitleChain(LLMChain):
"""Chain to generate the title of a notebook."""
@classmethod
- def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
+ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
task_creation_template = (
"Create a short, few word, descriptive title for a Jupyter notebook with the following content.\n"
"Content:\n{content}"
@@ -169,32 +167,35 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain:
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
+
def generate_title_and_summary(outline, llm=None, verbose=False):
"""Generate a title and summary of a notebook outline using an LLM."""
summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose)
title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose)
summary = summary_chain.predict(content=outline)
title = title_chain.predict(content=outline)
- outline['summary'] = summary
- outline['title'] = title.strip('"')
+ outline["summary"] = summary
+ outline["title"] = title.strip('"')
return outline
+
def create_notebook(outline):
"""Create an nbformat Notebook object for a notebook outline."""
nbf = nbformat.v4
nb = nbf.new_notebook()
- nb['cells'].append(nbf.new_markdown_cell('# ' + outline['title']))
- nb['cells'].append(nbf.new_markdown_cell('## Introduction'))
+ nb["cells"].append(nbf.new_markdown_cell("# " + outline["title"]))
+ nb["cells"].append(nbf.new_markdown_cell("## Introduction"))
disclaimer = f"This notebook was created by [Jupyter AI](https://github.com/jupyterlab/jupyter-ai) with the following prompt:\n\n> {outline['prompt']}"
- nb['cells'].append(nbf.new_markdown_cell(disclaimer))
- nb['cells'].append(nbf.new_markdown_cell(outline['summary']))
+ nb["cells"].append(nbf.new_markdown_cell(disclaimer))
+ nb["cells"].append(nbf.new_markdown_cell(outline["summary"]))
- for section in outline['sections'][1:]:
- nb['cells'].append(nbf.new_markdown_cell('## ' + section['title']))
- for code_block in section['code'].split('\n\n'):
- nb['cells'].append(nbf.new_code_cell(code_block))
+ for section in outline["sections"][1:]:
+ nb["cells"].append(nbf.new_markdown_cell("## " + section["title"]))
+ for code_block in section["code"].split("\n\n"):
+ nb["cells"].append(nbf.new_code_cell(code_block))
return nb
+
@ray.remote
class GenerateActor(BaseActor):
"""A Ray actor to generate a Jupyter notebook given a description."""
@@ -204,25 +205,27 @@ def __init__(self, reply_queue: Queue, root_dir: str, log: Logger):
self.root_dir = os.path.abspath(os.path.expanduser(root_dir))
self.llm = None
- def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
+ def create_llm_chain(
+ self, provider: Type[BaseProvider], provider_params: Dict[str, str]
+ ):
llm = provider(**provider_params)
self.llm = llm
return llm
-
+
def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()
-
+
response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions."
self.reply(response, message)
prompt = message.body
outline = generate_outline(prompt, llm=self.llm, verbose=True)
# Save the user input prompt, the description property is now LLM generated.
- outline['prompt'] = prompt
+ outline["prompt"] = prompt
outline = generate_code(outline, llm=self.llm, verbose=True)
outline = generate_title_and_summary(outline, llm=self.llm)
notebook = create_notebook(outline)
- final_path = os.path.join(self.root_dir, outline['title'] + '.ipynb')
+ final_path = os.path.join(self.root_dir, outline["title"] + ".ipynb")
nbformat.write(notebook, final_path)
response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it."""
self.reply(response, message)
diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py
index 1754cb9da..ef4e352e0 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/learn.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py
@@ -1,59 +1,59 @@
+import argparse
import json
import os
-import argparse
import time
from typing import List
import ray
-from ray.util.queue import Queue
-
+from jupyter_ai.actors.base import BaseActor, Logger
+from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader
+from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
+from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata
from jupyter_core.paths import jupyter_data_dir
-
from langchain import FAISS
+from langchain.schema import Document
from langchain.text_splitter import (
- RecursiveCharacterTextSplitter, PythonCodeTextSplitter,
- MarkdownTextSplitter, LatexTextSplitter
+ LatexTextSplitter,
+ MarkdownTextSplitter,
+ PythonCodeTextSplitter,
+ RecursiveCharacterTextSplitter,
)
-from langchain.schema import Document
-
-from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata
-from jupyter_ai.actors.base import BaseActor, Logger
-from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader
-from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
+from ray.util.queue import Queue
+INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices")
+METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json")
-INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices')
-METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, 'metadata.json')
@ray.remote
class LearnActor(BaseActor):
-
def __init__(self, reply_queue: Queue, log: Logger, root_dir: str):
super().__init__(reply_queue=reply_queue, log=log)
self.root_dir = root_dir
self.chunk_size = 2000
self.chunk_overlap = 100
- self.parser.prog = '/learn'
- self.parser.add_argument('-v', '--verbose', action='store_true')
- self.parser.add_argument('-d', '--delete', action='store_true')
- self.parser.add_argument('-l', '--list', action='store_true')
- self.parser.add_argument('path', nargs=argparse.REMAINDER)
- self.index_name = 'default'
+ self.parser.prog = "/learn"
+ self.parser.add_argument("-v", "--verbose", action="store_true")
+ self.parser.add_argument("-d", "--delete", action="store_true")
+ self.parser.add_argument("-l", "--list", action="store_true")
+ self.parser.add_argument("path", nargs=argparse.REMAINDER)
+ self.index_name = "default"
self.index = None
self.metadata = IndexMetadata(dirs=[])
-
+
if not os.path.exists(INDEX_SAVE_DIR):
os.makedirs(INDEX_SAVE_DIR)
-
- self.load_or_create()
-
+
+ self.load_or_create()
+
def _process_message(self, message: HumanChatMessage):
if not self.index:
self.load_or_create()
# If index is not still there, embeddings are not present
if not self.index:
- self.reply("Sorry, please select an embedding provider before using the `/learn` command.")
+ self.reply(
+ "Sorry, please select an embedding provider before using the `/learn` command."
+ )
args = self.parse_args(message)
if args is None:
@@ -63,7 +63,7 @@ def _process_message(self, message: HumanChatMessage):
self.delete()
self.reply(f"👍 I have deleted everything I previously learned.", message)
return
-
+
if args.list:
self.reply(self._build_list_response())
return
@@ -81,18 +81,18 @@ def _process_message(self, message: HumanChatMessage):
if args.verbose:
self.reply(f"Loading and splitting files for {load_path}", message)
-
+
self.learn_dir(load_path)
self.save()
- response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
+ response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(response, message)
def _build_list_response(self):
if not self.metadata.dirs:
return "There are no docs that have been learned yet."
-
+
dirs = [dir.path for dir in self.metadata.dirs]
dir_list = "\n- " + "\n- ".join(dirs) + "\n\n"
message = f"""I can answer questions from docs in these directories:
@@ -100,22 +100,32 @@ def _build_list_response(self):
return message
def learn_dir(self, path: str):
- splitters={
- '.py': PythonCodeTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
- '.md': MarkdownTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
- '.tex': LatexTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
- '.ipynb': NotebookSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
+ splitters = {
+ ".py": PythonCodeTextSplitter(
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
+ ),
+ ".md": MarkdownTextSplitter(
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
+ ),
+ ".tex": LatexTextSplitter(
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
+ ),
+ ".ipynb": NotebookSplitter(
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
+ ),
}
splitter = ExtensionSplitter(
splitters=splitters,
- default_splitter=RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
+ default_splitter=RecursiveCharacterTextSplitter(
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
+ ),
)
loader = RayRecursiveDirectoryLoader(path)
texts = loader.load_and_split(text_splitter=splitter)
self.index.add_documents(texts)
self._add_dir_to_metadata(path)
-
+
def _add_dir_to_metadata(self, path: str):
dirs = self.metadata.dirs
index = next((i for i, dir in enumerate(dirs) if dir.path == path), None)
@@ -128,10 +138,10 @@ def delete_and_relearn(self):
self.delete()
return
message = """🔔 Hi there, It seems like you have updated the embeddings model. For the **/ask**
- command to work with the new model, I have to re-learn the documents you had previously
+ command to work with the new model, I have to re-learn the documents you had previously
submitted for learning. Please wait to use the **/ask** command until I am done with this task."""
self.reply(message)
-
+
metadata = self.metadata
self.delete()
self.relearn(metadata)
@@ -139,7 +149,10 @@ def delete_and_relearn(self):
def delete(self):
self.index = None
self.metadata = IndexMetadata(dirs=[])
- paths = [os.path.join(INDEX_SAVE_DIR, self.index_name+ext) for ext in ['.pkl', '.faiss']]
+ paths = [
+ os.path.join(INDEX_SAVE_DIR, self.index_name + ext)
+ for ext in [".pkl", ".faiss"]
+ ]
for path in paths:
if os.path.isfile(path):
os.remove(path)
@@ -148,16 +161,18 @@ def delete(self):
def relearn(self, metadata: IndexMetadata):
# Index all dirs in the metadata
if not metadata.dirs:
- return
-
+ return
+
for dir in metadata.dirs:
self.learn_dir(dir.path)
-
+
self.save()
- dir_list = "\n- " + "\n- ".join([dir.path for dir in self.metadata.dirs]) + "\n\n"
+ dir_list = (
+ "\n- " + "\n- ".join([dir.path for dir in self.metadata.dirs]) + "\n\n"
+ )
message = f"""🎉 I am done learning docs in these directories:
- {dir_list} I am ready to answer questions about them.
+ {dir_list} I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(message)
@@ -165,17 +180,22 @@ def create(self):
embeddings = self.get_embeddings()
if not embeddings:
return
- self.index = FAISS.from_texts(["Jupyternaut knows about your filesystem, to ask questions first use the /learn command."], embeddings)
+ self.index = FAISS.from_texts(
+ [
+ "Jupyternaut knows about your filesystem, to ask questions first use the /learn command."
+ ],
+ embeddings,
+ )
self.save()
def save(self):
if self.index is not None:
self.index.save_local(INDEX_SAVE_DIR, index_name=self.index_name)
-
+
self.save_metadata()
def save_metadata(self):
- with open(METADATA_SAVE_PATH, 'w') as f:
+ with open(METADATA_SAVE_PATH, "w") as f:
f.write(self.metadata.json())
def load_or_create(self):
@@ -184,17 +204,19 @@ def load_or_create(self):
return
if self.index is None:
try:
- self.index = FAISS.load_local(INDEX_SAVE_DIR, embeddings, index_name=self.index_name)
+ self.index = FAISS.load_local(
+ INDEX_SAVE_DIR, embeddings, index_name=self.index_name
+ )
self.load_metadata()
except Exception as e:
self.create()
def load_metadata(self):
if not os.path.exists(METADATA_SAVE_PATH):
- return
-
- with open(METADATA_SAVE_PATH, 'r', encoding='utf-8') as f:
- j = json.loads(f.read())
+ return
+
+ with open(METADATA_SAVE_PATH, encoding="utf-8") as f:
+ j = json.loads(f.read())
self.metadata = IndexMetadata(**j)
def get_relevant_documents(self, question: str) -> List[Document]:
diff --git a/packages/jupyter-ai/jupyter_ai/actors/memory.py b/packages/jupyter-ai/jupyter_ai/actors/memory.py
index 808537599..98ee9b144 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/memory.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/memory.py
@@ -1,91 +1,89 @@
-from typing import Dict, Any, List
+from typing import Any, Dict, List
import ray
+from jupyter_ai.actors.base import Logger
from langchain.schema import BaseMemory
from pydantic import PrivateAttr
-from jupyter_ai.actors.base import Logger
@ray.remote
-class MemoryActor(object):
+class MemoryActor:
"""Turns any LangChain memory into a Ray actor.
-
+
The resulting actor can be used as LangChain memory in chains
running in different actors by using RemoteMemory (below).
"""
-
+
def __init__(self, log: Logger, memory: BaseMemory):
self.memory = memory
self.log = log
-
+
def get_chat_memory(self):
return self.memory.chat_memory
-
+
def get_output_key(self):
return self.memory.output_key
-
+
def get_input_key(self):
return self.memory.input_key
def get_return_messages(self):
return self.memory.return_messages
-
+
def get_memory_variables(self):
return self.memory.memory_variables
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
return self.memory.save_context(inputs, outputs)
-
+
def clear(self):
return self.memory.clear()
-
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return self.memory.load_memory_variables(inputs)
-
-
class RemoteMemory(BaseMemory):
"""Wraps a MemoryActor into a LangChain memory class.
-
+
This enables you to use a single distributed memory across multiple
Ray actors running different chains.
"""
-
+
actor_name: str
_actor: Any = PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
self._actor = ray.get_actor(self.actor_name)
-
+
@property
def memory_variables(self) -> List[str]:
o = self._actor.get_memory_variables.remote()
return ray.get(o)
-
+
@property
def output_key(self):
o = self._actor.get_output_key.remote()
return ray.get(o)
-
+
@property
def input_key(self):
o = self._actor.get_input_key.remote()
return ray.get(o)
-
+
@property
def return_messages(self):
o = self._actor.get_return_messages.remote()
return ray.get(o)
-
+
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
o = self._actor.save_context.remote(inputs, outputs)
return ray.get(o)
-
+
def clear(self):
self._actor.clear.remote()
-
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
o = self._actor.load_memory_variables.remote(inputs)
- return ray.get(o)
\ No newline at end of file
+ return ray.get(o)
diff --git a/packages/jupyter-ai/jupyter_ai/actors/providers.py b/packages/jupyter-ai/jupyter_ai/actors/providers.py
index fd249ede9..7b11e1773 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/providers.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/providers.py
@@ -1,18 +1,24 @@
from typing import Optional, Tuple, Type
-from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
-from jupyter_ai_magics.providers import BaseProvider
-from jupyter_ai_magics.utils import decompose_model_id, load_embedding_providers, load_providers
+
import ray
from jupyter_ai.actors.base import BaseActor, Logger
+from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
+from jupyter_ai_magics.providers import BaseProvider
+from jupyter_ai_magics.utils import (
+ decompose_model_id,
+ load_embedding_providers,
+ load_providers,
+)
from ray.util.queue import Queue
+
@ray.remote
-class ProvidersActor():
+class ProvidersActor:
"""Actor that loads model and embedding providers from,
- entry points. Also provides utility functions to get the
+ entry points. Also provides utility functions to get the
providers and provider class matching a provider id.
"""
-
+
def __init__(self, log: Logger):
self.log = log
self.model_providers = load_providers(log=log)
@@ -21,22 +27,23 @@ def __init__(self, log: Logger):
def get_model_providers(self):
"""Returns dictionary of registered LLM providers"""
return self.model_providers
-
+
def get_model_provider_data(self, model_id: str) -> Tuple[str, Type[BaseProvider]]:
"""Returns the model provider class that matches the provider id"""
provider_id, local_model_id = decompose_model_id(model_id, self.model_providers)
provider = self.model_providers.get(provider_id, None)
return local_model_id, provider
-
+
def get_embeddings_providers(self):
"""Returns dictionary of registered embedding providers"""
return self.embeddings_providers
- def get_embeddings_provider_data(self, model_id: str) -> Tuple[str, Type[BaseEmbeddingsProvider]]:
+ def get_embeddings_provider_data(
+ self, model_id: str
+ ) -> Tuple[str, Type[BaseEmbeddingsProvider]]:
"""Returns the embedding provider class that matches the provider id"""
- provider_id, local_model_id = decompose_model_id(model_id, self.embeddings_providers)
+ provider_id, local_model_id = decompose_model_id(
+ model_id, self.embeddings_providers
+ )
provider = self.embeddings_providers.get(provider_id, None)
return local_model_id, provider
-
-
-
\ No newline at end of file
diff --git a/packages/jupyter-ai/jupyter_ai/actors/router.py b/packages/jupyter-ai/jupyter_ai/actors/router.py
index fbc3234da..c9129e58a 100644
--- a/packages/jupyter-ai/jupyter_ai/actors/router.py
+++ b/packages/jupyter-ai/jupyter_ai/actors/router.py
@@ -1,29 +1,28 @@
import ray
+from jupyter_ai.actors.base import ACTOR_TYPE, COMMANDS, BaseActor, Logger
from ray.util.queue import Queue
-from jupyter_ai.actors.base import ACTOR_TYPE, COMMANDS, Logger, BaseActor
@ray.remote
class Router(BaseActor):
def __init__(self, reply_queue: Queue, log: Logger):
"""Routes messages to the correct actor.
-
- To register new actors, add the actor type in the `ACTOR_TYPE` enum and
+
+ To register new actors, add the actor type in the `ACTOR_TYPE` enum and
add a corresponding command in the `COMMANDS` dictionary.
"""
super().__init__(reply_queue=reply_queue, log=log)
def _process_message(self, message):
-
# assign default actor
default = ray.get_actor(ACTOR_TYPE.DEFAULT)
if message.body.startswith("/"):
- command = message.body.split(' ', 1)[0]
+ command = message.body.split(" ", 1)[0]
if command in COMMANDS.keys():
actor = ray.get_actor(COMMANDS[command].value)
actor.process_message.remote(message)
- if command == '/clear':
+ if command == "/clear":
actor = ray.get_actor(ACTOR_TYPE.DEFAULT)
actor.clear_memory.remote()
else:
diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
index f87a74e30..98098da7e 100644
--- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
+++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
@@ -1,31 +1,31 @@
-from typing import List, Optional
-from pathlib import Path
import hashlib
import itertools
+from pathlib import Path
+from typing import List, Optional
import ray
-
from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
from langchain.document_loaders.directory import _is_visible
-from langchain.text_splitter import (
- RecursiveCharacterTextSplitter, TextSplitter,
-)
+from langchain.schema import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
+
@ray.remote
def path_to_doc(path):
- with open(str(path), 'r') as f:
+ with open(str(path)) as f:
text = f.read()
m = hashlib.sha256()
- m.update(text.encode('utf-8'))
- metadata = {'path': str(path), 'sha256': m.digest(), 'extension': path.suffix}
+ m.update(text.encode("utf-8"))
+ metadata = {"path": str(path), "sha256": m.digest(), "extension": path.suffix}
return Document(page_content=text, metadata=metadata)
+
class ExcludePattern(Exception):
pass
-
+
+
def iter_paths(path, extensions, exclude):
- for p in Path(path).rglob('*'):
+ for p in Path(path).rglob("*"):
if p.is_dir():
continue
if not _is_visible(p.relative_to(path)):
@@ -39,23 +39,43 @@ def iter_paths(path, extensions, exclude):
if p.suffix in extensions:
yield p
+
class RayRecursiveDirectoryLoader(BaseLoader):
-
def __init__(
self,
path,
- extensions={'.py', '.md', '.R', '.Rmd', '.jl', '.sh', '.ipynb', '.js', '.ts', '.jsx', '.tsx', '.txt'},
- exclude={'.ipynb_checkpoints', 'node_modules', 'lib', 'build', '.git', '.DS_Store'}
+ extensions={
+ ".py",
+ ".md",
+ ".R",
+ ".Rmd",
+ ".jl",
+ ".sh",
+ ".ipynb",
+ ".js",
+ ".ts",
+ ".jsx",
+ ".tsx",
+ ".txt",
+ },
+ exclude={
+ ".ipynb_checkpoints",
+ "node_modules",
+ "lib",
+ "build",
+ ".git",
+ ".DS_Store",
+ },
):
self.path = path
self.extensions = extensions
- self.exclude=exclude
-
+ self.exclude = exclude
+
def load(self) -> List[Document]:
paths = iter_paths(self.path, self.extensions, self.exclude)
doc_refs = list(map(path_to_doc.remote, paths))
return ray.get(doc_refs)
-
+
def load_and_split(
self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]:
@@ -67,7 +87,7 @@ def load_and_split(
@ray.remote
def split(doc):
return _text_splitter.split_documents([doc])
-
+
paths = iter_paths(self.path, self.extensions, self.exclude)
doc_refs = map(split.remote, map(path_to_doc.remote, paths))
- return list(itertools.chain(*ray.get(list(doc_refs))))
\ No newline at end of file
+ return list(itertools.chain(*ray.get(list(doc_refs))))
diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py b/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py
index d80289453..d3b32f478 100644
--- a/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py
+++ b/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py
@@ -1,44 +1,50 @@
+import copy
from typing import List, Optional
from langchain.schema import Document
-from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter, MarkdownTextSplitter
-import copy
+from langchain.text_splitter import (
+ MarkdownTextSplitter,
+ RecursiveCharacterTextSplitter,
+ TextSplitter,
+)
+
class ExtensionSplitter(TextSplitter):
-
def __init__(self, splitters, default_splitter=None):
self.splitters = splitters
if default_splitter is None:
self.default_splitter = RecursiveCharacterTextSplitter()
else:
self.default_splitter = default_splitter
-
+
def split_text(self, text: str, metadata=None):
- splitter = self.splitters.get(metadata['extension'], self.default_splitter)
+ splitter = self.splitters.get(metadata["extension"], self.default_splitter)
return splitter.split_text(text)
- def create_documents(self, texts: List[str], metadatas: Optional[List[dict]] = None) -> List[Document]:
+ def create_documents(
+ self, texts: List[str], metadatas: Optional[List[dict]] = None
+ ) -> List[Document]:
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
metadata = copy.deepcopy(_metadatas[i])
for chunk in self.split_text(text, metadata):
- new_doc = Document(
- page_content=chunk, metadata=metadata
- )
+ new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents
+
import nbformat
+
class NotebookSplitter(TextSplitter):
-
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.markdown_splitter = MarkdownTextSplitter(chunk_size=self._chunk_size, chunk_overlap=self._chunk_overlap)
-
+ self.markdown_splitter = MarkdownTextSplitter(
+ chunk_size=self._chunk_size, chunk_overlap=self._chunk_overlap
+ )
+
def split_text(self, text: str):
nb = nbformat.reads(text, as_version=4)
- md = '\n\n'.join([cell.source for cell in nb.cells])
+ md = "\n\n".join([cell.source for cell in nb.cells])
return self.markdown_splitter.split_text(md)
-
diff --git a/packages/jupyter-ai/jupyter_ai/engine.py b/packages/jupyter-ai/jupyter_ai/engine.py
index 11e0f7571..d4dce4595 100644
--- a/packages/jupyter-ai/jupyter_ai/engine.py
+++ b/packages/jupyter-ai/jupyter_ai/engine.py
@@ -1,12 +1,16 @@
-from abc import abstractmethod, ABC, ABCMeta
+from abc import ABC, ABCMeta, abstractmethod
from typing import Dict
+
import openai
from traitlets.config import LoggingConfigurable, Unicode
+
from .task_manager import DescribeTaskResponse
+
class BaseModelEngineMetaclass(ABCMeta, type(LoggingConfigurable)):
pass
+
class BaseModelEngine(ABC, LoggingConfigurable, metaclass=BaseModelEngineMetaclass):
id: str
name: str
@@ -14,34 +18,33 @@ class BaseModelEngine(ABC, LoggingConfigurable, metaclass=BaseModelEngineMetacla
# these two attributes are currently reserved but unused.
input_type: str
output_type: str
-
+
@abstractmethod
- async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]):
+ async def execute(
+ self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]
+ ):
pass
+
class GPT3ModelEngine(BaseModelEngine):
id = "gpt3"
name = "GPT-3"
- modalities = [
- "txt2txt"
- ]
+ modalities = ["txt2txt"]
- api_key = Unicode(
- config=True,
- help="OpenAI API key",
- allow_none=False
- )
+ api_key = Unicode(config=True, help="OpenAI API key", allow_none=False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]):
+ async def execute(
+ self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]
+ ):
if "body" not in prompt_variables:
raise Exception("Prompt body must be specified.")
prompt = task.prompt_template.format(**prompt_variables)
self.log.info(f"GPT3 prompt:\n{prompt}")
-
+
openai.api_key = self.api_key
response = await openai.Completion.acreate(
model="text-davinci-003",
@@ -50,6 +53,6 @@ async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str,
max_tokens=256,
top_p=1,
frequency_penalty=0,
- presence_penalty=0
+ presence_penalty=0,
)
- return response['choices'][0]['text']
\ No newline at end of file
+ return response["choices"][0]["text"]
diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py
index 837bb68f8..df40fcf24 100644
--- a/packages/jupyter-ai/jupyter_ai/extension.py
+++ b/packages/jupyter-ai/jupyter_ai/extension.py
@@ -1,41 +1,36 @@
import asyncio
+import inspect
+import ray
+from importlib_metadata import entry_points
+from jupyter_ai.actors.ask import AskActor
+from jupyter_ai.actors.base import ACTOR_TYPE
from jupyter_ai.actors.chat_provider import ChatProviderActor
from jupyter_ai.actors.config import ConfigActor
+from jupyter_ai.actors.default import DefaultActor
from jupyter_ai.actors.embeddings_provider import EmbeddingsProviderActor
-from jupyter_ai.actors.providers import ProvidersActor
-
-from jupyter_ai_magics.utils import load_providers
-
-from langchain.memory import ConversationBufferWindowMemory
-from jupyter_ai.actors.default import DefaultActor
-from jupyter_ai.actors.ask import AskActor
+from jupyter_ai.actors.generate import GenerateActor
from jupyter_ai.actors.learn import LearnActor
-from jupyter_ai.actors.router import Router
from jupyter_ai.actors.memory import MemoryActor
-from jupyter_ai.actors.generate import GenerateActor
-from jupyter_ai.actors.base import ACTOR_TYPE
+from jupyter_ai.actors.providers import ProvidersActor
+from jupyter_ai.actors.router import Router
from jupyter_ai.reply_processor import ReplyProcessor
+from jupyter_ai_magics.utils import load_providers
from jupyter_server.extension.application import ExtensionApp
+from langchain.memory import ConversationBufferWindowMemory
+from ray.util.queue import Queue
+from .engine import BaseModelEngine
from .handlers import (
- ChatHandler,
- ChatHistoryHandler,
- EmbeddingsModelProviderHandler,
- ModelProviderHandler,
- PromptAPIHandler,
+ ChatHandler,
+ ChatHistoryHandler,
+ EmbeddingsModelProviderHandler,
+ GlobalConfigHandler,
+ ModelProviderHandler,
+ PromptAPIHandler,
TaskAPIHandler,
- GlobalConfigHandler
)
-from importlib_metadata import entry_points
-import inspect
-from .engine import BaseModelEngine
-
-import ray
-from ray.util.queue import Queue
-from jupyter_ai_magics.utils import load_providers
-
class AiExtension(ExtensionApp):
name = "jupyter_ai"
@@ -51,41 +46,48 @@ class AiExtension(ExtensionApp):
]
@property
- def ai_engines(self):
+ def ai_engines(self):
if "ai_engines" not in self.settings:
self.settings["ai_engines"] = {}
return self.settings["ai_engines"]
-
def initialize_settings(self):
ray.init()
# EP := entry point
eps = entry_points()
-
+
## step 1: instantiate model engines and bind them to settings
model_engine_class_eps = eps.select(group="jupyter_ai.model_engine_classes")
-
+
if not model_engine_class_eps:
- self.log.error("No model engines found for jupyter_ai.model_engine_classes group. One or more model engines are required for AI extension to work.")
+ self.log.error(
+ "No model engines found for jupyter_ai.model_engine_classes group. One or more model engines are required for AI extension to work."
+ )
return
for model_engine_class_ep in model_engine_class_eps:
try:
Engine = model_engine_class_ep.load()
except:
- self.log.error(f"Unable to load model engine class from entry point `{model_engine_class_ep.name}`.")
+ self.log.error(
+ f"Unable to load model engine class from entry point `{model_engine_class_ep.name}`."
+ )
continue
if not inspect.isclass(Engine) or not issubclass(Engine, BaseModelEngine):
- self.log.error(f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}` as it is not a subclass of `BaseModelEngine`.")
+ self.log.error(
+ f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}` as it is not a subclass of `BaseModelEngine`."
+ )
continue
try:
self.ai_engines[Engine.id] = Engine(config=self.config, log=self.log)
except:
- self.log.error(f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}`.")
+ self.log.error(
+ f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}`."
+ )
continue
self.log.info(f"Registered engine `{Engine.id}`.")
@@ -96,15 +98,17 @@ def initialize_settings(self):
if not module_default_tasks_eps:
self.settings["ai_default_tasks"] = []
return
-
+
default_tasks = []
for module_default_tasks_ep in module_default_tasks_eps:
try:
module_default_tasks = module_default_tasks_ep.load()
except:
- self.log.error(f"Unable to load task from entry point `{module_default_tasks_ep.name}`")
+ self.log.error(
+ f"Unable to load task from entry point `{module_default_tasks_ep.name}`"
+ )
continue
-
+
default_tasks += module_default_tasks
self.settings["ai_default_tasks"] = default_tasks
@@ -119,13 +123,12 @@ def initialize_settings(self):
# Store chat clients in a dictionary
self.settings["chat_clients"] = {}
self.settings["chat_handlers"] = {}
-
+
# store chat messages in memory for now
# this is only used to render the UI, and is not the conversational
# memory object used by the LM chain.
self.settings["chat_history"] = []
-
reply_queue = Queue()
self.settings["reply_queue"] = reply_queue
@@ -134,21 +137,27 @@ def initialize_settings(self):
log=self.log,
)
default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote(
- reply_queue=reply_queue,
+ reply_queue=reply_queue,
log=self.log,
- chat_history=self.settings["chat_history"]
+ chat_history=self.settings["chat_history"],
)
- providers_actor = ProvidersActor.options(name=ACTOR_TYPE.PROVIDERS.value).remote(
+ providers_actor = ProvidersActor.options(
+ name=ACTOR_TYPE.PROVIDERS.value
+ ).remote(
log=self.log,
)
config_actor = ConfigActor.options(name=ACTOR_TYPE.CONFIG.value).remote(
log=self.log,
)
- chat_provider_actor = ChatProviderActor.options(name=ACTOR_TYPE.CHAT_PROVIDER.value).remote(
+ chat_provider_actor = ChatProviderActor.options(
+ name=ACTOR_TYPE.CHAT_PROVIDER.value
+ ).remote(
log=self.log,
)
- embeddings_provider_actor = EmbeddingsProviderActor.options(name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value).remote(
+ embeddings_provider_actor = EmbeddingsProviderActor.options(
+ name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value
+ ).remote(
log=self.log,
)
learn_actor = LearnActor.options(name=ACTOR_TYPE.LEARN.value).remote(
@@ -157,7 +166,7 @@ def initialize_settings(self):
root_dir=self.serverapp.root_dir,
)
ask_actor = AskActor.options(name=ACTOR_TYPE.ASK.value).remote(
- reply_queue=reply_queue,
+ reply_queue=reply_queue,
log=self.log,
)
memory_actor = MemoryActor.options(name=ACTOR_TYPE.MEMORY.value).remote(
@@ -165,22 +174,24 @@ def initialize_settings(self):
memory=ConversationBufferWindowMemory(return_messages=True, k=2),
)
generate_actor = GenerateActor.options(name=ACTOR_TYPE.GENERATE.value).remote(
- reply_queue=reply_queue,
+ reply_queue=reply_queue,
log=self.log,
- root_dir=self.settings['server_root_dir'],
+ root_dir=self.settings["server_root_dir"],
)
-
- self.settings['router'] = router
- self.settings['providers_actor'] = providers_actor
- self.settings['config_actor'] = config_actor
- self.settings['chat_provider_actor'] = chat_provider_actor
- self.settings['embeddings_provider_actor'] = embeddings_provider_actor
+
+ self.settings["router"] = router
+ self.settings["providers_actor"] = providers_actor
+ self.settings["config_actor"] = config_actor
+ self.settings["chat_provider_actor"] = chat_provider_actor
+ self.settings["embeddings_provider_actor"] = embeddings_provider_actor
self.settings["default_actor"] = default_actor
self.settings["learn_actor"] = learn_actor
self.settings["ask_actor"] = ask_actor
self.settings["memory_actor"] = memory_actor
self.settings["generate_actor"] = generate_actor
- reply_processor = ReplyProcessor(self.settings['chat_handlers'], reply_queue, log=self.log)
+ reply_processor = ReplyProcessor(
+ self.settings["chat_handlers"], reply_queue, log=self.log
+ )
loop = asyncio.get_event_loop()
- loop.create_task(reply_processor.start())
\ No newline at end of file
+ loop.create_task(reply_processor.start())
diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py
index b141da604..7dc021aac 100644
--- a/packages/jupyter-ai/jupyter_ai/handlers.py
+++ b/packages/jupyter-ai/jupyter_ai/handlers.py
@@ -1,45 +1,43 @@
-from dataclasses import asdict
+import getpass
import json
+import time
+import uuid
+from dataclasses import asdict
from typing import Dict, List
-from jupyter_ai.actors.base import ACTOR_TYPE
+
import ray
import tornado
-import uuid
-import time
-import getpass
-
-from tornado.web import HTTPError
+from jupyter_ai.actors.base import ACTOR_TYPE
+from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
+from jupyter_server.base.handlers import JupyterHandler
+from jupyter_server.utils import ensure_async
from pydantic import ValidationError
-
from tornado import web, websocket
-
-from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler
-from jupyter_server.utils import ensure_async
-
-from .task_manager import TaskManager
+from tornado.web import HTTPError
from .models import (
- ChatHistory,
- ChatUser,
- ListProvidersEntry,
- ListProvidersResponse,
- PromptRequest,
- ChatRequest,
- ChatMessage,
- Message,
- AgentChatMessage,
- HumanChatMessage,
- ConnectionMessage,
+ AgentChatMessage,
ChatClient,
- GlobalConfig
+ ChatHistory,
+ ChatMessage,
+ ChatRequest,
+ ChatUser,
+ ConnectionMessage,
+ GlobalConfig,
+ HumanChatMessage,
+ ListProvidersEntry,
+ ListProvidersResponse,
+ Message,
+ PromptRequest,
)
+from .task_manager import TaskManager
class APIHandler(BaseAPIHandler):
@property
- def engines(self):
+ def engines(self):
return self.settings["ai_engines"]
-
+
@property
def default_tasks(self):
return self.settings["ai_default_tasks"]
@@ -49,9 +47,12 @@ def task_manager(self):
# we have to create the TaskManager lazily, since no event loop is
# running in ServerApp.initialize_settings().
if "task_manager" not in self.settings:
- self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks)
+ self.settings["task_manager"] = TaskManager(
+ engines=self.engines, default_tasks=self.default_tasks
+ )
return self.settings["task_manager"]
-
+
+
class PromptAPIHandler(APIHandler):
@tornado.web.authenticated
async def post(self):
@@ -68,13 +69,13 @@ async def post(self):
task = await self.task_manager.describe_task(request.task_id)
if not task:
raise HTTPError(404, f"Task not found with ID: {request.task_id}")
-
+
output = await ensure_async(engine.execute(task, request.prompt_variables))
- self.finish(json.dumps({
- "output": output,
- "insertion_mode": task.insertion_mode
- }))
+ self.finish(
+ json.dumps({"output": output, "insertion_mode": task.insertion_mode})
+ )
+
class TaskAPIHandler(APIHandler):
@tornado.web.authenticated
@@ -83,7 +84,7 @@ async def get(self, id=None):
list_tasks_response = await self.task_manager.list_tasks()
self.finish(json.dumps(list_tasks_response.dict()))
return
-
+
describe_task_response = await self.task_manager.describe_task(id)
if describe_task_response is None:
raise HTTPError(404, f"Task not found with ID: {id}")
@@ -95,35 +96,32 @@ class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""
_messages = []
-
+
@property
def chat_history(self):
return self.settings["chat_history"]
-
+
@chat_history.setter
def _chat_history_setter(self, new_history):
self.settings["chat_history"] = new_history
-
+
@tornado.web.authenticated
async def get(self):
history = ChatHistory(messages=self.chat_history)
self.finish(history.json())
-class ChatHandler(
- JupyterHandler,
- websocket.WebSocketHandler
-):
+class ChatHandler(JupyterHandler, websocket.WebSocketHandler):
"""
A websocket handler for chat.
"""
-
+
@property
- def chat_handlers(self) -> Dict[str, 'ChatHandler']:
+ def chat_handlers(self) -> Dict[str, "ChatHandler"]:
"""Dictionary mapping client IDs to their WebSocket handler
instances."""
return self.settings["chat_handlers"]
-
+
@property
def chat_clients(self) -> Dict[str, ChatClient]:
"""Dictionary mapping client IDs to their ChatClient objects that store
@@ -143,8 +141,7 @@ def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)
def pre_get(self):
- """Handles authentication/authorization.
- """
+ """Handles authentication/authorization."""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
@@ -168,8 +165,7 @@ def get_chat_user(self) -> ChatUser:
if collaborative:
return ChatUser(**asdict(self.current_user))
-
-
+
login = getpass.getuser()
return ChatUser(
username=login,
@@ -177,10 +173,9 @@ def get_chat_user(self) -> ChatUser:
name=login,
display_name=login,
color=None,
- avatar_url=None
+ avatar_url=None,
)
-
def generate_client_id(self):
"""Generates a client ID to identify the current WS connection."""
return uuid.uuid4().hex
@@ -201,25 +196,27 @@ def open(self):
self.log.debug("Clients are : %s", self.chat_handlers.keys())
def broadcast_message(self, message: Message):
- """Broadcasts message to all connected clients.
+ """Broadcasts message to all connected clients.
Appends message to `self.chat_history`.
"""
self.log.debug("Broadcasting message: %s to all clients...", message)
client_ids = self.chat_handlers.keys()
-
+
for client_id in client_ids:
client = self.chat_handlers[client_id]
if client:
client.write_message(message.dict())
-
+
# Only append ChatMessage instances to history, not control messages
- if isinstance(message, HumanChatMessage) or isinstance(message, AgentChatMessage):
+ if isinstance(message, HumanChatMessage) or isinstance(
+ message, AgentChatMessage
+ ):
self.chat_history.append(message)
async def on_message(self, message):
self.log.debug("Message recieved: %s", message)
-
+
try:
message = json.loads(message)
chat_request = ChatRequest(**message)
@@ -240,9 +237,9 @@ async def on_message(self, message):
self.broadcast_message(message=chat_message)
# Clear the message history if given the /clear command
- if chat_request.prompt.startswith('/'):
- command = chat_request.prompt.split(' ', 1)[0]
- if command == '/clear':
+ if chat_request.prompt.startswith("/"):
+ command = chat_request.prompt.split(" ", 1)[0]
+ if command == "/clear":
self.chat_history.clear()
# process through the router
@@ -261,11 +258,11 @@ def on_close(self):
class ModelProviderHandler(BaseAPIHandler):
@property
- def chat_providers(self):
+ def chat_providers(self):
actor = ray.get_actor("providers")
o = actor.get_model_providers.remote()
return ray.get(o)
-
+
@web.authenticated
def get(self):
providers = []
@@ -284,13 +281,14 @@ def get(self):
fields=provider.fields,
)
)
-
- response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name))
+
+ response = ListProvidersResponse(
+ providers=sorted(providers, key=lambda p: p.name)
+ )
self.finish(response.json())
class EmbeddingsModelProviderHandler(BaseAPIHandler):
-
@property
def embeddings_providers(self):
actor = ray.get_actor("providers")
@@ -311,8 +309,10 @@ def get(self):
fields=provider.fields,
)
)
-
- response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name))
+
+ response = ListProvidersResponse(
+ providers=sorted(providers, key=lambda p: p.name)
+ )
self.finish(response.json())
@@ -320,14 +320,14 @@ class GlobalConfigHandler(BaseAPIHandler):
"""API handler for fetching and setting the
model and emebddings config.
"""
-
+
@web.authenticated
def get(self):
actor = ray.get_actor(ACTOR_TYPE.CONFIG)
config = ray.get(actor.get_config.remote())
if not config:
raise HTTPError(500, "No config found.")
-
+
self.finish(config.json())
@web.authenticated
@@ -345,10 +345,9 @@ def post(self):
raise HTTPError(500, str(e)) from e
except ValueError as e:
self.log.exception(e)
- raise HTTPError(500, str(e.cause) if hasattr(e, 'cause') else str(e))
+ raise HTTPError(500, str(e.cause) if hasattr(e, "cause") else str(e))
except Exception as e:
self.log.exception(e)
raise HTTPError(
500, "Unexpected error occurred while updating the config."
) from e
-
diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py
index cfa6a8419..dcefc8934 100644
--- a/packages/jupyter-ai/jupyter_ai/models.py
+++ b/packages/jupyter-ai/jupyter_ai/models.py
@@ -1,17 +1,20 @@
+from typing import Any, Dict, List, Literal, Optional, Union
+
from jupyter_ai_magics.providers import AuthStrategy, Field
+from pydantic import BaseModel
-from pydantic import BaseModel
-from typing import Any, Dict, List, Union, Literal, Optional
class PromptRequest(BaseModel):
task_id: str
engine_id: str
prompt_variables: Dict[str, str]
+
# the type of message used to chat with the agent
class ChatRequest(BaseModel):
prompt: str
+
class ChatUser(BaseModel):
# User ID assigned by IdentityProvider.
username: str
@@ -21,12 +24,14 @@ class ChatUser(BaseModel):
color: Optional[str]
avatar_url: Optional[str]
+
class ChatClient(ChatUser):
# A unique client ID assigned to identify different JupyterLab clients on
# the same device (i.e. running on multiple tabs/windows), which may have
# the same username assigned to them by the IdentityProvider.
id: str
+
class AgentChatMessage(BaseModel):
type: Literal["agent"] = "agent"
id: str
@@ -35,6 +40,7 @@ class AgentChatMessage(BaseModel):
# message ID of the HumanChatMessage it is replying to
reply_to: str
+
class HumanChatMessage(BaseModel):
type: Literal["human"] = "human"
id: str
@@ -42,45 +48,49 @@ class HumanChatMessage(BaseModel):
body: str
client: ChatClient
+
class ConnectionMessage(BaseModel):
type: Literal["connection"] = "connection"
client_id: str
+
class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"
+
# the type of messages being broadcast to clients
ChatMessage = Union[
AgentChatMessage,
HumanChatMessage,
]
-Message = Union[
- AgentChatMessage,
- HumanChatMessage,
- ConnectionMessage,
- ClearMessage
-]
+Message = Union[AgentChatMessage, HumanChatMessage, ConnectionMessage, ClearMessage]
+
class ListEnginesEntry(BaseModel):
id: str
name: str
+
class ListTasksEntry(BaseModel):
id: str
name: str
+
class ListTasksResponse(BaseModel):
tasks: List[ListTasksEntry]
+
class DescribeTaskResponse(BaseModel):
name: str
insertion_mode: str
prompt_template: str
engines: List[ListEnginesEntry]
+
class ChatHistory(BaseModel):
"""History of chat messages"""
+
messages: List[ChatMessage]
@@ -88,6 +98,7 @@ class ListProvidersEntry(BaseModel):
"""Model provider with supported models
and provider's authentication strategy
"""
+
id: str
name: str
models: List[str]
@@ -99,12 +110,15 @@ class ListProvidersEntry(BaseModel):
class ListProvidersResponse(BaseModel):
providers: List[ListProvidersEntry]
+
class IndexedDir(BaseModel):
path: str
+
class IndexMetadata(BaseModel):
dirs: List[IndexedDir]
+
class GlobalConfig(BaseModel):
model_provider_id: Optional[str] = None
embeddings_provider_id: Optional[str] = None
diff --git a/packages/jupyter-ai/jupyter_ai/reply_processor.py b/packages/jupyter-ai/jupyter_ai/reply_processor.py
index 3df2bfb71..d8f215d02 100644
--- a/packages/jupyter-ai/jupyter_ai/reply_processor.py
+++ b/packages/jupyter-ai/jupyter_ai/reply_processor.py
@@ -1,10 +1,11 @@
import asyncio
from typing import Dict
+
from jupyter_ai.handlers import ChatHandler
from ray.util.queue import Queue
-class ReplyProcessor():
+class ReplyProcessor:
"""A single processor to distribute replies"""
def __init__(self, handlers: Dict[str, ChatHandler], queue: Queue, log):
@@ -13,11 +14,11 @@ def __init__(self, handlers: Dict[str, ChatHandler], queue: Queue, log):
self.log = log
def process(self, message):
- self.log.debug('Processing message %s in ReplyProcessor', message)
+ self.log.debug("Processing message %s in ReplyProcessor", message)
for handler in self.handlers.values():
if not handler:
continue
-
+
handler.broadcast_message(message)
break
diff --git a/packages/jupyter-ai/jupyter_ai/task_manager.py b/packages/jupyter-ai/jupyter_ai/task_manager.py
index e995a6f01..9c38573b3 100644
--- a/packages/jupyter-ai/jupyter_ai/task_manager.py
+++ b/packages/jupyter-ai/jupyter_ai/task_manager.py
@@ -1,10 +1,17 @@
import asyncio
-import aiosqlite
import os
-from typing import Optional, List
+from typing import List, Optional
+import aiosqlite
from jupyter_core.paths import jupyter_data_dir
-from .models import ListTasksResponse, ListTasksEntry, DescribeTaskResponse, ListEnginesEntry
+
+from .models import (
+ DescribeTaskResponse,
+ ListEnginesEntry,
+ ListTasksEntry,
+ ListTasksResponse,
+)
+
class TaskManager:
db_path = os.path.join(jupyter_data_dir(), "ai_task_manager.db")
@@ -13,7 +20,7 @@ def __init__(self, engines, default_tasks):
self.engines = engines
self.default_tasks = default_tasks
self.db_initialized = asyncio.create_task(self.init_db())
-
+
async def init_db(self):
async with aiosqlite.connect(self.db_path) as con:
await con.execute(
@@ -33,36 +40,39 @@ async def init_db(self):
for task in self.default_tasks:
await con.execute(
"INSERT INTO tasks (id, name, prompt_template, modality, insertion_mode, is_default) VALUES (?, ?, ?, ?, ?, ?)",
- (task["id"], task["name"], task["prompt_template"], task["modality"], task["insertion_mode"], 1)
+ (
+ task["id"],
+ task["name"],
+ task["prompt_template"],
+ task["modality"],
+ task["insertion_mode"],
+ 1,
+ ),
)
-
+
await con.commit()
-
+
async def list_tasks(self) -> ListTasksResponse:
await self.db_initialized
async with aiosqlite.connect(self.db_path) as con:
- cursor = await con.execute(
- "SELECT id, name FROM tasks"
- )
+ cursor = await con.execute("SELECT id, name FROM tasks")
rows = await cursor.fetchall()
tasks = []
if not rows:
return tasks
-
+
for row in rows:
- tasks.append(ListTasksEntry(
- id=row[0],
- name=row[1]
- ))
-
+ tasks.append(ListTasksEntry(id=row[0], name=row[1]))
+
return ListTasksResponse(tasks=tasks)
-
+
async def describe_task(self, id: str) -> Optional[DescribeTaskResponse]:
await self.db_initialized
async with aiosqlite.connect(self.db_path) as con:
cursor = await con.execute(
- "SELECT name, prompt_template, modality, insertion_mode FROM tasks WHERE id = ?", (id,)
+ "SELECT name, prompt_template, modality, insertion_mode FROM tasks WHERE id = ?",
+ (id,),
)
row = await cursor.fetchone()
@@ -75,15 +85,14 @@ async def describe_task(self, id: str) -> Optional[DescribeTaskResponse]:
for engine in self.engines.values():
if modality in engine.modalities:
engines.append(ListEnginesEntry(id=engine.id, name=engine.name))
-
+
# sort engines A-Z
engines = sorted(engines, key=lambda engine: engine.name)
-
+
return DescribeTaskResponse(
name=row[0],
prompt_template=row[1],
modality=row[2],
insertion_mode=row[3],
- engines=engines
+ engines=engines,
)
-
\ No newline at end of file
diff --git a/packages/jupyter-ai/jupyter_ai/tasks.py b/packages/jupyter-ai/jupyter_ai/tasks.py
index 9d3e8133c..b3127ad56 100644
--- a/packages/jupyter-ai/jupyter_ai/tasks.py
+++ b/packages/jupyter-ai/jupyter_ai/tasks.py
@@ -1,5 +1,6 @@
from typing import List, TypedDict
+
class DefaultTaskDefinition(TypedDict):
id: str
name: str
@@ -7,40 +8,41 @@ class DefaultTaskDefinition(TypedDict):
modality: str
insertion_mode: str
+
tasks: List[DefaultTaskDefinition] = [
{
"id": "explain-code",
"name": "Explain code",
- "prompt_template": "Explain the following Python 3 code. The first sentence must begin with the phrase \"The code below\".\n{body}",
+ "prompt_template": 'Explain the following Python 3 code. The first sentence must begin with the phrase "The code below".\n{body}',
"modality": "txt2txt",
- "insertion_mode": "above"
+ "insertion_mode": "above",
},
{
"id": "generate-code",
"name": "Generate code",
"prompt_template": "Generate Python 3 code in Markdown according to the following definition.\n{body}",
"modality": "txt2txt",
- "insertion_mode": "below"
+ "insertion_mode": "below",
},
{
"id": "explain-code-in-cells-above",
"name": "Explain code in cells above",
- "prompt_template": "Explain the following Python 3 code. The first sentence must begin with the phrase \"The code below\".\n{body}",
+ "prompt_template": 'Explain the following Python 3 code. The first sentence must begin with the phrase "The code below".\n{body}',
"modality": "txt2txt",
- "insertion_mode": "above-in-cells"
+ "insertion_mode": "above-in-cells",
},
{
"id": "generate-code-in-cells-below",
"name": "Generate code in cells below",
"prompt_template": "Generate Python 3 code in Markdown according to the following definition.\n{body}",
"modality": "txt2txt",
- "insertion_mode": "below-in-cells"
+ "insertion_mode": "below-in-cells",
},
{
"id": "freeform",
"name": "Freeform prompt",
"prompt_template": "{body}",
"modality": "txt2txt",
- "insertion_mode": "below"
- }
+ "insertion_mode": "below",
+ },
]
diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
index 620ec8990..0194dec73 100644
--- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
+++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
@@ -11,4 +11,4 @@
# payload = json.loads(response.body)
# assert payload == {
# "data": "This is /jupyter-ai/get_example endpoint!"
-# }
\ No newline at end of file
+# }
diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json
index f9d86ef3d..50b31ff26 100644
--- a/packages/jupyter-ai/package.json
+++ b/packages/jupyter-ai/package.json
@@ -54,7 +54,7 @@
"watch": "run-p watch:src watch:labextension",
"watch:src": "tsc -w",
"watch:labextension": "jupyter labextension watch .",
- "dev-install": "pip install -e . && jupyter labextension develop . --overwrite && jupyter server extension enable jupyter_ai",
+ "dev-install": "pip install -e \".[dev,all]\" && jupyter labextension develop . --overwrite && jupyter server extension enable jupyter_ai",
"dev-uninstall": "pip uninstall jupyter_ai -y"
},
"dependencies": {
diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml
index cff42637e..20a39e558 100644
--- a/packages/jupyter-ai/pyproject.toml
+++ b/packages/jupyter-ai/pyproject.toml
@@ -53,14 +53,12 @@ test = [
"pytest-tornasync"
]
+dev = [
+ "jupyter_ai_magics[dev]"
+]
+
all = [
- "ai21",
- "anthropic",
- "cohere",
- "huggingface_hub",
- "ipywidgets",
- "openai",
- "boto3"
+ "jupyter_ai_magics[all]"
]
[tool.hatch.version]
diff --git a/packages/jupyter-ai/setup.py b/packages/jupyter-ai/setup.py
index bea233743..aefdf20db 100644
--- a/packages/jupyter-ai/setup.py
+++ b/packages/jupyter-ai/setup.py
@@ -1 +1 @@
-__import__('setuptools').setup()
+__import__("setuptools").setup()
diff --git a/packages/jupyter-ai/style/icons/jupyternaut.svg b/packages/jupyter-ai/style/icons/jupyternaut.svg
index 20f02087b..d4367985d 100644
--- a/packages/jupyter-ai/style/icons/jupyternaut.svg
+++ b/packages/jupyter-ai/style/icons/jupyternaut.svg
@@ -12,4 +12,4 @@
-
\ No newline at end of file
+
diff --git a/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py b/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py
index 5ba7a914e..23d06f6dd 100644
--- a/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py
+++ b/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py
@@ -10,7 +10,7 @@
c.ServerApp.port_retries = 0
c.ServerApp.open_browser = False
-c.ServerApp.root_dir = mkdtemp(prefix='galata-test-')
+c.ServerApp.root_dir = mkdtemp(prefix="galata-test-")
c.ServerApp.token = ""
c.ServerApp.password = ""
c.ServerApp.disable_check_xsrf = True
diff --git a/playground/config.example.py b/playground/config.example.py
index 88b43074a..850ab215a 100644
--- a/playground/config.example.py
+++ b/playground/config.example.py
@@ -2,5 +2,5 @@
# Reference: https://jupyter-ai.readthedocs.io/en/latest/users/index.html#configuring-with-openai
# Specify full path to the notebook dir if running jupyter lab from
-# outside of the jupyter-ai project root directory
+# outside of the jupyter-ai project root directory
c.ServerApp.root_dir = "./playground"
diff --git a/scripts/bump-version.sh b/scripts/bump-version.sh
index 11d90177a..900211be9 100755
--- a/scripts/bump-version.sh
+++ b/scripts/bump-version.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
# script that bumps version for all projects regardless of whether they were
# changed since last release. needed because `lerna version` only bumps versions for projects
# listed by `lerna changed` by default.