From 781693a7c05fc1636a3ff1da4f7fcf29cc9a93cf Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 1 Nov 2023 15:24:40 -0700 Subject: [PATCH] Added model_kwargs input to Bedrock provider. --- .../jupyter_ai_magics/providers.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 9fdbffa7a..725c57a38 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -646,8 +646,17 @@ class BedrockProvider(BaseProvider, Bedrock): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), + MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] + def __init__(self, *args, **kwargs): + model_kwargs = kwargs.pop("model_kwargs") + if model_kwargs and isinstance(model_kwargs, str): + model_kwargs = json.loads(model_kwargs) + super().__init__(*args, **kwargs, model_kwargs=model_kwargs) + else: + super().__init__(*args, **kwargs) + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) @@ -670,8 +679,17 @@ class BedrockChatProvider(BaseProvider, BedrockChat): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), + MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] + def __init__(self, *args, **kwargs): + model_kwargs = kwargs.pop("model_kwargs") + if model_kwargs and isinstance(model_kwargs, str): + model_kwargs = json.loads(model_kwargs) + super().__init__(*args, **kwargs, model_kwargs=model_kwargs) + else: + super().__init__(*args, **kwargs) + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs)