Skip to content

Commit

Permalink
Refactored generate for better stability with all providers/models.
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Oct 12, 2023
1 parent 21cf9a4 commit 7dd4bfa
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 63 deletions.
16 changes: 16 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def get_prompt_template(self, format) -> PromptTemplate:
def is_chat_provider(self):
return isinstance(self, BaseChatModel)

@property
def allows_concurrency(self):
return True


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down Expand Up @@ -267,6 +271,10 @@ class AnthropicProvider(BaseProvider, Anthropic):
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")

@property
def allows_concurrency(self):
return False


class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
id = "anthropic-chat"
Expand All @@ -285,6 +293,10 @@ class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")

@property
def allows_concurrency(self):
return False


class CohereProvider(BaseProvider, Cohere):
id = "cohere"
Expand Down Expand Up @@ -665,3 +677,7 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:

async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
return await self._generate_in_executor(*args, **kwargs)

@property
def allows_concurrency(self):
return not "anthropic" in self.model_id
130 changes: 67 additions & 63 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,57 @@
import asyncio
import json
import os
import re
from typing import Dict, Type
from typing import Dict, List, Optional, Type

import nbformat
from jupyter_ai.chat_handlers import BaseChatHandler
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.llms import BaseLLM
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import BaseOutputParser
from pydantic import BaseModel

schema = """{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"description": {
"type": "string"
},
"sections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {
"type": "string"
},
"content": {
"type": "string"
}
},
"required": ["title", "content"]
}
}
},
"required": ["sections"]
}"""

class OutlineSection(BaseModel):
title: str
content: str


class Outline(BaseModel):
description: Optional[str] = None
sections: List[OutlineSection]


class NotebookOutlineChain(LLMChain):
"""Chain to generate a notebook outline, with section titles and descriptions."""

@classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
def from_llm(
cls, llm: BaseLLM, parser: BaseOutputParser[Outline], verbose: bool = False
) -> LLMChain:
task_creation_template = (
"You are an AI that creates a detailed content outline for a Jupyter notebook on a given topic.\n"
"Generate the outline as JSON data that will validate against this JSON schema:\n"
"{schema}\n"
"{format_instructions}\n"
"Here is a description of the notebook you will create an outline for: {description}\n"
"Don't include an introduction or conclusion section in the outline, focus only on description and sections that will need code.\n"
)
prompt = PromptTemplate(
template=task_creation_template,
input_variables=["description", "schema"],
input_variables=["description"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
return cls(prompt=prompt, llm=llm, verbose=verbose)


def extract_json(text: str) -> str:
"""Extract json from text using Regex."""
# The pattern to find json string enclosed in ```json````
pattern = r"```json\n(.*?)\n```"

# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)

return matches[0] if matches else text


async def generate_outline(description, llm=None, verbose=False):
"""Generate an outline of sections given a description of a notebook."""
chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose)
outline = await chain.apredict(description=description, schema=schema)
outline = extract_json(outline)
return json.loads(outline)
parser = PydanticOutputParser(pydantic_object=Outline)
chain = NotebookOutlineChain.from_llm(llm=llm, parser=parser, verbose=verbose)
outline = await chain.apredict(description=description)
outline = parser.parse(outline)
return outline.dict()


class CodeImproverChain(LLMChain):
Expand Down Expand Up @@ -141,7 +118,8 @@ class NotebookTitleChain(LLMChain):
def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
task_creation_template = (
"Create a short, few word, descriptive title for a Jupyter notebook with the following content.\n"
"Content:\n{content}"
"Content:\n{content}\n"
"Don't return anything other than the title."
)
prompt = PromptTemplate(
template=task_creation_template,
Expand Down Expand Up @@ -178,7 +156,7 @@ async def generate_code(section, description, llm=None, verbose=False) -> None:


async def generate_title(outline, llm=None, verbose: bool = False):
"""Generate a title and summary of a notebook outline using an LLM."""
"""Generate a title of a notebook outline using an LLM."""
title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose)
title = await title_chain.apredict(content=outline)
title = title.strip()
Expand All @@ -187,12 +165,14 @@ async def generate_title(outline, llm=None, verbose: bool = False):


async def generate_summary(outline, llm=None, verbose: bool = False):
"""Generate a summary of a notebook using an LLM."""
summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose)
summary = await summary_chain.apredict(content=outline)
outline["summary"] = summary


async def fill_outline(outline, llm, verbose=False):
"""Generate title and content of a notebook sections using an LLM."""
shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose}

await generate_title(**shared_kwargs)
Expand All @@ -201,6 +181,20 @@ async def fill_outline(outline, llm, verbose=False):
await generate_code(section, outline["description"], llm=llm, verbose=verbose)


async def afill_outline(outline, llm, verbose=False):
"""Concurrently generate title and content of notebook sections using an LLM."""
shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose}

all_coros = []
all_coros.append(generate_title(**shared_kwargs))
all_coros.append(generate_summary(**shared_kwargs))
for section in outline["sections"]:
all_coros.append(
generate_code(section, outline["description"], llm=llm, verbose=verbose)
)
await asyncio.gather(*all_coros)


def create_notebook(outline):
"""Create an nbformat Notebook object for a notebook outline."""
nbf = nbformat.v4
Expand Down Expand Up @@ -233,29 +227,39 @@ def create_llm_chain(
self.llm = llm
return llm

async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

# first send a verification message to user
response = "πŸ‘ Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions."
self.reply(response, message)
async def _generate_notebook(self, prompt: str):
"""Generate a notebook and save to local disk"""

# generate notebook outline
prompt = message.body
# create outline
outline = await generate_outline(prompt, llm=self.llm, verbose=True)
# Save the user input prompt, the description property is now LLM generated.
outline["prompt"] = prompt

# fill the outline concurrently
await fill_outline(outline, llm=self.llm, verbose=True)
if self.llm.allows_concurrency:
# fill the outline concurrently
await afill_outline(outline, llm=self.llm, verbose=True)
else:
# fill outline
await fill_outline(outline, llm=self.llm, verbose=True)

# create and write the notebook to disk
notebook = create_notebook(outline)
final_path = os.path.join(self.root_dir, outline["title"] + ".ipynb")
nbformat.write(notebook, final_path)
response = f"""πŸŽ‰ I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it."""
self.reply(response, message)
return final_path

async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

# first send a verification message to user
response = "πŸ‘ Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions."
self.reply(response, message)

# /generate notebook
# Error handling
try:
final_path = await self._generate_notebook(prompt=message.body)
response = f"""πŸŽ‰ I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it."""
except Exception as e:
self.log.exception(e)
response = "An error occurred while generating the notebook. Try running the /generate task again."
finally:
self.reply(response, message)

0 comments on commit 7dd4bfa

Please sign in to comment.