From 7dd4bfae5bfe0b26edd4b203facceede884ab4ed Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 11 Oct 2023 21:48:07 -0700 Subject: [PATCH] Refactored generate for better stability with all providers/models. --- .../jupyter_ai_magics/providers.py | 16 +++ .../jupyter_ai/chat_handlers/generate.py | 130 +++++++++--------- 2 files changed, 83 insertions(+), 63 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 5a77926c7..9fdbffa7a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -227,6 +227,10 @@ def get_prompt_template(self, format) -> PromptTemplate: def is_chat_provider(self): return isinstance(self, BaseChatModel) + @property + def allows_concurrency(self): + return True + class AI21Provider(BaseProvider, AI21): id = "ai21" @@ -267,6 +271,10 @@ class AnthropicProvider(BaseProvider, Anthropic): pypi_package_deps = ["anthropic"] auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + @property + def allows_concurrency(self): + return False + class ChatAnthropicProvider(BaseProvider, ChatAnthropic): id = "anthropic-chat" @@ -285,6 +293,10 @@ class ChatAnthropicProvider(BaseProvider, ChatAnthropic): pypi_package_deps = ["anthropic"] auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + @property + def allows_concurrency(self): + return False + class CohereProvider(BaseProvider, Cohere): id = "cohere" @@ -665,3 +677,7 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]: return await self._generate_in_executor(*args, **kwargs) + + @property + def allows_concurrency(self): + return not "anthropic" in self.model_id diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 1934c6fa5..b3a7c4335 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -1,8 +1,6 @@ import asyncio -import json import os -import re -from typing import Dict, Type +from typing import Dict, List, Optional, Type import nbformat from jupyter_ai.chat_handlers import BaseChatHandler @@ -10,71 +8,50 @@ from jupyter_ai_magics.providers import BaseProvider from langchain.chains import LLMChain from langchain.llms import BaseLLM +from langchain.output_parsers import PydanticOutputParser from langchain.prompts import PromptTemplate +from langchain.schema.output_parser import BaseOutputParser +from pydantic import BaseModel -schema = """{ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "description": { - "type": "string" - }, - "sections": { - "type": "array", - "items": { - "type": "object", - "properties": { - "title": { - "type": "string" - }, - "content": { - "type": "string" - } - }, - "required": ["title", "content"] - } - } - }, - "required": ["sections"] -}""" + +class OutlineSection(BaseModel): + title: str + content: str + + +class Outline(BaseModel): + description: Optional[str] = None + sections: List[OutlineSection] class NotebookOutlineChain(LLMChain): """Chain to generate a notebook outline, with section titles and descriptions.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: + def from_llm( + cls, llm: BaseLLM, parser: BaseOutputParser[Outline], verbose: bool = False + ) -> LLMChain: task_creation_template = ( "You are an AI that creates a detailed content outline for a Jupyter notebook on a given topic.\n" - "Generate the outline as JSON data that will validate against this JSON schema:\n" - "{schema}\n" + "{format_instructions}\n" "Here is a description of the notebook you will create an outline for: {description}\n" "Don't include an introduction or conclusion section in the outline, focus only on description and sections that will need code.\n" ) prompt = PromptTemplate( template=task_creation_template, - input_variables=["description", "schema"], + input_variables=["description"], + partial_variables={"format_instructions": parser.get_format_instructions()}, ) return cls(prompt=prompt, llm=llm, verbose=verbose) -def extract_json(text: str) -> str: - """Extract json from text using Regex.""" - # The pattern to find json string enclosed in ```json```` - pattern = r"```json\n(.*?)\n```" - - # Find all matches in the input text - matches = re.findall(pattern, text, re.DOTALL) - - return matches[0] if matches else text - - async def generate_outline(description, llm=None, verbose=False): """Generate an outline of sections given a description of a notebook.""" - chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose) - outline = await chain.apredict(description=description, schema=schema) - outline = extract_json(outline) - return json.loads(outline) + parser = PydanticOutputParser(pydantic_object=Outline) + chain = NotebookOutlineChain.from_llm(llm=llm, parser=parser, verbose=verbose) + outline = await chain.apredict(description=description) + outline = parser.parse(outline) + return outline.dict() class CodeImproverChain(LLMChain): @@ -141,7 +118,8 @@ class NotebookTitleChain(LLMChain): def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain: task_creation_template = ( "Create a short, few word, descriptive title for a Jupyter notebook with the following content.\n" - "Content:\n{content}" + "Content:\n{content}\n" + "Don't return anything other than the title." ) prompt = PromptTemplate( template=task_creation_template, @@ -178,7 +156,7 @@ async def generate_code(section, description, llm=None, verbose=False) -> None: async def generate_title(outline, llm=None, verbose: bool = False): - """Generate a title and summary of a notebook outline using an LLM.""" + """Generate a title of a notebook outline using an LLM.""" title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose) title = await title_chain.apredict(content=outline) title = title.strip() @@ -187,12 +165,14 @@ async def generate_title(outline, llm=None, verbose: bool = False): async def generate_summary(outline, llm=None, verbose: bool = False): + """Generate a summary of a notebook using an LLM.""" summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose) summary = await summary_chain.apredict(content=outline) outline["summary"] = summary async def fill_outline(outline, llm, verbose=False): + """Generate title and content of a notebook sections using an LLM.""" shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose} await generate_title(**shared_kwargs) @@ -201,6 +181,20 @@ async def fill_outline(outline, llm, verbose=False): await generate_code(section, outline["description"], llm=llm, verbose=verbose) +async def afill_outline(outline, llm, verbose=False): + """Concurrently generate title and content of notebook sections using an LLM.""" + shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose} + + all_coros = [] + all_coros.append(generate_title(**shared_kwargs)) + all_coros.append(generate_summary(**shared_kwargs)) + for section in outline["sections"]: + all_coros.append( + generate_code(section, outline["description"], llm=llm, verbose=verbose) + ) + await asyncio.gather(*all_coros) + + def create_notebook(outline): """Create an nbformat Notebook object for a notebook outline.""" nbf = nbformat.v4 @@ -233,29 +227,39 @@ def create_llm_chain( self.llm = llm return llm - async def _process_message(self, message: HumanChatMessage): - self.get_llm_chain() - - # first send a verification message to user - response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." - self.reply(response, message) + async def _generate_notebook(self, prompt: str): + """Generate a notebook and save to local disk""" - # generate notebook outline - prompt = message.body + # create outline outline = await generate_outline(prompt, llm=self.llm, verbose=True) # Save the user input prompt, the description property is now LLM generated. outline["prompt"] = prompt - # fill the outline concurrently - await fill_outline(outline, llm=self.llm, verbose=True) + if self.llm.allows_concurrency: + # fill the outline concurrently + await afill_outline(outline, llm=self.llm, verbose=True) + else: + # fill outline + await fill_outline(outline, llm=self.llm, verbose=True) # create and write the notebook to disk notebook = create_notebook(outline) final_path = os.path.join(self.root_dir, outline["title"] + ".ipynb") nbformat.write(notebook, final_path) - response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" - self.reply(response, message) + return final_path + + async def _process_message(self, message: HumanChatMessage): + self.get_llm_chain() + # first send a verification message to user + response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." + self.reply(response, message) -# /generate notebook -# Error handling + try: + final_path = await self._generate_notebook(prompt=message.body) + response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" + except Exception as e: + self.log.exception(e) + response = "An error occurred while generating the notebook. Try running the /generate task again." + finally: + self.reply(response, message)