diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..d9089f491 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,58 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: requirements-txt-fixer + - id: check-added-large-files + - id: check-case-conflict + - id: check-toml + exclude: ^packages/jupyter-ai-module-cookiecutter + - id: check-yaml + exclude: ^packages/jupyter-ai-module-cookiecutter + - id: debug-statements + - id: forbid-new-submodules + - id: check-builtin-literals + - id: trailing-whitespace + + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black"] + files: \.py$ + + - repo: https://github.com/asottile/pyupgrade + rev: v3.4.0 + hooks: + - id: pyupgrade + args: [--py37-plus] + + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + additional_dependencies: + [ + "flake8-bugbear==20.1.4", + "flake8-logging-format==0.6.0", + "flake8-implicit-str-concat==0.2.0", + ] + stages: [manual] + + - repo: https://github.com/sirosen/check-jsonschema + rev: 0.23.1 + hooks: + - id: check-jsonschema + name: "Check GitHub Workflows" + files: ^\.github/workflows/ + types: [yaml] + args: ["--schemafile", "https://json.schemastore.org/github-workflow"] + stages: [manual] diff --git a/README.md b/README.md index 9b08ed418..e9327fa11 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Jupyter AI -Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly +Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly and powerful way to explore generative AI models in notebooks and improve your productivity in JupyterLab and the Jupyter Notebook. More specifically, Jupyter AI offers: @@ -41,7 +41,7 @@ First, install [conda](https://conda.io/projects/conda/en/latest/user-guide/inst If you are using an Apple Silicon-based Mac (M1, M1 Pro, M2, etc.), you need to uninstall the `pip` provided version of `grpcio` and install the version provided by `conda` instead. - $ pip uninstall grpcio; conda install grpcio + $ pip uninstall grpcio; conda install grpcio If you are not using JupyterLab and you only want to install the Jupyter AI `%%ai` magic, skip the `pip install jupyter_ai` step above, and instead, run: @@ -59,7 +59,7 @@ Once you have installed the `%%ai` magic, you can enable it in any notebook or t or: %load_ext jupyter_ai - + The screenshots below are from notebooks in the `examples/` directory of this package. Then, you can use the `%%ai` magic command to specify a model and natural language prompt: diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css index 1858f4fee..e91d1bbcd 100644 --- a/docs/source/_static/css/custom.css +++ b/docs/source/_static/css/custom.css @@ -2,4 +2,4 @@ img.screenshot { /* Copied from div.admonition */ box-shadow: 0 .2rem .5rem var(--pst-color-shadow),0 0 .0625rem var(--pst-color-shadow); border-color: var(--pst-color-info); -} \ No newline at end of file +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 8edc80417..bdd277e22 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,9 +6,9 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'Jupyter AI' -copyright = '2023, Project Jupyter' -author = 'Project Jupyter' +project = "Jupyter AI" +copyright = "2023, Project Jupyter" +author = "Project Jupyter" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -16,17 +16,17 @@ extensions = ["myst_parser"] myst_enable_extensions = ["colon_fence"] -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "pydata_sphinx_theme" -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'css/custom.css', + "css/custom.css", ] # -- Jupyter theming ------------------------------------------------- diff --git a/docs/source/contributors/index.md b/docs/source/contributors/index.md index cb1bc68e9..c0823ef4b 100644 --- a/docs/source/contributors/index.md +++ b/docs/source/contributors/index.md @@ -21,7 +21,7 @@ Due to a compatibility issue with Webpack, Node.js 18.15.0 does not work with Ju ::: ## Development install -After you have installed the prerequisites, create a new conda environment and activate it. +After you have installed the prerequisites, create a new conda environment and activate it. ``` conda create -n jupyter-ai python=3.10 diff --git a/docs/source/index.md b/docs/source/index.md index 118789088..b93973a96 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -1,6 +1,6 @@ # Jupyter AI -Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly +Welcome to Jupyter AI, which brings generative AI to Jupyter. Jupyter AI provides a user-friendly and powerful way to explore generative AI models in notebooks and improve your productivity in JupyterLab and the Jupyter Notebook. More specifically, Jupyter AI offers: @@ -19,4 +19,4 @@ maxdepth: 2 users/index contributors/index -``` \ No newline at end of file +``` diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 40c8a4c62..3e1f5def0 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -2,7 +2,7 @@ Welcome to the user documentation for Jupyter AI. -If you are interested in contributing to Jupyter AI, +If you are interested in contributing to Jupyter AI, please see our {doc}`contributor's guide `. ## Prerequisites @@ -101,7 +101,7 @@ First, install [conda](https://conda.io/projects/conda/en/latest/user-guide/inst If you are using an Apple Silicon-based Mac (M1, M1 Pro, M2, etc.), you need to uninstall the `pip` provided version of `grpcio` and install the version provided by `conda` instead. - $ pip uninstall grpcio; conda install grpcio + $ pip uninstall grpcio; conda install grpcio If you are not using JupyterLab and you only want to install the Jupyter AI `%%ai` magic, you can run: @@ -127,7 +127,7 @@ or ## The chat interface -The easiest way to get started with Jupyter AI is to use the chat interface. +The easiest way to get started with Jupyter AI is to use the chat interface. :::{attention} :name: open-ai-privacy-cost @@ -449,7 +449,7 @@ will look like properly typeset equations. Generate the 2D heat equation in LaTeX surrounded by `$$`. Do not include an explanation. ``` -This prompt will produce output as a code cell below the input cell. +This prompt will produce output as a code cell below the input cell. :::{warning} :name: run-code diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index ff35147db..c3dce919a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -1,4 +1,11 @@ from ._version import __version__ + +# expose embedding model providers on the package root +from .embedding_providers import ( + CohereEmbeddingsProvider, + HfHubEmbeddingsProvider, + OpenAIEmbeddingsProvider, +) from .exception import store_exception from .magics import AiMagics @@ -6,24 +13,20 @@ from .providers import ( AI21Provider, AnthropicProvider, + BaseProvider, + ChatOpenAINewProvider, + ChatOpenAIProvider, CohereProvider, HfHubProvider, OpenAIProvider, - ChatOpenAIProvider, - ChatOpenAINewProvider, - SmEndpointProvider -) -# expose embedding model providers on the package root -from .embedding_providers import ( - OpenAIEmbeddingsProvider, - CohereEmbeddingsProvider, - HfHubEmbeddingsProvider + SmEndpointProvider, ) -from .providers import BaseProvider + def load_ipython_extension(ipython): ipython.register_magics(AiMagics) ipython.set_custom_exc((BaseException,), store_exception) + def unload_ipython_extension(ipython): ipython.set_custom_exc((BaseException,), ipython.CustomTB) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py index 74be10485..ab383af32 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py @@ -3,4 +3,4 @@ "gpt3": "openai:text-davinci-003", "chatgpt": "openai-chat:gpt-3.5-turbo", "gpt4": "openai-chat:gpt-4", -} \ No newline at end of file +} diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index 5b31725bf..0d3f866ae 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -1,8 +1,13 @@ from typing import ClassVar, List, Type + from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy, Field -from pydantic import BaseModel, Extra -from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings +from langchain.embeddings import ( + CohereEmbeddings, + HuggingFaceHubEmbeddings, + OpenAIEmbeddings, +) from langchain.embeddings.base import Embeddings +from pydantic import BaseModel, Extra class BaseEmbeddingsProvider(BaseModel): @@ -33,7 +38,7 @@ class Config: model_id: str - provider_klass: ClassVar[Type[Embeddings]] + provider_klass: ClassVar[Type[Embeddings]] registry: ClassVar[bool] = False """Whether this provider is a registry provider.""" @@ -41,14 +46,12 @@ class Config: fields: ClassVar[List[Field]] = [] """Fields expected by this provider in its constructor. Each `Field` `f` should be passed as a keyword argument, keyed by `f.key`.""" - + class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider): id = "openai" name = "OpenAI" - models = [ - "text-embedding-ada-002" - ] + models = ["text-embedding-ada-002"] model_id_key = "model" pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") @@ -58,11 +61,7 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider): class CohereEmbeddingsProvider(BaseEmbeddingsProvider): id = "cohere" name = "Cohere" - models = [ - 'large', - 'multilingual-22-12', - 'small' - ] + models = ["large", "multilingual-22-12", "small"] model_id_key = "model" pypi_package_deps = ["cohere"] auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py b/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py index 733481ce1..126eedf52 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/exception.py @@ -1,12 +1,13 @@ -from IPython.core.magic import register_line_magic +import traceback + from IPython.core.getipython import get_ipython +from IPython.core.magic import register_line_magic from IPython.core.ultratb import ListTB -import traceback def store_exception(shell, etype, evalue, tb, tb_offset=None): # A structured traceback (a list of strings) or None - + if issubclass(etype, SyntaxError): # Disable ANSI color strings shell.SyntaxTB.color_toggle() @@ -18,7 +19,9 @@ def store_exception(shell, etype, evalue, tb, tb_offset=None): else: # Disable ANSI color strings shell.InteractiveTB.color_toggle() - stb = shell.InteractiveTB.structured_traceback(etype, evalue, tb, tb_offset=tb_offset) + stb = shell.InteractiveTB.structured_traceback( + etype, evalue, tb, tb_offset=tb_offset + ) stb_text = shell.InteractiveTB.stb2text(stb) # Re-enable ANSI color strings shell.InteractiveTB.color_toggle() @@ -31,6 +34,6 @@ def store_exception(shell, etype, evalue, tb, tb_offset=None): err = shell.user_ns.get("Err", {}) err[prompt_number] = styled_exception shell.user_ns["Err"] = err - - # Return + + # Return return etraceback diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index f9eebcac1..c42f5efdd 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -9,22 +9,23 @@ import click from IPython import get_ipython -from IPython.core.magic import Magics, magics_class, line_cell_magic +from IPython.core.magic import Magics, line_cell_magic, magics_class from IPython.display import HTML, JSON, Markdown, Math from jupyter_ai_magics.utils import decompose_model_id, load_providers +from langchain.chains import LLMChain -from .providers import BaseProvider -from .parsers import (cell_magic_parser, - line_magic_parser, +from .parsers import ( CellArgs, DeleteArgs, ErrorArgs, HelpArgs, ListArgs, RegisterArgs, - UpdateArgs) - -from langchain.chains import LLMChain + UpdateArgs, + cell_magic_parser, + line_magic_parser, +) +from .providers import BaseProvider MODEL_ID_ALIASES = { "gpt2": "huggingface_hub:gpt2", @@ -33,35 +34,30 @@ "gpt4": "openai-chat:gpt-4", } -class TextOrMarkdown(object): +class TextOrMarkdown: def __init__(self, text, markdown): self.text = text self.markdown = markdown def _repr_mimebundle_(self, include=None, exclude=None): - return ( - { - 'text/plain': self.text, - 'text/markdown': self.markdown - } - ) + return {"text/plain": self.text, "text/markdown": self.markdown} -class TextWithMetadata(object): +class TextWithMetadata: def __init__(self, text, metadata): self.text = text self.metadata = metadata def _repr_mimebundle_(self, include=None, exclude=None): - return ({'text/plain': self.text}, self.metadata) + return ({"text/plain": self.text}, self.metadata) class Base64Image: def __init__(self, mimeData, metadata): - mimeDataParts = mimeData.split(',') - self.data = base64.b64decode(mimeDataParts[1]); - self.mimeType = re.sub(r';base64$', '', mimeDataParts[0]) + mimeDataParts = mimeData.split(",") + self.data = base64.b64decode(mimeDataParts[1]) + self.mimeType = re.sub(r";base64$", "", mimeDataParts[0]) self.metadata = metadata def _repr_mimebundle_(self, include=None, exclude=None): @@ -76,14 +72,14 @@ def _repr_mimebundle_(self, include=None, exclude=None): "math": Math, "md": Markdown, "json": JSON, - "text": TextWithMetadata + "text": TextWithMetadata, } NA_MESSAGE = 'N/A' -MARKDOWN_PROMPT_TEMPLATE = '{prompt}\n\nProduce output in markdown format only.' +MARKDOWN_PROMPT_TEMPLATE = "{prompt}\n\nProduce output in markdown format only." -PROVIDER_NO_MODELS = 'This provider does not define a list of models.' +PROVIDER_NO_MODELS = "This provider does not define a list of models." CANNOT_DETERMINE_MODEL_TEXT = """Cannot determine model provider from model ID '{0}'. @@ -95,132 +91,150 @@ def _repr_mimebundle_(self, include=None, exclude=None): PROMPT_TEMPLATES_BY_FORMAT = { - "code": '{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.', - "html": '{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.', - "image": '{prompt}\n\nProduce output as an image only, with no text before or after it.', + "code": "{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.", + "html": "{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.", + "image": "{prompt}\n\nProduce output as an image only, with no text before or after it.", "markdown": MARKDOWN_PROMPT_TEMPLATE, "md": MARKDOWN_PROMPT_TEMPLATE, - "math": '{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.', - "json": '{prompt}\n\nProduce output in JSON format only, with nothing before or after it.', - "text": '{prompt}' # No customization + "math": "{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.", + "json": "{prompt}\n\nProduce output in JSON format only, with nothing before or after it.", + "text": "{prompt}", # No customization } -AI_COMMANDS = { "delete", "error", "help", "list", "register", "update" } +AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"} + class FormatDict(dict): """Subclass of dict to be passed to str#format(). Suppresses KeyError and leaves replacement field unchanged if replacement field is not associated with a value.""" - def __missing__(self, key): + + def __missing__(self, key): return key.join("{}") + class EnvironmentError(BaseException): pass + class CellMagicError(BaseException): pass + @magics_class class AiMagics(Magics): def __init__(self, shell): - super(AiMagics, self).__init__(shell) + super().__init__(shell) self.transcript_openai = [] # suppress warning when using old OpenAIChat provider - warnings.filterwarnings("ignore", message="You are trying to use a chat model. This way of initializing it is " + warnings.filterwarnings( + "ignore", + message="You are trying to use a chat model. This way of initializing it is " "no longer supported. Instead, please use: " - "`from langchain.chat_models import ChatOpenAI`") + "`from langchain.chat_models import ChatOpenAI`", + ) self.providers = load_providers() - + # initialize a registry of custom model/chain names self.custom_model_registry = MODEL_ID_ALIASES - + def _ai_bulleted_list_models_for_provider(self, provider_id, Provider): output = "" - if (len(Provider.models) == 1 and Provider.models[0] == "*"): + if len(Provider.models) == 1 and Provider.models[0] == "*": output += f"* {PROVIDER_NO_MODELS}\n" else: for model_id in Provider.models: - output += f"* {provider_id}:{model_id}\n"; - output += "\n" # End of bulleted list - + output += f"* {provider_id}:{model_id}\n" + output += "\n" # End of bulleted list + return output def _ai_inline_list_models_for_provider(self, provider_id, Provider): output = "" - if (len(Provider.models) == 1 and Provider.models[0] == "*"): + if len(Provider.models) == 1 and Provider.models[0] == "*": return PROVIDER_NO_MODELS for model_id in Provider.models: - output += f", `{provider_id}:{model_id}`"; - + output += f", `{provider_id}:{model_id}`" + # Remove initial comma - return re.sub(r'^, ', '', output) - + return re.sub(r"^, ", "", output) + # Is the required environment variable set? def _ai_env_status_for_provider_markdown(self, provider_id): - na_message = 'Not applicable. | ' + NA_MESSAGE + na_message = "Not applicable. | " + NA_MESSAGE + + if ( + provider_id not in self.providers + or self.providers[provider_id].auth_strategy == None + ): + return na_message # No emoji - if (provider_id not in self.providers or - self.providers[provider_id].auth_strategy == None): - return na_message # No emoji - try: env_var = self.providers[provider_id].auth_strategy.name - except AttributeError: # No "name" attribute + except AttributeError: # No "name" attribute return na_message - + output = f"`{env_var}` | " - if (os.getenv(env_var) == None): - output += (""); + if os.getenv(env_var) == None: + output += ( + '" + ) else: - output += (""); - + output += ( + '" + ) + return output def _ai_env_status_for_provider_text(self, provider_id): - if (provider_id not in self.providers or - self.providers[provider_id].auth_strategy == None): - return '' # No message necessary - - try: + if ( + provider_id not in self.providers + or self.providers[provider_id].auth_strategy == None + ): + return "" # No message necessary + + try: env_var = self.providers[provider_id].auth_strategy.name - except AttributeError: # No "name" attribute - return '' - + except AttributeError: # No "name" attribute + return "" + output = f"Requires environment variable {env_var} " - if (os.getenv(env_var) != None): + if os.getenv(env_var) != None: output += "(set)" else: output += "(not set)" - + return output + "\n" # Is this a name of a Python variable that can be called as a LangChain chain? def _is_langchain_chain(self, name): # Reserved word in Python? - if (keyword.iskeyword(name)): - return False; - - acceptable_name = re.compile('^[a-zA-Z0-9_]+$') - if (not acceptable_name.match(name)): - return False; - + if keyword.iskeyword(name): + return False + + acceptable_name = re.compile("^[a-zA-Z0-9_]+$") + if not acceptable_name.match(name): + return False + ipython = get_ipython() - return(name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain)) + return name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain) # Is this an acceptable name for an alias? def _validate_name(self, register_name): # A registry name contains ASCII letters, numbers, hyphens, underscores, # and periods. No other characters, including a colon, are permitted - acceptable_name = re.compile('^[a-zA-Z0-9._-]+$') - if (not acceptable_name.match(register_name)): - raise ValueError('A registry name may contain ASCII letters, numbers, hyphens, underscores, ' - + 'and periods. No other characters, including a colon, are permitted') + acceptable_name = re.compile("^[a-zA-Z0-9._-]+$") + if not acceptable_name.match(register_name): + raise ValueError( + "A registry name may contain ASCII letters, numbers, hyphens, underscores, " + + "and periods. No other characters, including a colon, are permitted" + ) # Initially set or update an alias to a target def _safely_set_target(self, register_name, target): @@ -230,33 +244,38 @@ def _safely_set_target(self, register_name, target): self.custom_model_registry[register_name] = ip.user_ns[target] else: # Ensure that the destination is properly formatted - if (':' not in target): + if ":" not in target: raise ValueError( - 'Target model must be an LLMChain object or a model name in PROVIDER_ID:MODEL_NAME format') + "Target model must be an LLMChain object or a model name in PROVIDER_ID:MODEL_NAME format" + ) self.custom_model_registry[register_name] = target def handle_delete(self, args: DeleteArgs): - if (args.name in AI_COMMANDS): - raise ValueError(f"Reserved command names, including {args.name}, cannot be deleted") - - if (args.name not in self.custom_model_registry): + if args.name in AI_COMMANDS: + raise ValueError( + f"Reserved command names, including {args.name}, cannot be deleted" + ) + + if args.name not in self.custom_model_registry: raise ValueError(f"There is no alias called {args.name}") - + del self.custom_model_registry[args.name] output = f"Deleted alias `{args.name}`" return TextOrMarkdown(output, output) def handle_register(self, args: RegisterArgs): # Existing command names are not allowed - if (args.name in AI_COMMANDS): + if args.name in AI_COMMANDS: raise ValueError(f"The name {args.name} is reserved for a command") - + # Existing registered names are not allowed - if (args.name in self.custom_model_registry): - raise ValueError(f"The name {args.name} is already associated with a custom model; " - + 'use %ai update to change its target') - + if args.name in self.custom_model_registry: + raise ValueError( + f"The name {args.name} is already associated with a custom model; " + + "use %ai update to change its target" + ) + # Does the new name match expected format? self._validate_name(args.name) @@ -265,62 +284,75 @@ def handle_register(self, args: RegisterArgs): return TextOrMarkdown(output, output) def handle_update(self, args: UpdateArgs): - if (args.name in AI_COMMANDS): - raise ValueError(f"Reserved command names, including {args.name}, cannot be updated") - - if (args.name not in self.custom_model_registry): + if args.name in AI_COMMANDS: + raise ValueError( + f"Reserved command names, including {args.name}, cannot be updated" + ) + + if args.name not in self.custom_model_registry: raise ValueError(f"There is no alias called {args.name}") - + self._safely_set_target(args.name, args.target) output = f"Updated target of alias `{args.name}`" return TextOrMarkdown(output, output) def _ai_list_command_markdown(self, single_provider=None): - output = ("| Provider | Environment variable | Set? | Models |\n" - + "|----------|----------------------|------|--------|\n") - if (single_provider is not None and single_provider not in self.providers): - return f"There is no model provider with ID `{single_provider}`."; + output = ( + "| Provider | Environment variable | Set? | Models |\n" + + "|----------|----------------------|------|--------|\n" + ) + if single_provider is not None and single_provider not in self.providers: + return f"There is no model provider with ID `{single_provider}`." for provider_id, Provider in self.providers.items(): - if (single_provider is not None and provider_id != single_provider): - continue; + if single_provider is not None and provider_id != single_provider: + continue - output += (f"| `{provider_id}` | " - + self._ai_env_status_for_provider_markdown(provider_id) + " | " + output += ( + f"| `{provider_id}` | " + + self._ai_env_status_for_provider_markdown(provider_id) + + " | " + self._ai_inline_list_models_for_provider(provider_id, Provider) - + " |\n") + + " |\n" + ) # Also list aliases. - if (single_provider is None and len(self.custom_model_registry) > 0): - output += ("\nAliases and custom commands:\n\n" + if single_provider is None and len(self.custom_model_registry) > 0: + output += ( + "\nAliases and custom commands:\n\n" + "| Name | Target |\n" - + "|------|--------|\n") + + "|------|--------|\n" + ) for key, value in self.custom_model_registry.items(): output += f"| `{key}` | " if isinstance(value, str): output += f"`{value}`" else: output += "*custom chain*" - + output += " |\n" return output def _ai_list_command_text(self, single_provider=None): output = "" - if (single_provider is not None and single_provider not in self.providers): - return f"There is no model provider with ID '{single_provider}'."; + if single_provider is not None and single_provider not in self.providers: + return f"There is no model provider with ID '{single_provider}'." for provider_id, Provider in self.providers.items(): - if (single_provider is not None and provider_id != single_provider): - continue; - - output += (f"{provider_id}\n" - + self._ai_env_status_for_provider_text(provider_id) # includes \n if nonblank - + self._ai_bulleted_list_models_for_provider(provider_id, Provider)) + if single_provider is not None and provider_id != single_provider: + continue + + output += ( + f"{provider_id}\n" + + self._ai_env_status_for_provider_text( + provider_id + ) # includes \n if nonblank + + self._ai_bulleted_list_models_for_provider(provider_id, Provider) + ) # Also list aliases. - if (single_provider is None and len(self.custom_model_registry) > 0): + if single_provider is None and len(self.custom_model_registry) > 0: output += "\nAliases and custom commands:\n" for key, value in self.custom_model_registry.items(): output += f"{key} - " @@ -328,7 +360,7 @@ def _ai_list_command_text(self, single_provider=None): output += value else: output += "custom chain" - + output += "\n" return output @@ -338,42 +370,34 @@ def handle_error(self, args: ErrorArgs): # Find the most recent error. ip = get_ipython() - if ('Err' not in ip.user_ns): + if "Err" not in ip.user_ns: return TextOrMarkdown(no_errors, no_errors) - err = ip.user_ns['Err'] + err = ip.user_ns["Err"] # Start from the previous execution count excount = ip.execution_count - 1 last_error = None - while (excount >= 0 and last_error is None): - if(excount in err): + while excount >= 0 and last_error is None: + if excount in err: last_error = err[excount] else: - excount = excount - 1; + excount = excount - 1 - if (last_error is None): + if last_error is None: return TextOrMarkdown(no_errors, no_errors) prompt = f"Explain the following error:\n\n{last_error}" # Set CellArgs based on ErrorArgs cell_args = CellArgs( - type="root", - model_id=args.model_id, - format=args.format, - reset=False) + type="root", model_id=args.model_id, format=args.format, reset=False + ) return self.run_ai_cell(cell_args, prompt) def _append_exchange_openai(self, prompt: str, output: str): """Appends a conversational exchange between user and an OpenAI Chat model to a transcript that will be included in future exchanges.""" - self.transcript_openai.append({ - "role": "user", - "content": prompt - }) - self.transcript_openai.append({ - "role": "assistant", - "content": output - }) + self.transcript_openai.append({"role": "user", "content": prompt}) + self.transcript_openai.append({"role": "assistant", "content": output}) def _decompose_model_id(self, model_id: str): """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" @@ -388,29 +412,31 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: return None return self.providers[provider_id] - + def display_output(self, output, display_format, md): # build output display DisplayClass = DISPLAYS_BY_FORMAT[display_format] # if the user wants code, add another cell with the output. - if display_format == 'code': + if display_format == "code": # Strip a leading language indicator and trailing triple-backticks - lang_indicator = r'^```[a-zA-Z0-9]*\n' - output = re.sub(lang_indicator, '', output) - output = re.sub(r'\n```$', '', output) + lang_indicator = r"^```[a-zA-Z0-9]*\n" + output = re.sub(lang_indicator, "", output) + output = re.sub(r"\n```$", "", output) new_cell_payload = dict( - source='set_next_input', + source="set_next_input", text=output, replace=False, ) ip = get_ipython() ip.payload_manager.write_payload(new_cell_payload) - return HTML('AI generated code inserted below ⬇️', metadata=md); + return HTML( + "AI generated code inserted below ⬇️", metadata=md + ) if DisplayClass is None: return output - if display_format == 'json': + if display_format == "json": # JSON display expects a dict, not a JSON string output = json.loads(output) output_display = DisplayClass(output, metadata=md) @@ -426,12 +452,12 @@ def handle_help(self, _: HelpArgs): def handle_list(self, args: ListArgs): return TextOrMarkdown( self._ai_list_command_text(args.provider_id), - self._ai_list_command_markdown(args.provider_id) + self._ai_list_command_markdown(args.provider_id), ) def run_ai_cell(self, args: CellArgs, prompt: str): # Apply a prompt template. - prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt = prompt) + prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt=prompt) # interpolate user namespace into prompt ip = get_ipython() @@ -439,30 +465,29 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # Determine provider and local model IDs # If this is a custom chain, send the message to the custom chain. - if (args.model_id in self.custom_model_registry and - isinstance(self.custom_model_registry[args.model_id], LLMChain)): - + if args.model_id in self.custom_model_registry and isinstance( + self.custom_model_registry[args.model_id], LLMChain + ): return self.display_output( self.custom_model_registry[args.model_id].run(prompt), args.format, - { - "jupyter_ai": { - "custom_chain_id": args.model_id - } - }) + {"jupyter_ai": {"custom_chain_id": args.model_id}}, + ) provider_id, local_model_id = self._decompose_model_id(args.model_id) Provider = self._get_provider(provider_id) if Provider is None: return TextOrMarkdown( - CANNOT_DETERMINE_MODEL_TEXT.format(args.model_id) + "\n\n" + CANNOT_DETERMINE_MODEL_TEXT.format(args.model_id) + + "\n\n" + "If you were trying to run a command, run '%ai help' to see a list of commands.", - CANNOT_DETERMINE_MODEL_MARKDOWN.format(args.model_id) + "\n\n" - + "If you were trying to run a command, run `%ai help` to see a list of commands." + CANNOT_DETERMINE_MODEL_MARKDOWN.format(args.model_id) + + "\n\n" + + "If you were trying to run a command, run `%ai help` to see a list of commands.", ) # if `--reset` is specified, reset transcript and return early - if (provider_id == "openai-chat" and args.reset): + if provider_id == "openai-chat" and args.reset: self.transcript_openai = [] return @@ -471,26 +496,26 @@ def run_ai_cell(self, args: CellArgs, prompt: str): if auth_strategy: # TODO: handle auth strategies besides EnvAuthStrategy if auth_strategy.type == "env" and auth_strategy.name not in os.environ: - raise EnvironmentError( + raise OSError( f"Authentication environment variable {auth_strategy.name} not provided.\n" f"An authentication token is required to use models from the {Provider.name} provider.\n" f"Please specify it via `%env {auth_strategy.name}=token`. " ) from None # configure and instantiate provider - provider_params = { "model_id": local_model_id } + provider_params = {"model_id": local_model_id} if provider_id == "openai-chat": provider_params["prefix_messages"] = self.transcript_openai # for SageMaker, validate that required params are specified if provider_id == "sagemaker-endpoint": if ( - args.region_name is None or - args.request_schema is None or - args.response_path is None + args.region_name is None + or args.request_schema is None + or args.response_path is None ): raise ValueError( - "When using the sagemaker-endpoint provider, you must specify all of " + - "the --region-name, --request-schema, and --response-path options." + "When using the sagemaker-endpoint provider, you must specify all of " + + "the --region-name, --request-schema, and --response-path options." ) provider_params["region_name"] = args.region_name provider_params["request_schema"] = args.request_schema @@ -506,18 +531,13 @@ def run_ai_cell(self, args: CellArgs, prompt: str): if provider_id == "openai-chat": self._append_exchange_openai(prompt, output) - md = { - "jupyter_ai": { - "provider_id": provider_id, - "model_id": local_model_id - } - } + md = {"jupyter_ai": {"provider_id": provider_id, "model_id": local_model_id}} return self.display_output(output, args.format, md) @line_cell_magic def ai(self, line, cell=None): - raw_args = line.split(' ') + raw_args = line.split(" ") if cell: args = cell_magic_parser(raw_args, prog_name="%%ai", standalone_mode=False) else: @@ -554,10 +574,10 @@ def ai(self, line, cell=None): """[0.8+]: To invoke a language model, you must use the `%%ai` cell magic. The `%ai` line magic is only for use with subcommands.""" - ) + ) prompt = cell.strip() - + # interpolate user namespace into prompt ip = get_ipython() prompt = prompt.format_map(FormatDict(ip.user_ns)) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index d2b0476b5..a6acf3525 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -1,27 +1,37 @@ +from typing import Literal, Optional, get_args + import click from pydantic import BaseModel -from typing import Optional, Literal, get_args -FORMAT_CHOICES_TYPE = Literal["code", "html", "image", "json", "markdown", "math", "md", "text"] +FORMAT_CHOICES_TYPE = Literal[ + "code", "html", "image", "json", "markdown", "math", "md", "text" +] FORMAT_CHOICES = list(get_args(FORMAT_CHOICES_TYPE)) FORMAT_HELP = """IPython display to use when rendering output. [default="markdown"]""" -REGION_NAME_SHORT_OPTION = '-n' -REGION_NAME_LONG_OPTION = '--region-name' -REGION_NAME_HELP = ("AWS region name, e.g. 'us-east-1'. Required for SageMaker provider; " + - "does nothing with other providers.") +REGION_NAME_SHORT_OPTION = "-n" +REGION_NAME_LONG_OPTION = "--region-name" +REGION_NAME_HELP = ( + "AWS region name, e.g. 'us-east-1'. Required for SageMaker provider; " + + "does nothing with other providers." +) -REQUEST_SCHEMA_SHORT_OPTION = '-q' -REQUEST_SCHEMA_LONG_OPTION = '--request-schema' -REQUEST_SCHEMA_HELP = ("The JSON object the endpoint expects, with the prompt being " + - "substituted into any value that matches the string literal ''. " + - "Required for SageMaker provider; does nothing with other providers.") +REQUEST_SCHEMA_SHORT_OPTION = "-q" +REQUEST_SCHEMA_LONG_OPTION = "--request-schema" +REQUEST_SCHEMA_HELP = ( + "The JSON object the endpoint expects, with the prompt being " + + "substituted into any value that matches the string literal ''. " + + "Required for SageMaker provider; does nothing with other providers." +) + +RESPONSE_PATH_SHORT_OPTION = "-p" +RESPONSE_PATH_LONG_OPTION = "--response-path" +RESPONSE_PATH_HELP = ( + "A JSONPath string that retrieves the language model's output " + + "from the endpoint's JSON response. Required for SageMaker provider; " + + "does nothing with other providers." +) -RESPONSE_PATH_SHORT_OPTION = '-p' -RESPONSE_PATH_LONG_OPTION = '--response-path' -RESPONSE_PATH_HELP = ("A JSONPath string that retrieves the language model's output " + - "from the endpoint's JSON response. Required for SageMaker provider; " + - "does nothing with other providers.") class CellArgs(BaseModel): type: Literal["root"] = "root" @@ -33,6 +43,7 @@ class CellArgs(BaseModel): request_schema: Optional[str] response_path: Optional[str] + # Should match CellArgs, but without "reset" class ErrorArgs(BaseModel): type: Literal["error"] = "error" @@ -43,51 +54,79 @@ class ErrorArgs(BaseModel): request_schema: Optional[str] response_path: Optional[str] + class HelpArgs(BaseModel): type: Literal["help"] = "help" + class ListArgs(BaseModel): type: Literal["list"] = "list" provider_id: Optional[str] + class RegisterArgs(BaseModel): type: Literal["register"] = "register" name: str target: str + class DeleteArgs(BaseModel): type: Literal["delete"] = "delete" name: str + class UpdateArgs(BaseModel): type: Literal["update"] = "update" name: str target: str + class LineMagicGroup(click.Group): """Helper class to print the help string for cell magics as well when `%ai --help` is called.""" + def get_help(self, ctx): with click.Context(cell_magic_parser, info_name="%%ai") as ctx: click.echo(cell_magic_parser.get_help(ctx)) - click.echo('-' * 78) + click.echo("-" * 78) with click.Context(line_magic_parser, info_name="%ai") as ctx: click.echo(super().get_help(ctx)) + @click.command() -@click.argument('model_id') -@click.option('-f', '--format', +@click.argument("model_id") +@click.option( + "-f", + "--format", type=click.Choice(FORMAT_CHOICES, case_sensitive=False), default="markdown", - help=FORMAT_HELP + help=FORMAT_HELP, ) -@click.option('-r', '--reset', is_flag=True, +@click.option( + "-r", + "--reset", + is_flag=True, help="""Clears the conversation transcript used when interacting with an - OpenAI chat model provider. Does nothing with other providers.""" + OpenAI chat model provider. Does nothing with other providers.""", +) +@click.option( + REGION_NAME_SHORT_OPTION, + REGION_NAME_LONG_OPTION, + required=False, + help=REGION_NAME_HELP, +) +@click.option( + REQUEST_SCHEMA_SHORT_OPTION, + REQUEST_SCHEMA_LONG_OPTION, + required=False, + help=REQUEST_SCHEMA_HELP, +) +@click.option( + RESPONSE_PATH_SHORT_OPTION, + RESPONSE_PATH_LONG_OPTION, + required=False, + help=RESPONSE_PATH_HELP, ) -@click.option(REGION_NAME_SHORT_OPTION, REGION_NAME_LONG_OPTION, required=False, help=REGION_NAME_HELP) -@click.option(REQUEST_SCHEMA_SHORT_OPTION, REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP) -@click.option(RESPONSE_PATH_SHORT_OPTION, RESPONSE_PATH_LONG_OPTION, required=False, help=RESPONSE_PATH_HELP) def cell_magic_parser(**kwargs): """ Invokes a language model identified by MODEL_ID, with the prompt being @@ -99,22 +138,41 @@ def cell_magic_parser(**kwargs): """ return CellArgs(**kwargs) + @click.group(cls=LineMagicGroup) def line_magic_parser(): """ Invokes a subcommand. """ -@line_magic_parser.command(name='error') -@click.argument('model_id') -@click.option('-f', '--format', + +@line_magic_parser.command(name="error") +@click.argument("model_id") +@click.option( + "-f", + "--format", type=click.Choice(FORMAT_CHOICES, case_sensitive=False), default="markdown", - help=FORMAT_HELP + help=FORMAT_HELP, +) +@click.option( + REGION_NAME_SHORT_OPTION, + REGION_NAME_LONG_OPTION, + required=False, + help=REGION_NAME_HELP, +) +@click.option( + REQUEST_SCHEMA_SHORT_OPTION, + REQUEST_SCHEMA_LONG_OPTION, + required=False, + help=REQUEST_SCHEMA_HELP, +) +@click.option( + RESPONSE_PATH_SHORT_OPTION, + RESPONSE_PATH_LONG_OPTION, + required=False, + help=RESPONSE_PATH_HELP, ) -@click.option(REGION_NAME_SHORT_OPTION, REGION_NAME_LONG_OPTION, required=False, help=REGION_NAME_HELP) -@click.option(REQUEST_SCHEMA_SHORT_OPTION, REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP) -@click.option(RESPONSE_PATH_SHORT_OPTION, RESPONSE_PATH_LONG_OPTION, required=False, help=RESPONSE_PATH_HELP) def error_subparser(**kwargs): """ Explains the most recent error. Takes the same options (except -r) as @@ -122,41 +180,48 @@ def error_subparser(**kwargs): """ return ErrorArgs(**kwargs) -@line_magic_parser.command(name='help') + +@line_magic_parser.command(name="help") def help_subparser(): """Show this message and exit.""" return HelpArgs() -@line_magic_parser.command(name='list', - short_help="List language models. See `%ai list --help` for options." + +@line_magic_parser.command( + name="list", short_help="List language models. See `%ai list --help` for options." ) -@click.argument('provider_id', required=False) +@click.argument("provider_id", required=False) def list_subparser(**kwargs): """List language models, optionally scoped to PROVIDER_ID.""" return ListArgs(**kwargs) -@line_magic_parser.command(name='register', - short_help="Register a new alias. See `%ai register --help` for options." + +@line_magic_parser.command( + name="register", + short_help="Register a new alias. See `%ai register --help` for options.", ) -@click.argument('name') -@click.argument('target') +@click.argument("name") +@click.argument("target") def register_subparser(**kwargs): """Register a new alias called NAME for the model or chain named TARGET.""" return RegisterArgs(**kwargs) -@line_magic_parser.command(name='delete', - short_help="Delete an alias. See `%ai delete --help` for options." + +@line_magic_parser.command( + name="delete", short_help="Delete an alias. See `%ai delete --help` for options." ) -@click.argument('name') +@click.argument("name") def register_subparser(**kwargs): """Delete an alias called NAME.""" return DeleteArgs(**kwargs) -@line_magic_parser.command(name='update', - short_help="Update the target of an alias. See `%ai update --help` for options." + +@line_magic_parser.command( + name="update", + short_help="Update the target of an alias. See `%ai update --help` for options.", ) -@click.argument('name') -@click.argument('target') +@click.argument("name") +@click.argument("target") def register_subparser(**kwargs): """Update an alias called NAME to refer to the model or chain named TARGET.""" return UpdateArgs(**kwargs) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 5a573b8b7..e5c68d5d7 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -1,11 +1,11 @@ -from typing import Any, ClassVar, Dict, List, Union, Literal, Optional import base64 +import copy import io import json -import copy +from typing import Any, ClassVar, Dict, List, Literal, Optional, Union from jsonpath_ng import jsonpath, parse -from langchain.schema import BaseModel as BaseLangchainProvider +from langchain.chat_models import ChatOpenAI from langchain.llms import ( AI21, Anthropic, @@ -13,29 +13,32 @@ HuggingFaceHub, OpenAI, OpenAIChat, - SagemakerEndpoint + SagemakerEndpoint, ) from langchain.llms.sagemaker_endpoint import LLMContentHandler -from langchain.utils import get_from_dict_or_env from langchain.llms.utils import enforce_stop_tokens - +from langchain.schema import BaseModel as BaseLangchainProvider +from langchain.utils import get_from_dict_or_env from pydantic import BaseModel, Extra, root_validator -from langchain.chat_models import ChatOpenAI + class EnvAuthStrategy(BaseModel): """Require one auth token via an environment variable.""" + type: Literal["env"] = "env" name: str class MultiEnvAuthStrategy(BaseModel): """Require multiple auth tokens via multiple environment variables.""" + type: Literal["file"] = "file" names: List[str] class AwsAuthStrategy(BaseModel): """Require AWS authentication via Boto3""" + type: Literal["aws"] = "aws" @@ -47,18 +50,22 @@ class AwsAuthStrategy(BaseModel): ] ] + class TextField(BaseModel): type: Literal["text"] = "text" key: str label: str + class MultilineTextField(BaseModel): type: Literal["text-multiline"] = "text-multiline" key: str label: str + Field = Union[TextField, MultilineTextField] + class BaseProvider(BaseLangchainProvider): # # pydantic config @@ -93,7 +100,7 @@ class Config: """Whether this provider is a registry provider.""" fields: ClassVar[List[Field]] = [] - """User inputs expected by this provider when initializing it. Each `Field` `f` + """User inputs expected by this provider when initializing it. Each `Field` `f` should be passed in the constructor as a keyword argument, keyed by `f.key`.""" # @@ -105,14 +112,16 @@ def __init__(self, *args, **kwargs): try: assert kwargs["model_id"] except: - raise AssertionError("model_id was not specified. Please specify it as a keyword argument.") + raise AssertionError( + "model_id was not specified. Please specify it as a keyword argument." + ) model_kwargs = {} model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] super().__init__(*args, **kwargs, **model_kwargs) - + class AI21Provider(BaseProvider, AI21): id = "ai21" name = "AI21" @@ -131,6 +140,7 @@ class AI21Provider(BaseProvider, AI21): pypi_package_deps = ["ai21"] auth_strategy = EnvAuthStrategy(name="AI21_API_KEY") + class AnthropicProvider(BaseProvider, Anthropic): id = "anthropic" name = "Anthropic" @@ -145,6 +155,7 @@ class AnthropicProvider(BaseProvider, Anthropic): pypi_package_deps = ["anthropic"] auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + class CohereProvider(BaseProvider, Cohere): id = "cohere" name = "Cohere" @@ -153,7 +164,13 @@ class CohereProvider(BaseProvider, Cohere): pypi_package_deps = ["cohere"] auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") -HUGGINGFACE_HUB_VALID_TASKS = ("text2text-generation", "text-generation", "text-to-image") + +HUGGINGFACE_HUB_VALID_TASKS = ( + "text2text-generation", + "text-generation", + "text-to-image", +) + class HfHubProvider(BaseProvider, HuggingFaceHub): id = "huggingface_hub" @@ -195,7 +212,7 @@ def validate_environment(cls, values: Dict) -> Dict: "Please install it with `pip install huggingface_hub`." ) return values - + # Handle image outputs def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to HuggingFace Hub's inference endpoint. @@ -220,21 +237,21 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: # Custom code for responding to image generation responses if self.client.task == "text-to-image": - imageFormat = response.format # Presume it's a PIL ImageFile - mimeType = '' - if (imageFormat == 'JPEG'): - mimeType = 'image/jpeg' - elif (imageFormat == 'PNG'): - mimeType = 'image/png' - elif (imageFormat == 'GIF'): - mimeType = 'image/gif' + imageFormat = response.format # Presume it's a PIL ImageFile + mimeType = "" + if imageFormat == "JPEG": + mimeType = "image/jpeg" + elif imageFormat == "PNG": + mimeType = "image/png" + elif imageFormat == "GIF": + mimeType = "image/gif" else: raise ValueError(f"Unrecognized image format {imageFormat}") - + buffer = io.BytesIO() response.save(buffer, format=imageFormat) # Encode image data to Base64 bytes, then decode bytes to str - return (mimeType + ';base64,' + base64.b64encode(buffer.getvalue()).decode()) + return mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode() if self.client.task == "text-generation": # Text generation return includes the starter text. @@ -252,6 +269,7 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: text = enforce_stop_tokens(text, stop) return text + class OpenAIProvider(BaseProvider, OpenAI): id = "openai" name = "OpenAI" @@ -270,6 +288,7 @@ class OpenAIProvider(BaseProvider, OpenAI): pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + class ChatOpenAIProvider(BaseProvider, OpenAIChat): id = "openai-chat" name = "OpenAI" @@ -288,14 +307,9 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat): def append_exchange(self, prompt: str, output: str): """Appends a conversational exchange between user and an OpenAI Chat model to a transcript that will be included in future exchanges.""" - self.prefix_messages.append({ - "role": "user", - "content": prompt - }) - self.prefix_messages.append({ - "role": "assistant", - "content": output - }) + self.prefix_messages.append({"role": "user", "content": prompt}) + self.prefix_messages.append({"role": "assistant", "content": output}) + # uses the new OpenAIChat provider. temporarily living as a separate class until # conflicts can be resolved @@ -314,6 +328,7 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI): pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + class JsonContentHandler(LLMContentHandler): content_type = "application/json" accepts = "application/json" @@ -330,20 +345,21 @@ def replace_values(self, old_val, new_val, d: Dict[str, Any]): d[key] = new_val if isinstance(val, dict): self.replace_values(old_val, new_val, val) - + return d def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: request_obj = copy.deepcopy(self.request_schema) self.replace_values("", prompt, request_obj) - request = json.dumps(request_obj).encode('utf-8') + request = json.dumps(request_obj).encode("utf-8") return request - + def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) matches = self.response_parser.find(response_json) return matches[0].value + class SmEndpointProvider(BaseProvider, SagemakerEndpoint): id = "sagemaker-endpoint" name = "Sagemaker Endpoint" @@ -364,12 +380,13 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint): TextField( key="response_path", label="Response path", - ) + ), ] - + def __init__(self, *args, **kwargs): - request_schema = kwargs.pop('request_schema') - response_path = kwargs.pop('response_path') - content_handler = JsonContentHandler(request_schema=request_schema, response_path=response_path) + request_schema = kwargs.pop("request_schema") + response_path = kwargs.pop("response_path") + content_handler = JsonContentHandler( + request_schema=request_schema, response_path=response_path + ) super().__init__(*args, **kwargs, content_handler=content_handler) - diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index aab722240..a6c717889 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -1,13 +1,11 @@ import logging from typing import Dict, Optional, Tuple, Union + from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES - from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider - from jupyter_ai_magics.providers import BaseProvider - Logger = Union[logging.Logger, logging.LoggerAdapter] @@ -23,15 +21,19 @@ def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: try: provider = model_provider_ep.load() except: - log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.") + log.error( + f"Unable to load model provider class from entry point `{model_provider_ep.name}`." + ) continue providers[provider.id] = provider log.info(f"Registered model provider `{provider.id}`.") - + return providers -def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbeddingsProvider]: +def load_embedding_providers( + log: Optional[Logger] = None, +) -> Dict[str, BaseEmbeddingsProvider]: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) @@ -42,29 +44,34 @@ def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbe try: provider = model_provider_ep.load() except: - log.error(f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`.") + log.error( + f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`." + ) continue providers[provider.id] = provider log.info(f"Registered embeddings model provider `{provider.id}`.") - + return providers -def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]: - """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - if model_id in MODEL_ID_ALIASES: - model_id = MODEL_ID_ALIASES[model_id] - if ":" not in model_id: - # case: model ID was not provided with a prefix indicating the provider - # ID. try to infer the provider ID before returning (None, None). +def decompose_model_id( + model_id: str, providers: Dict[str, BaseProvider] +) -> Tuple[str, str]: + """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" + if model_id in MODEL_ID_ALIASES: + model_id = MODEL_ID_ALIASES[model_id] + + if ":" not in model_id: + # case: model ID was not provided with a prefix indicating the provider + # ID. try to infer the provider ID before returning (None, None). + + # naively search through the dictionary and return the first provider + # that provides a model of the same ID. + for provider_id, provider in providers.items(): + if model_id in provider.models: + return (provider_id, model_id) - # naively search through the dictionary and return the first provider - # that provides a model of the same ID. - for provider_id, provider in providers.items(): - if model_id in provider.models: - return (provider_id, model_id) - - return (None, None) + return (None, None) - provider_id, local_model_id = model_id.split(":", 1) - return (provider_id, local_model_id) + provider_id, local_model_id = model_id.split(":", 1) + return (provider_id, local_model_id) diff --git a/packages/jupyter-ai-magics/package.json b/packages/jupyter-ai-magics/package.json index 922d8a22b..94d527c8b 100644 --- a/packages/jupyter-ai-magics/package.json +++ b/packages/jupyter-ai-magics/package.json @@ -18,7 +18,7 @@ "url": "https://github.com/jupyterlab/jupyter-ai.git" }, "scripts": { - "dev-install": "pip install -e \".[all]\"", + "dev-install": "pip install -e \".[dev,all]\"", "dev-uninstall": "pip uninstall jupyter_ai_magics -y" } } diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 2d59c9567..464284679 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -31,6 +31,10 @@ dependencies = [ ] [project.optional-dependencies] +dev = [ + "pre-commit~=3.3.3" +] + test = [ "coverage", "pytest", diff --git a/packages/jupyter-ai-magics/setup.py b/packages/jupyter-ai-magics/setup.py index bea233743..aefdf20db 100644 --- a/packages/jupyter-ai-magics/setup.py +++ b/packages/jupyter-ai-magics/setup.py @@ -1 +1 @@ -__import__('setuptools').setup() +__import__("setuptools").setup() diff --git a/packages/jupyter-ai-module-cookiecutter/README.md b/packages/jupyter-ai-module-cookiecutter/README.md index e80d6ac08..b99f3617d 100644 --- a/packages/jupyter-ai-module-cookiecutter/README.md +++ b/packages/jupyter-ai-module-cookiecutter/README.md @@ -2,8 +2,8 @@ A [cookiecutter](https://github.com/audreyr/cookiecutter) template for creating a AI module. The AI module constructed from the template serves as a very simple -example that can be extended however you wish. - +example that can be extended however you wish. + ## Usage Install cookiecutter. diff --git a/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py b/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py index 48b57e8be..aa816e397 100644 --- a/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py +++ b/packages/jupyter-ai-module-cookiecutter/hooks/post_gen_project.py @@ -6,7 +6,7 @@ def remove_path(path: Path) -> None: """Remove the provided path. - + If the target path is a directory, remove it recursively. """ if not path.exists(): @@ -21,7 +21,6 @@ def remove_path(path: Path) -> None: if __name__ == "__main__": - if not "{{ cookiecutter.has_settings }}".lower().startswith("y"): remove_path(PROJECT_DIRECTORY / "schema") @@ -30,7 +29,9 @@ def remove_path(path: Path) -> None: remove_path(PROJECT_DIRECTORY / ".github/workflows/binder-on-pr.yml") if not "{{ cookiecutter.test }}".lower().startswith("y"): - remove_path(PROJECT_DIRECTORY / ".github" / "workflows" / "update-integration-tests.yml") + remove_path( + PROJECT_DIRECTORY / ".github" / "workflows" / "update-integration-tests.yml" + ) remove_path(PROJECT_DIRECTORY / "src" / "__tests__") remove_path(PROJECT_DIRECTORY / "ui-tests") remove_path(PROJECT_DIRECTORY / "{{ cookiecutter.python_name }}" / "tests") diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml index a975d584d..e79041157 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: set -eux jlpm jlpm run lint:check -{% if cookiecutter.test.lower().startswith('y') %} +{% if cookiecutter.test.lower().startswith('y') %} - name: Test the extension run: | set -eux @@ -102,7 +102,7 @@ jobs: uses: actions/download-artifact@v3 with: name: extension-artifacts - + - name: Install the extension run: | set -eux diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild index 1a20c19a8..d4b13f656 100755 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/binder/postBuild @@ -15,10 +15,11 @@ from pathlib import Path ROOT = Path.cwd() + def _(*args, **kwargs): - """ Run a command, echoing the args + """Run a command, echoing the args - fails hard if something goes wrong + fails hard if something goes wrong """ print("\n\t", " ".join(args), "\n") return_code = subprocess.call(args, **kwargs) @@ -26,6 +27,7 @@ def _(*args, **kwargs): print("\nERROR", return_code, " ".join(args)) sys.exit(return_code) + # verify the environment is self-consistent before even starting _(sys.executable, "-m", "pip", "check") diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py index bea233743..aefdf20db 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/setup.py @@ -1 +1 @@ -__import__('setuptools').setup() +__import__("setuptools").setup() diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py index 5ba7a914e..23d06f6dd 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/ui-tests/jupyter_server_test_config.py @@ -10,7 +10,7 @@ c.ServerApp.port_retries = 0 c.ServerApp.open_browser = False -c.ServerApp.root_dir = mkdtemp(prefix='galata-test-') +c.ServerApp.root_dir = mkdtemp(prefix="galata-test-") c.ServerApp.token = "" c.ServerApp.password = "" c.ServerApp.disable_check_xsrf = True diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py index d14127453..db43ba181 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/__init__.py @@ -7,7 +7,4 @@ def _jupyter_labextension_paths(): - return [{ - "src": "labextension", - "dest": "{{ cookiecutter.labextension_name }}" - }] + return [{"src": "labextension", "dest": "{{ cookiecutter.labextension_name }}"}] diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py index 420bde3d3..c32e86148 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/engine.py @@ -3,6 +3,7 @@ from jupyter_ai.engine import BaseModelEngine from jupyter_ai.models import DescribeTaskResponse + class TestModelEngine(BaseModelEngine): name = "test" input_type = "txt" @@ -18,7 +19,9 @@ class TestModelEngine(BaseModelEngine): # ) # - async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]): + async def execute( + self, task: DescribeTaskResponse, prompt_variables: Dict[str, str] + ): # Core method that executes a model when provided with a task # description and a dictionary of prompt variables. For example, to # execute an OpenAI text completion model: diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py index c3cb8b9dc..07226c447 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.labextension_name}}/{{cookiecutter.python_name}}/tasks.py @@ -1,4 +1,5 @@ from typing import List + from jupyter_ai import DefaultTaskDefinition # tasks your AI module exposes by default. declared in `pyproject.toml` as an @@ -9,6 +10,6 @@ "name": "Test task", "prompt_template": "{body}", "modality": "txt2txt", - "insertion_mode": "test" + "insertion_mode": "test", } ] diff --git a/packages/jupyter-ai/.gitignore b/packages/jupyter-ai/.gitignore index 56891ff87..7c34e452b 100644 --- a/packages/jupyter-ai/.gitignore +++ b/packages/jupyter-ai/.gitignore @@ -123,4 +123,4 @@ dmypy.json .vscode # Coverage reports -coverage/* \ No newline at end of file +coverage/* diff --git a/packages/jupyter-ai/conftest.py b/packages/jupyter-ai/conftest.py index 12f736c03..dc3b0e05c 100644 --- a/packages/jupyter-ai/conftest.py +++ b/packages/jupyter-ai/conftest.py @@ -1,6 +1,6 @@ import pytest -pytest_plugins = ("jupyter_server.pytest_plugin", ) +pytest_plugins = ("jupyter_server.pytest_plugin",) @pytest.fixture diff --git a/packages/jupyter-ai/jupyter_ai/__init__.py b/packages/jupyter-ai/jupyter_ai/__init__.py index 7bb42b861..701dc2a86 100644 --- a/packages/jupyter-ai/jupyter_ai/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/__init__.py @@ -1,24 +1,19 @@ -from ._version import __version__ -from .extension import AiExtension - # expose jupyter_ai_magics ipython extension from jupyter_ai_magics import load_ipython_extension, unload_ipython_extension +from ._version import __version__ + # imports to expose entry points. DO NOT REMOVE. from .engine import GPT3ModelEngine -from .tasks import tasks +from .extension import AiExtension # imports to expose types to other AI modules. DO NOT REMOVE. -from .tasks import DefaultTaskDefinition +from .tasks import DefaultTaskDefinition, tasks + def _jupyter_labextension_paths(): - return [{ - "src": "labextension", - "dest": "@jupyter-ai/core" - }] + return [{"src": "labextension", "dest": "@jupyter-ai/core"}] + def _jupyter_server_extension_points(): - return [{ - "module": "jupyter_ai", - "app": AiExtension - }] + return [{"module": "jupyter_ai", "app": AiExtension}] diff --git a/packages/jupyter-ai/jupyter_ai/actors/__init__.py b/packages/jupyter-ai/jupyter_ai/actors/__init__.py index a48ae5056..9cfbc9156 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/actors/__init__.py @@ -1 +1 @@ -"""Actor classes that process incoming chat messages""" \ No newline at end of file +"""Actor classes that process incoming chat messages""" diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py index e78837ca4..45f2d5b9b 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/ask.py +++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py @@ -1,15 +1,13 @@ import argparse from typing import Dict, List, Type -from jupyter_ai_magics.providers import BaseProvider import ray -from ray.util.queue import Queue - +from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger +from jupyter_ai.models import HumanChatMessage +from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationalRetrievalChain from langchain.schema import BaseRetriever, Document - -from jupyter_ai.models import HumanChatMessage -from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger +from ray.util.queue import Queue @ray.remote @@ -24,38 +22,39 @@ class AskActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): super().__init__(reply_queue=reply_queue, log=log) - self.parser.prog = '/ask' - self.parser.add_argument('query', nargs=argparse.REMAINDER) + self.parser.prog = "/ask" + self.parser.add_argument("query", nargs=argparse.REMAINDER) - def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + def create_llm_chain( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): retriever = Retriever() self.llm = provider(**provider_params) self.chat_history = [] - self.llm_chain = ConversationalRetrievalChain.from_llm( - self.llm, - retriever - ) + self.llm_chain = ConversationalRetrievalChain.from_llm(self.llm, retriever) def _process_message(self, message: HumanChatMessage): args = self.parse_args(message) if args is None: return - query = ' '.join(args.query) + query = " ".join(args.query) if not query: self.reply(f"{self.parser.format_usage()}", message) return - + self.get_llm_chain() try: - result = self.llm_chain({"question": query, "chat_history": self.chat_history}) - response = result['answer'] + result = self.llm_chain( + {"question": query, "chat_history": self.chat_history} + ) + response = result["answer"] self.chat_history.append((query, response)) self.reply(response, message) except AssertionError as e: self.log.error(e) - response = """Sorry, an error occurred while reading the from the learned documents. - If you have changed the embedding provider, try deleting the existing index by running + response = """Sorry, an error occurred while reading the from the learned documents. + If you have changed the embedding provider, try deleting the existing index by running `/learn -d` command and then re-submitting the `learn ` to learn the documents, and then asking the question again. """ @@ -68,11 +67,11 @@ class Retriever(BaseRetriever): of inconsistent de-serialization of index when it's accessed directly from the ask actor. """ - + def get_relevant_documents(self, question: str): index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) docs = ray.get(index_actor.get_relevant_documents.remote(question)) return docs - + async def aget_relevant_documents(self, query: str) -> List[Document]: - return await super().aget_relevant_documents(query) \ No newline at end of file + return await super().aget_relevant_documents(query) diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index 84a62bf8b..d1ddddcf3 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -1,20 +1,19 @@ import argparse -from enum import Enum -from uuid import uuid4 -import time import logging -from typing import Dict, Optional, Type, Union +import time import traceback +from enum import Enum +from typing import Dict, Optional, Type, Union +from uuid import uuid4 -from jupyter_ai_magics.providers import BaseProvider import ray - +from jupyter_ai.models import AgentChatMessage, HumanChatMessage +from jupyter_ai_magics.providers import BaseProvider from ray.util.queue import Queue -from jupyter_ai.models import HumanChatMessage, AgentChatMessage - Logger = Union[logging.Logger, logging.LoggerAdapter] + class ACTOR_TYPE(str, Enum): # the top level actor that routes incoming messages to the appropriate actor ROUTER = "router" @@ -23,28 +22,26 @@ class ACTOR_TYPE(str, Enum): DEFAULT = "default" ASK = "ask" - LEARN = 'learn' - MEMORY = 'memory' - GENERATE = 'generate' - PROVIDERS = 'providers' - CONFIG = 'config' - CHAT_PROVIDER = 'chat_provider' - EMBEDDINGS_PROVIDER = 'embeddings_provider' + LEARN = "learn" + MEMORY = "memory" + GENERATE = "generate" + PROVIDERS = "providers" + CONFIG = "config" + CHAT_PROVIDER = "chat_provider" + EMBEDDINGS_PROVIDER = "embeddings_provider" + COMMANDS = { - '/ask': ACTOR_TYPE.ASK, - '/learn': ACTOR_TYPE.LEARN, - '/generate': ACTOR_TYPE.GENERATE + "/ask": ACTOR_TYPE.ASK, + "/learn": ACTOR_TYPE.LEARN, + "/generate": ACTOR_TYPE.GENERATE, } -class BaseActor(): + +class BaseActor: """Base actor implemented by actors that are called by the `Router`""" - def __init__( - self, - log: Logger, - reply_queue: Queue - ): + def __init__(self, log: Logger, reply_queue: Queue): self.log = log self.reply_queue = reply_queue self.parser = argparse.ArgumentParser() @@ -63,17 +60,17 @@ def process_message(self, message: HumanChatMessage): formatted_e = traceback.format_exc() response = f"Sorry, something went wrong and I wasn't able to index that path.\n\n```\n{formatted_e}\n```" self.reply(response, message) - + def _process_message(self, message: HumanChatMessage): """Processes the message passed by the `Router`""" raise NotImplementedError("Should be implemented by subclasses.") - + def reply(self, response, message: Optional[HumanChatMessage] = None): m = AgentChatMessage( id=uuid4().hex, time=time.time(), body=response, - reply_to=message.id if message else "" + reply_to=message.id if message else "", ) self.reply_queue.put(m) @@ -82,40 +79,50 @@ def get_llm_chain(self): lm_provider = ray.get(actor.get_provider.remote()) lm_provider_params = ray.get(actor.get_provider_params.remote()) - curr_lm_id = f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None - next_lm_id = f'{lm_provider.id}:{lm_provider_params["model_id"]}' if lm_provider else None + curr_lm_id = ( + f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + ) + next_lm_id = ( + f'{lm_provider.id}:{lm_provider_params["model_id"]}' + if lm_provider + else None + ) if not lm_provider: return None - + if curr_lm_id != next_lm_id: - self.log.info(f"Switching chat language model from {curr_lm_id} to {next_lm_id}.") + self.log.info( + f"Switching chat language model from {curr_lm_id} to {next_lm_id}." + ) self.create_llm_chain(lm_provider, lm_provider_params) return self.llm_chain - + def get_embeddings(self): actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) provider = ray.get(actor.get_provider.remote()) embedding_params = ray.get(actor.get_provider_params.remote()) embedding_model_id = ray.get(actor.get_model_id.remote()) - + if not provider: return None - + if embedding_model_id != self.embedding_model_id: self.embeddings = provider(**embedding_params) return self.embeddings - - def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + + def create_llm_chain( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): raise NotImplementedError("Should be implemented by subclasses") - + def parse_args(self, message): - args = message.body.split(' ') + args = message.body.split(" ") try: args = self.parser.parse_args(args[1:]) except (argparse.ArgumentError, SystemExit) as e: response = f"{self.parser.format_usage()}" self.reply(response, message) return None - return args \ No newline at end of file + return args diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py index a225beb7b..e108f9f2a 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -1,10 +1,10 @@ -from jupyter_ai.actors.base import Logger, ACTOR_TYPE -from jupyter_ai.models import GlobalConfig import ray +from jupyter_ai.actors.base import ACTOR_TYPE, Logger +from jupyter_ai.models import GlobalConfig -@ray.remote -class ChatProviderActor(): +@ray.remote +class ChatProviderActor: def __init__(self, log: Logger): self.log = log self.provider = None @@ -16,26 +16,28 @@ def update(self, config: GlobalConfig): local_model_id, provider = ray.get( actor.get_model_provider_data.remote(model_id) ) - + if not provider: raise ValueError(f"No provider and model found with '{model_id}'") - + fields = config.fields.get(model_id, {}) - provider_params = { "model_id": local_model_id, **fields } - + provider_params = {"model_id": local_model_id, **fields} + auth_strategy = provider.auth_strategy if auth_strategy and auth_strategy.type == "env": api_keys = config.api_keys name = auth_strategy.name if name not in api_keys: - raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") + raise ValueError( + f"Missing value for '{auth_strategy.name}' in the config." + ) provider_params[name.lower()] = api_keys[name] - + self.provider = provider self.provider_params = provider_params def get_provider(self): return self.provider - + def get_provider_params(self): - return self.provider_params \ No newline at end of file + return self.provider_params diff --git a/packages/jupyter-ai/jupyter_ai/actors/config.py b/packages/jupyter-ai/jupyter_ai/actors/config.py index f51c9e5f7..5ca3c518e 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/config.py +++ b/packages/jupyter-ai/jupyter_ai/actors/config.py @@ -1,21 +1,22 @@ import json import os + +import ray from jupyter_ai.actors.base import ACTOR_TYPE, Logger from jupyter_ai.models import GlobalConfig -import ray from jupyter_core.paths import jupyter_data_dir @ray.remote -class ConfigActor(): - """Provides model and embedding provider id along +class ConfigActor: + """Provides model and embedding provider id along with the credentials to authenticate providers. """ def __init__(self, log: Logger): self.log = log - self.save_dir = os.path.join(jupyter_data_dir(), 'jupyter_ai') - self.save_path = os.path.join(self.save_dir, 'config.json') + self.save_dir = os.path.join(jupyter_data_dir(), "jupyter_ai") + self.save_path = os.path.join(self.save_dir, "config.json") self.config = None self._load() @@ -25,7 +26,7 @@ def update(self, config: GlobalConfig, save_to_disk: bool = True): if save_to_disk: self._save(config) self.config = config - + def _update_chat_provider(self, config: GlobalConfig): if not config.model_provider_id: return @@ -43,19 +44,19 @@ def _update_embeddings_provider(self, config: GlobalConfig): def _save(self, config: GlobalConfig): if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) - - with open(self.save_path, 'w') as f: + + with open(self.save_path, "w") as f: f.write(config.json()) def _load(self): if os.path.exists(self.save_path): - with open(self.save_path, 'r', encoding='utf-8') as f: + with open(self.save_path, encoding="utf-8") as f: config = GlobalConfig(**json.loads(f.read())) self.update(config, False) return - + # otherwise, create a new empty config file self.update(GlobalConfig(), True) def get_config(self): - return self.config \ No newline at end of file + return self.config diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py index f5434c8b6..1792a0b31 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/default.py +++ b/packages/jupyter-ai/jupyter_ai/actors/default.py @@ -1,21 +1,18 @@ -from typing import Dict, Type, List -import ray +from typing import Dict, List, Type +import ray +from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor +from jupyter_ai.actors.memory import RemoteMemory +from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage +from jupyter_ai_magics.providers import BaseProvider from langchain import ConversationChain from langchain.prompts import ( - ChatPromptTemplate, - MessagesPlaceholder, + ChatPromptTemplate, HumanMessagePromptTemplate, - SystemMessagePromptTemplate -) -from langchain.schema import ( - AIMessage, + MessagesPlaceholder, + SystemMessagePromptTemplate, ) - -from jupyter_ai.actors.base import BaseActor, ACTOR_TYPE -from jupyter_ai.actors.memory import RemoteMemory -from jupyter_ai.models import HumanChatMessage, ClearMessage, ChatMessage -from jupyter_ai_magics.providers import BaseProvider +from langchain.schema import AIMessage SYSTEM_PROMPT = """ You are Jupyternaut, a conversational assistant living in JupyterLab to help users. @@ -28,6 +25,7 @@ The following is a friendly conversation between you and a human. """.strip() + @ray.remote class DefaultActor(BaseActor): def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): @@ -35,25 +33,27 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): self.memory = None self.chat_history = chat_history - def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + def create_llm_chain( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): llm = provider(**provider_params) self.memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) - prompt_template = ChatPromptTemplate.from_messages([ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(provider_name=llm.name, local_model_id=llm.model_id), - MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), - AIMessage(content="") - ]) + prompt_template = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( + provider_name=llm.name, local_model_id=llm.model_id + ), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content=""), + ] + ) self.llm = llm self.llm_chain = ConversationChain( - llm=llm, - prompt=prompt_template, - verbose=True, - memory=self.memory + llm=llm, prompt=prompt_template, verbose=True, memory=self.memory ) - + def clear_memory(self): - # clear chain memory if self.memory: self.memory.clear() @@ -68,8 +68,5 @@ def clear_memory(self): def _process_message(self, message: HumanChatMessage): self.get_llm_chain() - response = self.llm_chain.predict( - input=message.body, - stop=["\nHuman:"] - ) + response = self.llm_chain.predict(input=message.body, stop=["\nHuman:"]) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py index 068ce0388..3067e85e8 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -1,10 +1,10 @@ -from jupyter_ai.actors.base import Logger, ACTOR_TYPE -from jupyter_ai.models import GlobalConfig import ray +from jupyter_ai.actors.base import ACTOR_TYPE, Logger +from jupyter_ai.models import GlobalConfig -@ray.remote -class EmbeddingsProviderActor(): +@ray.remote +class EmbeddingsProviderActor: def __init__(self, log: Logger): self.log = log self.provider = None @@ -17,26 +17,28 @@ def update(self, config: GlobalConfig): local_model_id, provider = ray.get( actor.get_embeddings_provider_data.remote(model_id) ) - + if not provider: raise ValueError(f"No provider and model found with '{model_id}'") - + provider_params = {} provider_params[provider.model_id_key] = local_model_id - + auth_strategy = provider.auth_strategy if auth_strategy and auth_strategy.type == "env": api_keys = config.api_keys name = auth_strategy.name if name not in api_keys: - raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") + raise ValueError( + f"Missing value for '{auth_strategy.name}' in the config." + ) provider_params[name.lower()] = api_keys[name] - + self.provider = provider.provider_klass self.provider_params = provider_params previous_model_id = self.model_id self.model_id = model_id - + if previous_model_id and previous_model_id != model_id: # delete the index actor = ray.get_actor(ACTOR_TYPE.LEARN) @@ -44,9 +46,9 @@ def update(self, config: GlobalConfig): def get_provider(self): return self.provider - + def get_provider_params(self): return self.provider_params - + def get_model_id(self): - return self.model_id \ No newline at end of file + return self.model_id diff --git a/packages/jupyter-ai/jupyter_ai/actors/generate.py b/packages/jupyter-ai/jupyter_ai/actors/generate.py index a240078b5..07efa63a8 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/generate.py +++ b/packages/jupyter-ai/jupyter_ai/actors/generate.py @@ -2,19 +2,15 @@ import os from typing import Dict, Type -import ray -from ray.util.queue import Queue - -from langchain.llms import BaseLLM -from langchain.prompts import PromptTemplate -from langchain.llms import BaseLLM -from langchain.chains import LLMChain - import nbformat - -from jupyter_ai.models import HumanChatMessage +import ray from jupyter_ai.actors.base import BaseActor, Logger +from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider, ChatOpenAINewProvider +from langchain.chains import LLMChain +from langchain.llms import BaseLLM +from langchain.prompts import PromptTemplate +from ray.util.queue import Queue schema = """{ "$schema": "http://json-schema.org/draft-07/schema#", @@ -42,11 +38,12 @@ "required": ["sections"] }""" + class NotebookOutlineChain(LLMChain): """Chain to generate a notebook outline, with section titles and descriptions.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: + def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: task_creation_template = ( "You are an AI that creates a detailed content outline for a Jupyter notebook on a given topic.\n" "Generate the outline as JSON data that will validate against this JSON schema:\n" @@ -56,24 +53,23 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: ) prompt = PromptTemplate( template=task_creation_template, - input_variables=[ - "description", - "schema" - ], + input_variables=["description", "schema"], ) return cls(prompt=prompt, llm=llm, verbose=verbose) + def generate_outline(description, llm=None, verbose=False): """Generate an outline of sections given a description of a notebook.""" chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose) outline = chain.predict(description=description, schema=schema) return json.loads(outline) + class CodeImproverChain(LLMChain): """Chain to improve source code.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: + def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: task_creation_template = ( "Improve the following code and make sure it is valid. Make sure to return the improved code only - don't give an explanation of the improvements.\n" "{code}" @@ -86,18 +82,22 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: ) return cls(prompt=prompt, llm=llm, verbose=verbose) + def improve_code(code, llm=None, verbose=False): """Improve source code using an LLM.""" chain = CodeImproverChain.from_llm(llm=llm, verbose=verbose) improved_code = chain.predict(code=code) - improved_code = '\n'.join([line for line in improved_code.split('/n') if not line.startswith("```")]) + improved_code = "\n".join( + [line for line in improved_code.split("/n") if not line.startswith("```")] + ) return improved_code + class NotebookSectionCodeChain(LLMChain): """Chain to generate source code for a notebook section.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: + def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: task_creation_template = ( "You are an AI that writes code for a single section of a Jupyter notebook.\n" "Overall topic of the notebook: {description}\n" @@ -110,35 +110,32 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: ) prompt = PromptTemplate( template=task_creation_template, - input_variables=[ - "description", - "title", - "content", - "code_so_far" - ], + input_variables=["description", "title", "content", "code_so_far"], ) return cls(prompt=prompt, llm=llm, verbose=verbose) + def generate_code(outline, llm=None, verbose=False): """Generate source code for a section given a description of the notebook and section.""" chain = NotebookSectionCodeChain.from_llm(llm=llm, verbose=verbose) code_so_far = [] - for section in outline['sections']: + for section in outline["sections"]: code = chain.predict( - description=outline['description'], - title=section['title'], - content=section['content'], - code_so_far='\n'.join(code_so_far) + description=outline["description"], + title=section["title"], + content=section["content"], + code_so_far="\n".join(code_so_far), ) - section['code'] = improve_code(code, llm=llm, verbose=verbose) - code_so_far.append(section['code']) + section["code"] = improve_code(code, llm=llm, verbose=verbose) + code_so_far.append(section["code"]) return outline + class NotebookSummaryChain(LLMChain): """Chain to generate a short summary of a notebook.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: + def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: task_creation_template = ( "Create a markdown summary for a Jupyter notebook with the following content." " The summary should consist of a single paragraph.\n" @@ -152,11 +149,12 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: ) return cls(prompt=prompt, llm=llm, verbose=verbose) + class NotebookTitleChain(LLMChain): """Chain to generate the title of a notebook.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: + def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: task_creation_template = ( "Create a short, few word, descriptive title for a Jupyter notebook with the following content.\n" "Content:\n{content}" @@ -169,32 +167,35 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: ) return cls(prompt=prompt, llm=llm, verbose=verbose) + def generate_title_and_summary(outline, llm=None, verbose=False): """Generate a title and summary of a notebook outline using an LLM.""" summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose) title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose) summary = summary_chain.predict(content=outline) title = title_chain.predict(content=outline) - outline['summary'] = summary - outline['title'] = title.strip('"') + outline["summary"] = summary + outline["title"] = title.strip('"') return outline + def create_notebook(outline): """Create an nbformat Notebook object for a notebook outline.""" nbf = nbformat.v4 nb = nbf.new_notebook() - nb['cells'].append(nbf.new_markdown_cell('# ' + outline['title'])) - nb['cells'].append(nbf.new_markdown_cell('## Introduction')) + nb["cells"].append(nbf.new_markdown_cell("# " + outline["title"])) + nb["cells"].append(nbf.new_markdown_cell("## Introduction")) disclaimer = f"This notebook was created by [Jupyter AI](https://github.com/jupyterlab/jupyter-ai) with the following prompt:\n\n> {outline['prompt']}" - nb['cells'].append(nbf.new_markdown_cell(disclaimer)) - nb['cells'].append(nbf.new_markdown_cell(outline['summary'])) + nb["cells"].append(nbf.new_markdown_cell(disclaimer)) + nb["cells"].append(nbf.new_markdown_cell(outline["summary"])) - for section in outline['sections'][1:]: - nb['cells'].append(nbf.new_markdown_cell('## ' + section['title'])) - for code_block in section['code'].split('\n\n'): - nb['cells'].append(nbf.new_code_cell(code_block)) + for section in outline["sections"][1:]: + nb["cells"].append(nbf.new_markdown_cell("## " + section["title"])) + for code_block in section["code"].split("\n\n"): + nb["cells"].append(nbf.new_code_cell(code_block)) return nb + @ray.remote class GenerateActor(BaseActor): """A Ray actor to generate a Jupyter notebook given a description.""" @@ -204,25 +205,27 @@ def __init__(self, reply_queue: Queue, root_dir: str, log: Logger): self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) self.llm = None - def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + def create_llm_chain( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): llm = provider(**provider_params) self.llm = llm return llm - + def _process_message(self, message: HumanChatMessage): self.get_llm_chain() - + response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." self.reply(response, message) prompt = message.body outline = generate_outline(prompt, llm=self.llm, verbose=True) # Save the user input prompt, the description property is now LLM generated. - outline['prompt'] = prompt + outline["prompt"] = prompt outline = generate_code(outline, llm=self.llm, verbose=True) outline = generate_title_and_summary(outline, llm=self.llm) notebook = create_notebook(outline) - final_path = os.path.join(self.root_dir, outline['title'] + '.ipynb') + final_path = os.path.join(self.root_dir, outline["title"] + ".ipynb") nbformat.write(notebook, final_path) response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py index 1754cb9da..ef4e352e0 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/learn.py +++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py @@ -1,59 +1,59 @@ +import argparse import json import os -import argparse import time from typing import List import ray -from ray.util.queue import Queue - +from jupyter_ai.actors.base import BaseActor, Logger +from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader +from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter +from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata from jupyter_core.paths import jupyter_data_dir - from langchain import FAISS +from langchain.schema import Document from langchain.text_splitter import ( - RecursiveCharacterTextSplitter, PythonCodeTextSplitter, - MarkdownTextSplitter, LatexTextSplitter + LatexTextSplitter, + MarkdownTextSplitter, + PythonCodeTextSplitter, + RecursiveCharacterTextSplitter, ) -from langchain.schema import Document - -from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata -from jupyter_ai.actors.base import BaseActor, Logger -from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader -from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter +from ray.util.queue import Queue +INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices") +METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json") -INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices') -METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, 'metadata.json') @ray.remote class LearnActor(BaseActor): - def __init__(self, reply_queue: Queue, log: Logger, root_dir: str): super().__init__(reply_queue=reply_queue, log=log) self.root_dir = root_dir self.chunk_size = 2000 self.chunk_overlap = 100 - self.parser.prog = '/learn' - self.parser.add_argument('-v', '--verbose', action='store_true') - self.parser.add_argument('-d', '--delete', action='store_true') - self.parser.add_argument('-l', '--list', action='store_true') - self.parser.add_argument('path', nargs=argparse.REMAINDER) - self.index_name = 'default' + self.parser.prog = "/learn" + self.parser.add_argument("-v", "--verbose", action="store_true") + self.parser.add_argument("-d", "--delete", action="store_true") + self.parser.add_argument("-l", "--list", action="store_true") + self.parser.add_argument("path", nargs=argparse.REMAINDER) + self.index_name = "default" self.index = None self.metadata = IndexMetadata(dirs=[]) - + if not os.path.exists(INDEX_SAVE_DIR): os.makedirs(INDEX_SAVE_DIR) - - self.load_or_create() - + + self.load_or_create() + def _process_message(self, message: HumanChatMessage): if not self.index: self.load_or_create() # If index is not still there, embeddings are not present if not self.index: - self.reply("Sorry, please select an embedding provider before using the `/learn` command.") + self.reply( + "Sorry, please select an embedding provider before using the `/learn` command." + ) args = self.parse_args(message) if args is None: @@ -63,7 +63,7 @@ def _process_message(self, message: HumanChatMessage): self.delete() self.reply(f"👍 I have deleted everything I previously learned.", message) return - + if args.list: self.reply(self._build_list_response()) return @@ -81,18 +81,18 @@ def _process_message(self, message: HumanChatMessage): if args.verbose: self.reply(f"Loading and splitting files for {load_path}", message) - + self.learn_dir(load_path) self.save() - response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. + response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. You can ask questions about these docs by prefixing your message with **/ask**.""" self.reply(response, message) def _build_list_response(self): if not self.metadata.dirs: return "There are no docs that have been learned yet." - + dirs = [dir.path for dir in self.metadata.dirs] dir_list = "\n- " + "\n- ".join(dirs) + "\n\n" message = f"""I can answer questions from docs in these directories: @@ -100,22 +100,32 @@ def _build_list_response(self): return message def learn_dir(self, path: str): - splitters={ - '.py': PythonCodeTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap), - '.md': MarkdownTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap), - '.tex': LatexTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap), - '.ipynb': NotebookSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) + splitters = { + ".py": PythonCodeTextSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + ), + ".md": MarkdownTextSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + ), + ".tex": LatexTextSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + ), + ".ipynb": NotebookSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + ), } splitter = ExtensionSplitter( splitters=splitters, - default_splitter=RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) + default_splitter=RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + ), ) loader = RayRecursiveDirectoryLoader(path) texts = loader.load_and_split(text_splitter=splitter) self.index.add_documents(texts) self._add_dir_to_metadata(path) - + def _add_dir_to_metadata(self, path: str): dirs = self.metadata.dirs index = next((i for i, dir in enumerate(dirs) if dir.path == path), None) @@ -128,10 +138,10 @@ def delete_and_relearn(self): self.delete() return message = """🔔 Hi there, It seems like you have updated the embeddings model. For the **/ask** - command to work with the new model, I have to re-learn the documents you had previously + command to work with the new model, I have to re-learn the documents you had previously submitted for learning. Please wait to use the **/ask** command until I am done with this task.""" self.reply(message) - + metadata = self.metadata self.delete() self.relearn(metadata) @@ -139,7 +149,10 @@ def delete_and_relearn(self): def delete(self): self.index = None self.metadata = IndexMetadata(dirs=[]) - paths = [os.path.join(INDEX_SAVE_DIR, self.index_name+ext) for ext in ['.pkl', '.faiss']] + paths = [ + os.path.join(INDEX_SAVE_DIR, self.index_name + ext) + for ext in [".pkl", ".faiss"] + ] for path in paths: if os.path.isfile(path): os.remove(path) @@ -148,16 +161,18 @@ def delete(self): def relearn(self, metadata: IndexMetadata): # Index all dirs in the metadata if not metadata.dirs: - return - + return + for dir in metadata.dirs: self.learn_dir(dir.path) - + self.save() - dir_list = "\n- " + "\n- ".join([dir.path for dir in self.metadata.dirs]) + "\n\n" + dir_list = ( + "\n- " + "\n- ".join([dir.path for dir in self.metadata.dirs]) + "\n\n" + ) message = f"""🎉 I am done learning docs in these directories: - {dir_list} I am ready to answer questions about them. + {dir_list} I am ready to answer questions about them. You can ask questions about these docs by prefixing your message with **/ask**.""" self.reply(message) @@ -165,17 +180,22 @@ def create(self): embeddings = self.get_embeddings() if not embeddings: return - self.index = FAISS.from_texts(["Jupyternaut knows about your filesystem, to ask questions first use the /learn command."], embeddings) + self.index = FAISS.from_texts( + [ + "Jupyternaut knows about your filesystem, to ask questions first use the /learn command." + ], + embeddings, + ) self.save() def save(self): if self.index is not None: self.index.save_local(INDEX_SAVE_DIR, index_name=self.index_name) - + self.save_metadata() def save_metadata(self): - with open(METADATA_SAVE_PATH, 'w') as f: + with open(METADATA_SAVE_PATH, "w") as f: f.write(self.metadata.json()) def load_or_create(self): @@ -184,17 +204,19 @@ def load_or_create(self): return if self.index is None: try: - self.index = FAISS.load_local(INDEX_SAVE_DIR, embeddings, index_name=self.index_name) + self.index = FAISS.load_local( + INDEX_SAVE_DIR, embeddings, index_name=self.index_name + ) self.load_metadata() except Exception as e: self.create() def load_metadata(self): if not os.path.exists(METADATA_SAVE_PATH): - return - - with open(METADATA_SAVE_PATH, 'r', encoding='utf-8') as f: - j = json.loads(f.read()) + return + + with open(METADATA_SAVE_PATH, encoding="utf-8") as f: + j = json.loads(f.read()) self.metadata = IndexMetadata(**j) def get_relevant_documents(self, question: str) -> List[Document]: diff --git a/packages/jupyter-ai/jupyter_ai/actors/memory.py b/packages/jupyter-ai/jupyter_ai/actors/memory.py index 808537599..98ee9b144 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/memory.py +++ b/packages/jupyter-ai/jupyter_ai/actors/memory.py @@ -1,91 +1,89 @@ -from typing import Dict, Any, List +from typing import Any, Dict, List import ray +from jupyter_ai.actors.base import Logger from langchain.schema import BaseMemory from pydantic import PrivateAttr -from jupyter_ai.actors.base import Logger @ray.remote -class MemoryActor(object): +class MemoryActor: """Turns any LangChain memory into a Ray actor. - + The resulting actor can be used as LangChain memory in chains running in different actors by using RemoteMemory (below). """ - + def __init__(self, log: Logger, memory: BaseMemory): self.memory = memory self.log = log - + def get_chat_memory(self): return self.memory.chat_memory - + def get_output_key(self): return self.memory.output_key - + def get_input_key(self): return self.memory.input_key def get_return_messages(self): return self.memory.return_messages - + def get_memory_variables(self): return self.memory.memory_variables def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: return self.memory.save_context(inputs, outputs) - + def clear(self): return self.memory.clear() - + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return self.memory.load_memory_variables(inputs) - - class RemoteMemory(BaseMemory): """Wraps a MemoryActor into a LangChain memory class. - + This enables you to use a single distributed memory across multiple Ray actors running different chains. """ - + actor_name: str _actor: Any = PrivateAttr() def __init__(self, **data): super().__init__(**data) self._actor = ray.get_actor(self.actor_name) - + @property def memory_variables(self) -> List[str]: o = self._actor.get_memory_variables.remote() return ray.get(o) - + @property def output_key(self): o = self._actor.get_output_key.remote() return ray.get(o) - + @property def input_key(self): o = self._actor.get_input_key.remote() return ray.get(o) - + @property def return_messages(self): o = self._actor.get_return_messages.remote() return ray.get(o) - + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: o = self._actor.save_context.remote(inputs, outputs) return ray.get(o) - + def clear(self): self._actor.clear.remote() - + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: o = self._actor.load_memory_variables.remote(inputs) - return ray.get(o) \ No newline at end of file + return ray.get(o) diff --git a/packages/jupyter-ai/jupyter_ai/actors/providers.py b/packages/jupyter-ai/jupyter_ai/actors/providers.py index fd249ede9..7b11e1773 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/providers.py +++ b/packages/jupyter-ai/jupyter_ai/actors/providers.py @@ -1,18 +1,24 @@ from typing import Optional, Tuple, Type -from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider -from jupyter_ai_magics.providers import BaseProvider -from jupyter_ai_magics.utils import decompose_model_id, load_embedding_providers, load_providers + import ray from jupyter_ai.actors.base import BaseActor, Logger +from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider +from jupyter_ai_magics.providers import BaseProvider +from jupyter_ai_magics.utils import ( + decompose_model_id, + load_embedding_providers, + load_providers, +) from ray.util.queue import Queue + @ray.remote -class ProvidersActor(): +class ProvidersActor: """Actor that loads model and embedding providers from, - entry points. Also provides utility functions to get the + entry points. Also provides utility functions to get the providers and provider class matching a provider id. """ - + def __init__(self, log: Logger): self.log = log self.model_providers = load_providers(log=log) @@ -21,22 +27,23 @@ def __init__(self, log: Logger): def get_model_providers(self): """Returns dictionary of registered LLM providers""" return self.model_providers - + def get_model_provider_data(self, model_id: str) -> Tuple[str, Type[BaseProvider]]: """Returns the model provider class that matches the provider id""" provider_id, local_model_id = decompose_model_id(model_id, self.model_providers) provider = self.model_providers.get(provider_id, None) return local_model_id, provider - + def get_embeddings_providers(self): """Returns dictionary of registered embedding providers""" return self.embeddings_providers - def get_embeddings_provider_data(self, model_id: str) -> Tuple[str, Type[BaseEmbeddingsProvider]]: + def get_embeddings_provider_data( + self, model_id: str + ) -> Tuple[str, Type[BaseEmbeddingsProvider]]: """Returns the embedding provider class that matches the provider id""" - provider_id, local_model_id = decompose_model_id(model_id, self.embeddings_providers) + provider_id, local_model_id = decompose_model_id( + model_id, self.embeddings_providers + ) provider = self.embeddings_providers.get(provider_id, None) return local_model_id, provider - - - \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/router.py b/packages/jupyter-ai/jupyter_ai/actors/router.py index fbc3234da..c9129e58a 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/router.py +++ b/packages/jupyter-ai/jupyter_ai/actors/router.py @@ -1,29 +1,28 @@ import ray +from jupyter_ai.actors.base import ACTOR_TYPE, COMMANDS, BaseActor, Logger from ray.util.queue import Queue -from jupyter_ai.actors.base import ACTOR_TYPE, COMMANDS, Logger, BaseActor @ray.remote class Router(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): """Routes messages to the correct actor. - - To register new actors, add the actor type in the `ACTOR_TYPE` enum and + + To register new actors, add the actor type in the `ACTOR_TYPE` enum and add a corresponding command in the `COMMANDS` dictionary. """ super().__init__(reply_queue=reply_queue, log=log) def _process_message(self, message): - # assign default actor default = ray.get_actor(ACTOR_TYPE.DEFAULT) if message.body.startswith("/"): - command = message.body.split(' ', 1)[0] + command = message.body.split(" ", 1)[0] if command in COMMANDS.keys(): actor = ray.get_actor(COMMANDS[command].value) actor.process_message.remote(message) - if command == '/clear': + if command == "/clear": actor = ray.get_actor(ACTOR_TYPE.DEFAULT) actor.clear_memory.remote() else: diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index f87a74e30..98098da7e 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -1,31 +1,31 @@ -from typing import List, Optional -from pathlib import Path import hashlib import itertools +from pathlib import Path +from typing import List, Optional import ray - from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document from langchain.document_loaders.directory import _is_visible -from langchain.text_splitter import ( - RecursiveCharacterTextSplitter, TextSplitter, -) +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter + @ray.remote def path_to_doc(path): - with open(str(path), 'r') as f: + with open(str(path)) as f: text = f.read() m = hashlib.sha256() - m.update(text.encode('utf-8')) - metadata = {'path': str(path), 'sha256': m.digest(), 'extension': path.suffix} + m.update(text.encode("utf-8")) + metadata = {"path": str(path), "sha256": m.digest(), "extension": path.suffix} return Document(page_content=text, metadata=metadata) + class ExcludePattern(Exception): pass - + + def iter_paths(path, extensions, exclude): - for p in Path(path).rglob('*'): + for p in Path(path).rglob("*"): if p.is_dir(): continue if not _is_visible(p.relative_to(path)): @@ -39,23 +39,43 @@ def iter_paths(path, extensions, exclude): if p.suffix in extensions: yield p + class RayRecursiveDirectoryLoader(BaseLoader): - def __init__( self, path, - extensions={'.py', '.md', '.R', '.Rmd', '.jl', '.sh', '.ipynb', '.js', '.ts', '.jsx', '.tsx', '.txt'}, - exclude={'.ipynb_checkpoints', 'node_modules', 'lib', 'build', '.git', '.DS_Store'} + extensions={ + ".py", + ".md", + ".R", + ".Rmd", + ".jl", + ".sh", + ".ipynb", + ".js", + ".ts", + ".jsx", + ".tsx", + ".txt", + }, + exclude={ + ".ipynb_checkpoints", + "node_modules", + "lib", + "build", + ".git", + ".DS_Store", + }, ): self.path = path self.extensions = extensions - self.exclude=exclude - + self.exclude = exclude + def load(self) -> List[Document]: paths = iter_paths(self.path, self.extensions, self.exclude) doc_refs = list(map(path_to_doc.remote, paths)) return ray.get(doc_refs) - + def load_and_split( self, text_splitter: Optional[TextSplitter] = None ) -> List[Document]: @@ -67,7 +87,7 @@ def load_and_split( @ray.remote def split(doc): return _text_splitter.split_documents([doc]) - + paths = iter_paths(self.path, self.extensions, self.exclude) doc_refs = map(split.remote, map(path_to_doc.remote, paths)) - return list(itertools.chain(*ray.get(list(doc_refs)))) \ No newline at end of file + return list(itertools.chain(*ray.get(list(doc_refs)))) diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py b/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py index d80289453..d3b32f478 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/splitter.py @@ -1,44 +1,50 @@ +import copy from typing import List, Optional from langchain.schema import Document -from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter, MarkdownTextSplitter -import copy +from langchain.text_splitter import ( + MarkdownTextSplitter, + RecursiveCharacterTextSplitter, + TextSplitter, +) + class ExtensionSplitter(TextSplitter): - def __init__(self, splitters, default_splitter=None): self.splitters = splitters if default_splitter is None: self.default_splitter = RecursiveCharacterTextSplitter() else: self.default_splitter = default_splitter - + def split_text(self, text: str, metadata=None): - splitter = self.splitters.get(metadata['extension'], self.default_splitter) + splitter = self.splitters.get(metadata["extension"], self.default_splitter) return splitter.split_text(text) - def create_documents(self, texts: List[str], metadatas: Optional[List[dict]] = None) -> List[Document]: + def create_documents( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> List[Document]: _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): metadata = copy.deepcopy(_metadatas[i]) for chunk in self.split_text(text, metadata): - new_doc = Document( - page_content=chunk, metadata=metadata - ) + new_doc = Document(page_content=chunk, metadata=metadata) documents.append(new_doc) return documents + import nbformat + class NotebookSplitter(TextSplitter): - def __init__(self, **kwargs): super().__init__(**kwargs) - self.markdown_splitter = MarkdownTextSplitter(chunk_size=self._chunk_size, chunk_overlap=self._chunk_overlap) - + self.markdown_splitter = MarkdownTextSplitter( + chunk_size=self._chunk_size, chunk_overlap=self._chunk_overlap + ) + def split_text(self, text: str): nb = nbformat.reads(text, as_version=4) - md = '\n\n'.join([cell.source for cell in nb.cells]) + md = "\n\n".join([cell.source for cell in nb.cells]) return self.markdown_splitter.split_text(md) - diff --git a/packages/jupyter-ai/jupyter_ai/engine.py b/packages/jupyter-ai/jupyter_ai/engine.py index 11e0f7571..d4dce4595 100644 --- a/packages/jupyter-ai/jupyter_ai/engine.py +++ b/packages/jupyter-ai/jupyter_ai/engine.py @@ -1,12 +1,16 @@ -from abc import abstractmethod, ABC, ABCMeta +from abc import ABC, ABCMeta, abstractmethod from typing import Dict + import openai from traitlets.config import LoggingConfigurable, Unicode + from .task_manager import DescribeTaskResponse + class BaseModelEngineMetaclass(ABCMeta, type(LoggingConfigurable)): pass + class BaseModelEngine(ABC, LoggingConfigurable, metaclass=BaseModelEngineMetaclass): id: str name: str @@ -14,34 +18,33 @@ class BaseModelEngine(ABC, LoggingConfigurable, metaclass=BaseModelEngineMetacla # these two attributes are currently reserved but unused. input_type: str output_type: str - + @abstractmethod - async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]): + async def execute( + self, task: DescribeTaskResponse, prompt_variables: Dict[str, str] + ): pass + class GPT3ModelEngine(BaseModelEngine): id = "gpt3" name = "GPT-3" - modalities = [ - "txt2txt" - ] + modalities = ["txt2txt"] - api_key = Unicode( - config=True, - help="OpenAI API key", - allow_none=False - ) + api_key = Unicode(config=True, help="OpenAI API key", allow_none=False) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]): + async def execute( + self, task: DescribeTaskResponse, prompt_variables: Dict[str, str] + ): if "body" not in prompt_variables: raise Exception("Prompt body must be specified.") prompt = task.prompt_template.format(**prompt_variables) self.log.info(f"GPT3 prompt:\n{prompt}") - + openai.api_key = self.api_key response = await openai.Completion.acreate( model="text-davinci-003", @@ -50,6 +53,6 @@ async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, max_tokens=256, top_p=1, frequency_penalty=0, - presence_penalty=0 + presence_penalty=0, ) - return response['choices'][0]['text'] \ No newline at end of file + return response["choices"][0]["text"] diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 837bb68f8..df40fcf24 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,41 +1,36 @@ import asyncio +import inspect +import ray +from importlib_metadata import entry_points +from jupyter_ai.actors.ask import AskActor +from jupyter_ai.actors.base import ACTOR_TYPE from jupyter_ai.actors.chat_provider import ChatProviderActor from jupyter_ai.actors.config import ConfigActor +from jupyter_ai.actors.default import DefaultActor from jupyter_ai.actors.embeddings_provider import EmbeddingsProviderActor -from jupyter_ai.actors.providers import ProvidersActor - -from jupyter_ai_magics.utils import load_providers - -from langchain.memory import ConversationBufferWindowMemory -from jupyter_ai.actors.default import DefaultActor -from jupyter_ai.actors.ask import AskActor +from jupyter_ai.actors.generate import GenerateActor from jupyter_ai.actors.learn import LearnActor -from jupyter_ai.actors.router import Router from jupyter_ai.actors.memory import MemoryActor -from jupyter_ai.actors.generate import GenerateActor -from jupyter_ai.actors.base import ACTOR_TYPE +from jupyter_ai.actors.providers import ProvidersActor +from jupyter_ai.actors.router import Router from jupyter_ai.reply_processor import ReplyProcessor +from jupyter_ai_magics.utils import load_providers from jupyter_server.extension.application import ExtensionApp +from langchain.memory import ConversationBufferWindowMemory +from ray.util.queue import Queue +from .engine import BaseModelEngine from .handlers import ( - ChatHandler, - ChatHistoryHandler, - EmbeddingsModelProviderHandler, - ModelProviderHandler, - PromptAPIHandler, + ChatHandler, + ChatHistoryHandler, + EmbeddingsModelProviderHandler, + GlobalConfigHandler, + ModelProviderHandler, + PromptAPIHandler, TaskAPIHandler, - GlobalConfigHandler ) -from importlib_metadata import entry_points -import inspect -from .engine import BaseModelEngine - -import ray -from ray.util.queue import Queue -from jupyter_ai_magics.utils import load_providers - class AiExtension(ExtensionApp): name = "jupyter_ai" @@ -51,41 +46,48 @@ class AiExtension(ExtensionApp): ] @property - def ai_engines(self): + def ai_engines(self): if "ai_engines" not in self.settings: self.settings["ai_engines"] = {} return self.settings["ai_engines"] - def initialize_settings(self): ray.init() # EP := entry point eps = entry_points() - + ## step 1: instantiate model engines and bind them to settings model_engine_class_eps = eps.select(group="jupyter_ai.model_engine_classes") - + if not model_engine_class_eps: - self.log.error("No model engines found for jupyter_ai.model_engine_classes group. One or more model engines are required for AI extension to work.") + self.log.error( + "No model engines found for jupyter_ai.model_engine_classes group. One or more model engines are required for AI extension to work." + ) return for model_engine_class_ep in model_engine_class_eps: try: Engine = model_engine_class_ep.load() except: - self.log.error(f"Unable to load model engine class from entry point `{model_engine_class_ep.name}`.") + self.log.error( + f"Unable to load model engine class from entry point `{model_engine_class_ep.name}`." + ) continue if not inspect.isclass(Engine) or not issubclass(Engine, BaseModelEngine): - self.log.error(f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}` as it is not a subclass of `BaseModelEngine`.") + self.log.error( + f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}` as it is not a subclass of `BaseModelEngine`." + ) continue try: self.ai_engines[Engine.id] = Engine(config=self.config, log=self.log) except: - self.log.error(f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}`.") + self.log.error( + f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}`." + ) continue self.log.info(f"Registered engine `{Engine.id}`.") @@ -96,15 +98,17 @@ def initialize_settings(self): if not module_default_tasks_eps: self.settings["ai_default_tasks"] = [] return - + default_tasks = [] for module_default_tasks_ep in module_default_tasks_eps: try: module_default_tasks = module_default_tasks_ep.load() except: - self.log.error(f"Unable to load task from entry point `{module_default_tasks_ep.name}`") + self.log.error( + f"Unable to load task from entry point `{module_default_tasks_ep.name}`" + ) continue - + default_tasks += module_default_tasks self.settings["ai_default_tasks"] = default_tasks @@ -119,13 +123,12 @@ def initialize_settings(self): # Store chat clients in a dictionary self.settings["chat_clients"] = {} self.settings["chat_handlers"] = {} - + # store chat messages in memory for now # this is only used to render the UI, and is not the conversational # memory object used by the LM chain. self.settings["chat_history"] = [] - reply_queue = Queue() self.settings["reply_queue"] = reply_queue @@ -134,21 +137,27 @@ def initialize_settings(self): log=self.log, ) default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote( - reply_queue=reply_queue, + reply_queue=reply_queue, log=self.log, - chat_history=self.settings["chat_history"] + chat_history=self.settings["chat_history"], ) - providers_actor = ProvidersActor.options(name=ACTOR_TYPE.PROVIDERS.value).remote( + providers_actor = ProvidersActor.options( + name=ACTOR_TYPE.PROVIDERS.value + ).remote( log=self.log, ) config_actor = ConfigActor.options(name=ACTOR_TYPE.CONFIG.value).remote( log=self.log, ) - chat_provider_actor = ChatProviderActor.options(name=ACTOR_TYPE.CHAT_PROVIDER.value).remote( + chat_provider_actor = ChatProviderActor.options( + name=ACTOR_TYPE.CHAT_PROVIDER.value + ).remote( log=self.log, ) - embeddings_provider_actor = EmbeddingsProviderActor.options(name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value).remote( + embeddings_provider_actor = EmbeddingsProviderActor.options( + name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value + ).remote( log=self.log, ) learn_actor = LearnActor.options(name=ACTOR_TYPE.LEARN.value).remote( @@ -157,7 +166,7 @@ def initialize_settings(self): root_dir=self.serverapp.root_dir, ) ask_actor = AskActor.options(name=ACTOR_TYPE.ASK.value).remote( - reply_queue=reply_queue, + reply_queue=reply_queue, log=self.log, ) memory_actor = MemoryActor.options(name=ACTOR_TYPE.MEMORY.value).remote( @@ -165,22 +174,24 @@ def initialize_settings(self): memory=ConversationBufferWindowMemory(return_messages=True, k=2), ) generate_actor = GenerateActor.options(name=ACTOR_TYPE.GENERATE.value).remote( - reply_queue=reply_queue, + reply_queue=reply_queue, log=self.log, - root_dir=self.settings['server_root_dir'], + root_dir=self.settings["server_root_dir"], ) - - self.settings['router'] = router - self.settings['providers_actor'] = providers_actor - self.settings['config_actor'] = config_actor - self.settings['chat_provider_actor'] = chat_provider_actor - self.settings['embeddings_provider_actor'] = embeddings_provider_actor + + self.settings["router"] = router + self.settings["providers_actor"] = providers_actor + self.settings["config_actor"] = config_actor + self.settings["chat_provider_actor"] = chat_provider_actor + self.settings["embeddings_provider_actor"] = embeddings_provider_actor self.settings["default_actor"] = default_actor self.settings["learn_actor"] = learn_actor self.settings["ask_actor"] = ask_actor self.settings["memory_actor"] = memory_actor self.settings["generate_actor"] = generate_actor - reply_processor = ReplyProcessor(self.settings['chat_handlers'], reply_queue, log=self.log) + reply_processor = ReplyProcessor( + self.settings["chat_handlers"], reply_queue, log=self.log + ) loop = asyncio.get_event_loop() - loop.create_task(reply_processor.start()) \ No newline at end of file + loop.create_task(reply_processor.start()) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index b141da604..7dc021aac 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -1,45 +1,43 @@ -from dataclasses import asdict +import getpass import json +import time +import uuid +from dataclasses import asdict from typing import Dict, List -from jupyter_ai.actors.base import ACTOR_TYPE + import ray import tornado -import uuid -import time -import getpass - -from tornado.web import HTTPError +from jupyter_ai.actors.base import ACTOR_TYPE +from jupyter_server.base.handlers import APIHandler as BaseAPIHandler +from jupyter_server.base.handlers import JupyterHandler +from jupyter_server.utils import ensure_async from pydantic import ValidationError - from tornado import web, websocket - -from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler -from jupyter_server.utils import ensure_async - -from .task_manager import TaskManager +from tornado.web import HTTPError from .models import ( - ChatHistory, - ChatUser, - ListProvidersEntry, - ListProvidersResponse, - PromptRequest, - ChatRequest, - ChatMessage, - Message, - AgentChatMessage, - HumanChatMessage, - ConnectionMessage, + AgentChatMessage, ChatClient, - GlobalConfig + ChatHistory, + ChatMessage, + ChatRequest, + ChatUser, + ConnectionMessage, + GlobalConfig, + HumanChatMessage, + ListProvidersEntry, + ListProvidersResponse, + Message, + PromptRequest, ) +from .task_manager import TaskManager class APIHandler(BaseAPIHandler): @property - def engines(self): + def engines(self): return self.settings["ai_engines"] - + @property def default_tasks(self): return self.settings["ai_default_tasks"] @@ -49,9 +47,12 @@ def task_manager(self): # we have to create the TaskManager lazily, since no event loop is # running in ServerApp.initialize_settings(). if "task_manager" not in self.settings: - self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks) + self.settings["task_manager"] = TaskManager( + engines=self.engines, default_tasks=self.default_tasks + ) return self.settings["task_manager"] - + + class PromptAPIHandler(APIHandler): @tornado.web.authenticated async def post(self): @@ -68,13 +69,13 @@ async def post(self): task = await self.task_manager.describe_task(request.task_id) if not task: raise HTTPError(404, f"Task not found with ID: {request.task_id}") - + output = await ensure_async(engine.execute(task, request.prompt_variables)) - self.finish(json.dumps({ - "output": output, - "insertion_mode": task.insertion_mode - })) + self.finish( + json.dumps({"output": output, "insertion_mode": task.insertion_mode}) + ) + class TaskAPIHandler(APIHandler): @tornado.web.authenticated @@ -83,7 +84,7 @@ async def get(self, id=None): list_tasks_response = await self.task_manager.list_tasks() self.finish(json.dumps(list_tasks_response.dict())) return - + describe_task_response = await self.task_manager.describe_task(id) if describe_task_response is None: raise HTTPError(404, f"Task not found with ID: {id}") @@ -95,35 +96,32 @@ class ChatHistoryHandler(BaseAPIHandler): """Handler to return message history""" _messages = [] - + @property def chat_history(self): return self.settings["chat_history"] - + @chat_history.setter def _chat_history_setter(self, new_history): self.settings["chat_history"] = new_history - + @tornado.web.authenticated async def get(self): history = ChatHistory(messages=self.chat_history) self.finish(history.json()) -class ChatHandler( - JupyterHandler, - websocket.WebSocketHandler -): +class ChatHandler(JupyterHandler, websocket.WebSocketHandler): """ A websocket handler for chat. """ - + @property - def chat_handlers(self) -> Dict[str, 'ChatHandler']: + def chat_handlers(self) -> Dict[str, "ChatHandler"]: """Dictionary mapping client IDs to their WebSocket handler instances.""" return self.settings["chat_handlers"] - + @property def chat_clients(self) -> Dict[str, ChatClient]: """Dictionary mapping client IDs to their ChatClient objects that store @@ -143,8 +141,7 @@ def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) def pre_get(self): - """Handles authentication/authorization. - """ + """Handles authentication/authorization.""" # authenticate the request before opening the websocket user = self.current_user if user is None: @@ -168,8 +165,7 @@ def get_chat_user(self) -> ChatUser: if collaborative: return ChatUser(**asdict(self.current_user)) - - + login = getpass.getuser() return ChatUser( username=login, @@ -177,10 +173,9 @@ def get_chat_user(self) -> ChatUser: name=login, display_name=login, color=None, - avatar_url=None + avatar_url=None, ) - def generate_client_id(self): """Generates a client ID to identify the current WS connection.""" return uuid.uuid4().hex @@ -201,25 +196,27 @@ def open(self): self.log.debug("Clients are : %s", self.chat_handlers.keys()) def broadcast_message(self, message: Message): - """Broadcasts message to all connected clients. + """Broadcasts message to all connected clients. Appends message to `self.chat_history`. """ self.log.debug("Broadcasting message: %s to all clients...", message) client_ids = self.chat_handlers.keys() - + for client_id in client_ids: client = self.chat_handlers[client_id] if client: client.write_message(message.dict()) - + # Only append ChatMessage instances to history, not control messages - if isinstance(message, HumanChatMessage) or isinstance(message, AgentChatMessage): + if isinstance(message, HumanChatMessage) or isinstance( + message, AgentChatMessage + ): self.chat_history.append(message) async def on_message(self, message): self.log.debug("Message recieved: %s", message) - + try: message = json.loads(message) chat_request = ChatRequest(**message) @@ -240,9 +237,9 @@ async def on_message(self, message): self.broadcast_message(message=chat_message) # Clear the message history if given the /clear command - if chat_request.prompt.startswith('/'): - command = chat_request.prompt.split(' ', 1)[0] - if command == '/clear': + if chat_request.prompt.startswith("/"): + command = chat_request.prompt.split(" ", 1)[0] + if command == "/clear": self.chat_history.clear() # process through the router @@ -261,11 +258,11 @@ def on_close(self): class ModelProviderHandler(BaseAPIHandler): @property - def chat_providers(self): + def chat_providers(self): actor = ray.get_actor("providers") o = actor.get_model_providers.remote() return ray.get(o) - + @web.authenticated def get(self): providers = [] @@ -284,13 +281,14 @@ def get(self): fields=provider.fields, ) ) - - response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name)) + + response = ListProvidersResponse( + providers=sorted(providers, key=lambda p: p.name) + ) self.finish(response.json()) class EmbeddingsModelProviderHandler(BaseAPIHandler): - @property def embeddings_providers(self): actor = ray.get_actor("providers") @@ -311,8 +309,10 @@ def get(self): fields=provider.fields, ) ) - - response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name)) + + response = ListProvidersResponse( + providers=sorted(providers, key=lambda p: p.name) + ) self.finish(response.json()) @@ -320,14 +320,14 @@ class GlobalConfigHandler(BaseAPIHandler): """API handler for fetching and setting the model and emebddings config. """ - + @web.authenticated def get(self): actor = ray.get_actor(ACTOR_TYPE.CONFIG) config = ray.get(actor.get_config.remote()) if not config: raise HTTPError(500, "No config found.") - + self.finish(config.json()) @web.authenticated @@ -345,10 +345,9 @@ def post(self): raise HTTPError(500, str(e)) from e except ValueError as e: self.log.exception(e) - raise HTTPError(500, str(e.cause) if hasattr(e, 'cause') else str(e)) + raise HTTPError(500, str(e.cause) if hasattr(e, "cause") else str(e)) except Exception as e: self.log.exception(e) raise HTTPError( 500, "Unexpected error occurred while updating the config." ) from e - diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index cfa6a8419..dcefc8934 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,17 +1,20 @@ +from typing import Any, Dict, List, Literal, Optional, Union + from jupyter_ai_magics.providers import AuthStrategy, Field +from pydantic import BaseModel -from pydantic import BaseModel -from typing import Any, Dict, List, Union, Literal, Optional class PromptRequest(BaseModel): task_id: str engine_id: str prompt_variables: Dict[str, str] + # the type of message used to chat with the agent class ChatRequest(BaseModel): prompt: str + class ChatUser(BaseModel): # User ID assigned by IdentityProvider. username: str @@ -21,12 +24,14 @@ class ChatUser(BaseModel): color: Optional[str] avatar_url: Optional[str] + class ChatClient(ChatUser): # A unique client ID assigned to identify different JupyterLab clients on # the same device (i.e. running on multiple tabs/windows), which may have # the same username assigned to them by the IdentityProvider. id: str + class AgentChatMessage(BaseModel): type: Literal["agent"] = "agent" id: str @@ -35,6 +40,7 @@ class AgentChatMessage(BaseModel): # message ID of the HumanChatMessage it is replying to reply_to: str + class HumanChatMessage(BaseModel): type: Literal["human"] = "human" id: str @@ -42,45 +48,49 @@ class HumanChatMessage(BaseModel): body: str client: ChatClient + class ConnectionMessage(BaseModel): type: Literal["connection"] = "connection" client_id: str + class ClearMessage(BaseModel): type: Literal["clear"] = "clear" + # the type of messages being broadcast to clients ChatMessage = Union[ AgentChatMessage, HumanChatMessage, ] -Message = Union[ - AgentChatMessage, - HumanChatMessage, - ConnectionMessage, - ClearMessage -] +Message = Union[AgentChatMessage, HumanChatMessage, ConnectionMessage, ClearMessage] + class ListEnginesEntry(BaseModel): id: str name: str + class ListTasksEntry(BaseModel): id: str name: str + class ListTasksResponse(BaseModel): tasks: List[ListTasksEntry] + class DescribeTaskResponse(BaseModel): name: str insertion_mode: str prompt_template: str engines: List[ListEnginesEntry] + class ChatHistory(BaseModel): """History of chat messages""" + messages: List[ChatMessage] @@ -88,6 +98,7 @@ class ListProvidersEntry(BaseModel): """Model provider with supported models and provider's authentication strategy """ + id: str name: str models: List[str] @@ -99,12 +110,15 @@ class ListProvidersEntry(BaseModel): class ListProvidersResponse(BaseModel): providers: List[ListProvidersEntry] + class IndexedDir(BaseModel): path: str + class IndexMetadata(BaseModel): dirs: List[IndexedDir] + class GlobalConfig(BaseModel): model_provider_id: Optional[str] = None embeddings_provider_id: Optional[str] = None diff --git a/packages/jupyter-ai/jupyter_ai/reply_processor.py b/packages/jupyter-ai/jupyter_ai/reply_processor.py index 3df2bfb71..d8f215d02 100644 --- a/packages/jupyter-ai/jupyter_ai/reply_processor.py +++ b/packages/jupyter-ai/jupyter_ai/reply_processor.py @@ -1,10 +1,11 @@ import asyncio from typing import Dict + from jupyter_ai.handlers import ChatHandler from ray.util.queue import Queue -class ReplyProcessor(): +class ReplyProcessor: """A single processor to distribute replies""" def __init__(self, handlers: Dict[str, ChatHandler], queue: Queue, log): @@ -13,11 +14,11 @@ def __init__(self, handlers: Dict[str, ChatHandler], queue: Queue, log): self.log = log def process(self, message): - self.log.debug('Processing message %s in ReplyProcessor', message) + self.log.debug("Processing message %s in ReplyProcessor", message) for handler in self.handlers.values(): if not handler: continue - + handler.broadcast_message(message) break diff --git a/packages/jupyter-ai/jupyter_ai/task_manager.py b/packages/jupyter-ai/jupyter_ai/task_manager.py index e995a6f01..9c38573b3 100644 --- a/packages/jupyter-ai/jupyter_ai/task_manager.py +++ b/packages/jupyter-ai/jupyter_ai/task_manager.py @@ -1,10 +1,17 @@ import asyncio -import aiosqlite import os -from typing import Optional, List +from typing import List, Optional +import aiosqlite from jupyter_core.paths import jupyter_data_dir -from .models import ListTasksResponse, ListTasksEntry, DescribeTaskResponse, ListEnginesEntry + +from .models import ( + DescribeTaskResponse, + ListEnginesEntry, + ListTasksEntry, + ListTasksResponse, +) + class TaskManager: db_path = os.path.join(jupyter_data_dir(), "ai_task_manager.db") @@ -13,7 +20,7 @@ def __init__(self, engines, default_tasks): self.engines = engines self.default_tasks = default_tasks self.db_initialized = asyncio.create_task(self.init_db()) - + async def init_db(self): async with aiosqlite.connect(self.db_path) as con: await con.execute( @@ -33,36 +40,39 @@ async def init_db(self): for task in self.default_tasks: await con.execute( "INSERT INTO tasks (id, name, prompt_template, modality, insertion_mode, is_default) VALUES (?, ?, ?, ?, ?, ?)", - (task["id"], task["name"], task["prompt_template"], task["modality"], task["insertion_mode"], 1) + ( + task["id"], + task["name"], + task["prompt_template"], + task["modality"], + task["insertion_mode"], + 1, + ), ) - + await con.commit() - + async def list_tasks(self) -> ListTasksResponse: await self.db_initialized async with aiosqlite.connect(self.db_path) as con: - cursor = await con.execute( - "SELECT id, name FROM tasks" - ) + cursor = await con.execute("SELECT id, name FROM tasks") rows = await cursor.fetchall() tasks = [] if not rows: return tasks - + for row in rows: - tasks.append(ListTasksEntry( - id=row[0], - name=row[1] - )) - + tasks.append(ListTasksEntry(id=row[0], name=row[1])) + return ListTasksResponse(tasks=tasks) - + async def describe_task(self, id: str) -> Optional[DescribeTaskResponse]: await self.db_initialized async with aiosqlite.connect(self.db_path) as con: cursor = await con.execute( - "SELECT name, prompt_template, modality, insertion_mode FROM tasks WHERE id = ?", (id,) + "SELECT name, prompt_template, modality, insertion_mode FROM tasks WHERE id = ?", + (id,), ) row = await cursor.fetchone() @@ -75,15 +85,14 @@ async def describe_task(self, id: str) -> Optional[DescribeTaskResponse]: for engine in self.engines.values(): if modality in engine.modalities: engines.append(ListEnginesEntry(id=engine.id, name=engine.name)) - + # sort engines A-Z engines = sorted(engines, key=lambda engine: engine.name) - + return DescribeTaskResponse( name=row[0], prompt_template=row[1], modality=row[2], insertion_mode=row[3], - engines=engines + engines=engines, ) - \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tasks.py b/packages/jupyter-ai/jupyter_ai/tasks.py index 9d3e8133c..b3127ad56 100644 --- a/packages/jupyter-ai/jupyter_ai/tasks.py +++ b/packages/jupyter-ai/jupyter_ai/tasks.py @@ -1,5 +1,6 @@ from typing import List, TypedDict + class DefaultTaskDefinition(TypedDict): id: str name: str @@ -7,40 +8,41 @@ class DefaultTaskDefinition(TypedDict): modality: str insertion_mode: str + tasks: List[DefaultTaskDefinition] = [ { "id": "explain-code", "name": "Explain code", - "prompt_template": "Explain the following Python 3 code. The first sentence must begin with the phrase \"The code below\".\n{body}", + "prompt_template": 'Explain the following Python 3 code. The first sentence must begin with the phrase "The code below".\n{body}', "modality": "txt2txt", - "insertion_mode": "above" + "insertion_mode": "above", }, { "id": "generate-code", "name": "Generate code", "prompt_template": "Generate Python 3 code in Markdown according to the following definition.\n{body}", "modality": "txt2txt", - "insertion_mode": "below" + "insertion_mode": "below", }, { "id": "explain-code-in-cells-above", "name": "Explain code in cells above", - "prompt_template": "Explain the following Python 3 code. The first sentence must begin with the phrase \"The code below\".\n{body}", + "prompt_template": 'Explain the following Python 3 code. The first sentence must begin with the phrase "The code below".\n{body}', "modality": "txt2txt", - "insertion_mode": "above-in-cells" + "insertion_mode": "above-in-cells", }, { "id": "generate-code-in-cells-below", "name": "Generate code in cells below", "prompt_template": "Generate Python 3 code in Markdown according to the following definition.\n{body}", "modality": "txt2txt", - "insertion_mode": "below-in-cells" + "insertion_mode": "below-in-cells", }, { "id": "freeform", "name": "Freeform prompt", "prompt_template": "{body}", "modality": "txt2txt", - "insertion_mode": "below" - } + "insertion_mode": "below", + }, ] diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index 620ec8990..0194dec73 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -11,4 +11,4 @@ # payload = json.loads(response.body) # assert payload == { # "data": "This is /jupyter-ai/get_example endpoint!" -# } \ No newline at end of file +# } diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index f9d86ef3d..50b31ff26 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -54,7 +54,7 @@ "watch": "run-p watch:src watch:labextension", "watch:src": "tsc -w", "watch:labextension": "jupyter labextension watch .", - "dev-install": "pip install -e . && jupyter labextension develop . --overwrite && jupyter server extension enable jupyter_ai", + "dev-install": "pip install -e \".[dev,all]\" && jupyter labextension develop . --overwrite && jupyter server extension enable jupyter_ai", "dev-uninstall": "pip uninstall jupyter_ai -y" }, "dependencies": { diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index cff42637e..20a39e558 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -53,14 +53,12 @@ test = [ "pytest-tornasync" ] +dev = [ + "jupyter_ai_magics[dev]" +] + all = [ - "ai21", - "anthropic", - "cohere", - "huggingface_hub", - "ipywidgets", - "openai", - "boto3" + "jupyter_ai_magics[all]" ] [tool.hatch.version] diff --git a/packages/jupyter-ai/setup.py b/packages/jupyter-ai/setup.py index bea233743..aefdf20db 100644 --- a/packages/jupyter-ai/setup.py +++ b/packages/jupyter-ai/setup.py @@ -1 +1 @@ -__import__('setuptools').setup() +__import__("setuptools").setup() diff --git a/packages/jupyter-ai/style/icons/jupyternaut.svg b/packages/jupyter-ai/style/icons/jupyternaut.svg index 20f02087b..d4367985d 100644 --- a/packages/jupyter-ai/style/icons/jupyternaut.svg +++ b/packages/jupyter-ai/style/icons/jupyternaut.svg @@ -12,4 +12,4 @@ - \ No newline at end of file + diff --git a/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py b/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py index 5ba7a914e..23d06f6dd 100644 --- a/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py +++ b/packages/jupyter-ai/ui-tests/jupyter_server_test_config.py @@ -10,7 +10,7 @@ c.ServerApp.port_retries = 0 c.ServerApp.open_browser = False -c.ServerApp.root_dir = mkdtemp(prefix='galata-test-') +c.ServerApp.root_dir = mkdtemp(prefix="galata-test-") c.ServerApp.token = "" c.ServerApp.password = "" c.ServerApp.disable_check_xsrf = True diff --git a/playground/config.example.py b/playground/config.example.py index 88b43074a..850ab215a 100644 --- a/playground/config.example.py +++ b/playground/config.example.py @@ -2,5 +2,5 @@ # Reference: https://jupyter-ai.readthedocs.io/en/latest/users/index.html#configuring-with-openai # Specify full path to the notebook dir if running jupyter lab from -# outside of the jupyter-ai project root directory +# outside of the jupyter-ai project root directory c.ServerApp.root_dir = "./playground" diff --git a/scripts/bump-version.sh b/scripts/bump-version.sh index 11d90177a..900211be9 100755 --- a/scripts/bump-version.sh +++ b/scripts/bump-version.sh @@ -1,3 +1,5 @@ +#!/bin/bash + # script that bumps version for all projects regardless of whether they were # changed since last release. needed because `lerna version` only bumps versions for projects # listed by `lerna changed` by default.