From c79f35554760b57b65060ad133c633346cb635a7 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 5 Oct 2023 16:16:37 -0700 Subject: [PATCH] Added missing code from backport. --- .../jupyter_ai_magics/providers.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index df4e22e26..5a77926c7 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -135,6 +135,10 @@ class Config: # instance attrs # model_id: str + prompt_templates: Dict[str, PromptTemplate] + """Prompt templates for each output type. Can be overridden with + `update_prompt_template`. The function `prompt_template`, in the base class, + refers to this.""" def __init__(self, *args, **kwargs): try: @@ -148,6 +152,36 @@ def __init__(self, *args, **kwargs): if self.__class__.model_id_key != "model_id": model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] + model_kwargs["prompt_templates"] = { + "code": PromptTemplate.from_template( + "{prompt}\n\nProduce output as source code only, " + "with no text or explanation before or after it." + ), + "html": PromptTemplate.from_template( + "{prompt}\n\nProduce output in HTML format only, " + "with no markup before or afterward." + ), + "image": PromptTemplate.from_template( + "{prompt}\n\nProduce output as an image only, " + "with no text before or after it." + ), + "markdown": PromptTemplate.from_template( + "{prompt}\n\nProduce output in markdown format only." + ), + "md": PromptTemplate.from_template( + "{prompt}\n\nProduce output in markdown format only." + ), + "math": PromptTemplate.from_template( + "{prompt}\n\nProduce output in LaTeX format only, " + "with $$ at the beginning and end." + ), + "json": PromptTemplate.from_template( + "{prompt}\n\nProduce output in JSON format only, " + "with nothing before or after it." + ), + "text": PromptTemplate.from_template("{prompt}"), # No customization + } + super().__init__(*args, **kwargs, **model_kwargs) async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]: