Skip to content

Commit

Permalink
Added model_kwargs input to Bedrock provider.
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Nov 1, 2023
1 parent 92dab10 commit 781693a
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,17 @@ class BedrockProvider(BaseProvider, Bedrock):
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"),
]

def __init__(self, *args, **kwargs):
model_kwargs = kwargs.pop("model_kwargs")
if model_kwargs and isinstance(model_kwargs, str):
model_kwargs = json.loads(model_kwargs)
super().__init__(*args, **kwargs, model_kwargs=model_kwargs)
else:
super().__init__(*args, **kwargs)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

Expand All @@ -670,8 +679,17 @@ class BedrockChatProvider(BaseProvider, BedrockChat):
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"),
]

def __init__(self, *args, **kwargs):
model_kwargs = kwargs.pop("model_kwargs")
if model_kwargs and isinstance(model_kwargs, str):
model_kwargs = json.loads(model_kwargs)
super().__init__(*args, **kwargs, model_kwargs=model_kwargs)
else:
super().__init__(*args, **kwargs)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

Expand Down

0 comments on commit 781693a

Please sign in to comment.