Skip to content

Commit

Permalink
Updates for more stable generate feature
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins authored and Marchlak committed Oct 28, 2024
1 parent 7900aff commit bc25e42
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import os
import re
from typing import Dict, Type

import nbformat
Expand Down Expand Up @@ -48,7 +49,7 @@ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
"Generate the outline as JSON data that will validate against this JSON schema:\n"
"{schema}\n"
"Here is a description of the notebook you will create an outline for: {description}\n"
"Don't include an introduction or conclusion section in the outline, focus only on sections that will need code."
"Don't include an introduction or conclusion section in the outline, focus only on description and sections that will need code.\n"
)
prompt = PromptTemplate(
template=task_creation_template,
Expand All @@ -57,10 +58,22 @@ def from_llm(cls, llm: BaseLLM, verbose: bool = False) -> LLMChain:
return cls(prompt=prompt, llm=llm, verbose=verbose)


def extract_json(text: str) -> str:
"""Extract json from text using Regex."""
# The pattern to find json string enclosed in ```json````
pattern = r"```json\n(.*?)\n```"

# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)

return matches[0] if matches else text


async def generate_outline(description, llm=None, verbose=False):
"""Generate an outline of sections given a description of a notebook."""
chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose)
outline = await chain.apredict(description=description, schema=schema)
outline = extract_json(outline)
return json.loads(outline)


Expand Down Expand Up @@ -182,14 +195,10 @@ async def generate_summary(outline, llm=None, verbose: bool = False):
async def fill_outline(outline, llm, verbose=False):
shared_kwargs = {"outline": outline, "llm": llm, "verbose": verbose}

all_coros = []
all_coros.append(generate_title(**shared_kwargs))
all_coros.append(generate_summary(**shared_kwargs))
await generate_title(**shared_kwargs)
await generate_summary(**shared_kwargs)
for section in outline["sections"]:
all_coros.append(
generate_code(section, outline["description"], llm=llm, verbose=verbose)
)
await asyncio.gather(*all_coros)
await generate_code(section, outline["description"], llm=llm, verbose=verbose)


def create_notebook(outline):
Expand Down

0 comments on commit bc25e42

Please sign in to comment.