diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 9fdbffa7a..725c57a38 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -646,8 +646,17 @@ class BedrockProvider(BaseProvider, Bedrock): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), + MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] + def __init__(self, *args, **kwargs): + model_kwargs = kwargs.pop("model_kwargs") + if model_kwargs and isinstance(model_kwargs, str): + model_kwargs = json.loads(model_kwargs) + super().__init__(*args, **kwargs, model_kwargs=model_kwargs) + else: + super().__init__(*args, **kwargs) + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) @@ -670,8 +679,17 @@ class BedrockChatProvider(BaseProvider, BedrockChat): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), + MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] + def __init__(self, *args, **kwargs): + model_kwargs = kwargs.pop("model_kwargs") + if model_kwargs and isinstance(model_kwargs, str): + model_kwargs = json.loads(model_kwargs) + super().__init__(*args, **kwargs, model_kwargs=model_kwargs) + else: + super().__init__(*args, **kwargs) + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs)