Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor generate for better stability with all providers/models #407

Merged
merged 4 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def get_prompt_template(self, format) -> PromptTemplate:
def is_chat_provider(self):
return isinstance(self, BaseChatModel)

@property
def allows_concurrency(self):
return True


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down Expand Up @@ -267,6 +271,10 @@ class AnthropicProvider(BaseProvider, Anthropic):
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")

@property
def allows_concurrency(self):
return False


class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
id = "anthropic-chat"
Expand All @@ -285,6 +293,10 @@ class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")

@property
def allows_concurrency(self):
return False


class CohereProvider(BaseProvider, Cohere):
id = "cohere"
Expand Down Expand Up @@ -665,3 +677,7 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:

async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
return await self._generate_in_executor(*args, **kwargs)

@property
def allows_concurrency(self):
return not "anthropic" in self.model_id
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"ipython",
"pydantic~=1.0",
"importlib_metadata>=5.2.0",
"langchain==0.0.308",
"langchain==0.0.318",
"typing_extensions>=4.5.0",
"click~=8.0",
"jsonpath-ng>=1.5.3,<2",
Expand Down
28 changes: 19 additions & 9 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler

PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)


class AskChatHandler(BaseChatHandler):
"""Processes messages prefixed with /ask. This actor will
Expand All @@ -27,9 +37,15 @@ def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
self.chat_history = []
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm, self._retriever, verbose=True
self.llm,
self._retriever,
memory=memory,
condense_question_prompt=CONDENSE_PROMPT,
verbose=False,
)

async def _process_message(self, message: HumanChatMessage):
Expand All @@ -44,14 +60,8 @@ async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
# limit chat history to last 2 exchanges
self.chat_history = self.chat_history[-2:]

result = await self.llm_chain.acall(
{"question": query, "chat_history": self.chat_history}
)
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
Expand Down
115 changes: 64 additions & 51 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,57 @@
import asyncio
import json
import os
from typing import Dict, Type
from typing import Dict, List, Optional, Type

import nbformat
from jupyter_ai.chat_handlers import BaseChatHandler
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.llms import BaseLLM
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import BaseOutputParser
from pydantic import BaseModel

schema = """{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"description": {
"type": "string"
},
"sections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {
"type": "string"
},
"content": {
"type": "string"
}
},
"required": ["title", "content"]
}
}
},
"required": ["sections"]
}"""

class OutlineSection(BaseModel):
title: str
content: str


class Outline(BaseModel):
description: Optional[str] = None
sections: List[OutlineSection]


class NotebookOutlineChain(LLMChain):
"""Chain to generate a notebook outline, with section titles and descriptions."""

@classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
def from_llm(
cls, llm: BaseLLM, parser: BaseOutputParser[Outline], verbose: bool = False
) -> LLMChain:
task_creation_template = (
"You are an AI that creates a detailed content outline for a Jupyter notebook on a given topic.\n"
"Generate the outline as JSON data that will validate against this JSON schema:\n"
"{schema}\n"
"{format_instructions}\n"
"Here is a description of the notebook you will create an outline for: {description}\n"
"Don't include an introduction or conclusion section in the outline, focus only on sections that will need code."
"Don't include an introduction or conclusion section in the outline, focus only on description and sections that will need code.\n"
)
prompt = PromptTemplate(
template=task_creation_template,
input_variables=["description", "schema"],
input_variables=["description"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
return cls(prompt=prompt, llm=llm, verbose=verbose)


async def generate_outline(description, llm=None, verbose=False):
"""Generate an outline of sections given a description of a notebook."""
chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose)
outline = await chain.apredict(description=description, schema=schema)
return json.loads(outline)
parser = PydanticOutputParser(pydantic_object=Outline)
chain = NotebookOutlineChain.from_llm(llm=llm, parser=parser, verbose=verbose)
outline = await chain.apredict(description=description)
outline = parser.parse(outline)
return outline.dict()


class CodeImproverChain(LLMChain):
Expand Down Expand Up @@ -128,7 +118,8 @@ class NotebookTitleChain(LLMChain):
def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
task_creation_template = (
"Create a short, few word, descriptive title for a Jupyter notebook with the following content.\n"
"Content:\n{content}"
"Content:\n{content}\n"
"Don't return anything other than the title."
)
prompt = PromptTemplate(
template=task_creation_template,
Expand Down Expand Up @@ -165,7 +156,7 @@ async def generate_code(section, description, llm=None, verbose=False) -> None:


async def generate_title(outline, llm=None, verbose: bool = False):
"""Generate a title and summary of a notebook outline using an LLM."""
"""Generate a title of a notebook outline using an LLM."""
title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose)
title = await title_chain.apredict(content=outline)
title = title.strip()
Expand All @@ -174,12 +165,24 @@ async def generate_title(outline, llm=None, verbose: bool = False):


async def generate_summary(outline, llm=None, verbose: bool = False):
"""Generate a summary of a notebook using an LLM."""
summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose)
summary = await summary_chain.apredict(content=outline)
outline["summary"] = summary


async def fill_outline(outline, llm, verbose=False):
"""Generate title and content of a notebook sections using an LLM."""
shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose}

await generate_title(**shared_kwargs)
await generate_summary(**shared_kwargs)
for section in outline["sections"]:
await generate_code(section, outline["description"], llm=llm, verbose=verbose)


async def afill_outline(outline, llm, verbose=False):
"""Concurrently generate title and content of notebook sections using an LLM."""
shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose}

all_coros = []
Expand Down Expand Up @@ -224,29 +227,39 @@ def create_llm_chain(
self.llm = llm
return llm

async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()
async def _generate_notebook(self, prompt: str):
"""Generate a notebook and save to local disk"""

# first send a verification message to user
response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions."
self.reply(response, message)

# generate notebook outline
prompt = message.body
# create outline
outline = await generate_outline(prompt, llm=self.llm, verbose=True)
# Save the user input prompt, the description property is now LLM generated.
outline["prompt"] = prompt

# fill the outline concurrently
await fill_outline(outline, llm=self.llm, verbose=True)
if self.llm.allows_concurrency:
# fill the outline concurrently
await afill_outline(outline, llm=self.llm, verbose=True)
else:
# fill outline
await fill_outline(outline, llm=self.llm, verbose=True)

# create and write the notebook to disk
notebook = create_notebook(outline)
final_path = os.path.join(self.root_dir, outline["title"] + ".ipynb")
nbformat.write(notebook, final_path)
response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it."""
self.reply(response, message)
return final_path

async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

# first send a verification message to user
response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions."
self.reply(response, message)

# /generate notebook
# Error handling
try:
final_path = await self._generate_notebook(prompt=message.body)
response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it."""
except Exception as e:
self.log.exception(e)
response = "An error occurred while generating the notebook. Try running the /generate task again."
finally:
self.reply(response, message)
2 changes: 1 addition & 1 deletion packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"openai~=0.26",
"aiosqlite>=0.18",
"importlib_metadata>=5.2.0",
"langchain==0.0.308",
"langchain==0.0.318",
"tiktoken", # required for OpenAIEmbeddings
"jupyter_ai_magics",
"dask[distributed]",
Expand Down
Loading