Skip to content

Commit

Permalink
Backport PR jupyterlab#572: LangChain v0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq authored and meeseeksmachine committed Jan 10, 2024
1 parent 240ba59 commit 495c812
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import ClassVar, List

from jupyter_ai_magics.providers import (
Expand All @@ -8,15 +7,15 @@
Field,
MultiEnvAuthStrategy,
)
from langchain.embeddings import (
from langchain.pydantic_v1 import BaseModel, Extra
from langchain_community.embeddings import (
BedrockEmbeddings,
CohereEmbeddings,
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
QianfanEmbeddingsEndpoint,
)
from langchain.pydantic_v1 import BaseModel, Extra


class BaseEmbeddingsProvider(BaseModel):
Expand Down
71 changes: 47 additions & 24 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,41 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union

from jsonpath_ng import parse
from langchain.chat_models import (
from langchain.chat_models.base import BaseChatModel
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.llms.utils import enforce_stop_tokens
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain_community.chat_models import (
AzureChatOpenAI,
BedrockChat,
ChatAnthropic,
ChatOpenAI,
QianfanChatEndpoint,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms import (
from langchain_community.llms import (
AI21,
Anthropic,
Bedrock,
Cohere,
GPT4All,
HuggingFaceHub,
OpenAI,
OpenAIChat,
QianfanLLMEndpoint,
SagemakerEndpoint,
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.llms.utils import enforce_stop_tokens
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain_community.chat_models import ChatOpenAI

# this is necessary because `langchain.pydantic_v1.main` does not include
# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main`
# subpackage.
try:
from pydantic.v1.main import ModelMetaclass
except:
from pydantic.main import ModelMetaclass


class EnvAuthStrategy(BaseModel):
Expand Down Expand Up @@ -99,7 +95,34 @@ class IntegerField(BaseModel):
Field = Union[TextField, MultilineTextField, IntegerField]


class BaseProvider(BaseModel):
class ProviderMetaclass(ModelMetaclass):
"""
A metaclass that ensures all class attributes defined inline within the
class definition are accessible and included in `Class.__dict__`.
This is necessary because Pydantic drops any ClassVars that are defined as
an instance field by a parent class, even if they are defined inline within
the class definition. We encountered this case when `langchain` added a
`name` attribute to a parent class shared by all `Provider`s, which caused
`Provider.name` to be inaccessible. See #558 for more info.
"""

def __new__(mcs, name, bases, namespace, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
for key in namespace:
# skip private class attributes
if key.startswith("_"):
continue
# skip class attributes already listed in `cls.__dict__`
if key in cls.__dict__:
continue

setattr(cls, key, namespace[key])

return cls


class BaseProvider(BaseModel, metaclass=ProviderMetaclass):
#
# pydantic config
#
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import ClassVar, Optional

from langchain.pydantic_v1 import BaseModel

from ..providers import ProviderMetaclass


def test_provider_metaclass():
"""
Asserts that the metaclass prevents class attributes from being omitted due
to parent classes defining an instance field of the same name.
You can reproduce the original issue by removing the
`metaclass=ProviderMetaclass` argument from the definition of `Child`.
"""

class Parent(BaseModel):
test: Optional[str]

class Base(BaseModel):
test: ClassVar[str]

class Child(Base, Parent, metaclass=ProviderMetaclass):
test: ClassVar[str] = "expected"

assert Child.test == "expected"
16 changes: 4 additions & 12 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,16 @@ dynamic = ["version", "description", "authors", "urls", "keywords"]
dependencies = [
"ipython",
"importlib_metadata>=5.2.0",
"langchain==0.0.350",
"langchain-core>=0.1.0,<0.1.4",
"langchain>=0.1.0,<0.2.0",
"typing_extensions>=4.5.0",
"click~=8.0",
"jsonpath-ng>=1.5.3,<2",
]

[project.optional-dependencies]
dev = [
"pre-commit>=3.3.3,<4"
]
dev = ["pre-commit>=3.3.3,<4"]

test = [
"coverage",
"pytest",
"pytest-asyncio",
"pytest-cov"
]
test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"]

all = [
"ai21",
Expand All @@ -52,7 +44,7 @@ all = [
"pillow",
"openai~=1.6.1",
"boto3",
"qianfan"
"qianfan",
]

[project.entry-points."jupyter_ai.model_providers"]
Expand Down
7 changes: 3 additions & 4 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import argparse
import json
import os
from typing import Any, Awaitable, Coroutine, List, Optional, Tuple
from typing import Any, Coroutine, List, Optional, Tuple

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai.document_loaders.directory import get_embeddings, split
from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
from jupyter_ai.models import (
Expand All @@ -22,7 +21,7 @@
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter,
)
from langchain.vectorstores import FAISS
from langchain_community.vectorstores import FAISS

from .base import BaseChatHandler, SlashCommandRoutingType

Expand Down Expand Up @@ -143,7 +142,7 @@ def _build_list_response(self):
async def learn_dir(
self, path: str, chunk_size: int, chunk_overlap: int, all_files: bool
):
dask_client = await self.dask_client_future
dask_client: DaskClient = await self.dask_client_future
splitter_kwargs = {"chunk_size": chunk_size, "chunk_overlap": chunk_overlap}
splitters = {
".py": PythonCodeTextSplitter(**splitter_kwargs),
Expand Down
2 changes: 0 additions & 2 deletions packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ dependencies = [
"jupyterlab>=3.5,<4",
"aiosqlite>=0.18",
"importlib_metadata>=5.2.0",
"langchain==0.0.350",
"langchain-core>=0.1.0,<0.1.4",
"tiktoken", # required for OpenAIEmbeddings
"jupyter_ai_magics",
"dask[distributed]",
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ include_external_packages = true

[[tool.importlinter.contracts]]
key = "pydantic"
name = "Forbidden import of `pydantic` package. Please import from `langchain.pydantic_v1` instead for compatibility with both Pydantic v1 and v2."
name = "Forbid `pydantic`. (note: Developers should import Pydantic from `langchain.pydantic_v1` instead for compatibility.)"
type = "forbidden"
source_modules = ["jupyter_ai", "jupyter_ai_magics"]
forbidden_modules = ["pydantic"]
# TODO: get `langchain` to export `ModelMetaclass` to avoid needing this statement
ignore_imports = ["jupyter_ai_magics.providers -> pydantic"]

[tool.pytest.ini_options]
addopts = "--ignore packages/jupyter-ai-module-cookiecutter"

0 comments on commit 495c812

Please sign in to comment.