From bc3ba6b919be6354c6f9d1aa4d8c733327f4e770 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Nov 2024 10:58:26 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jupyter_ai/callback_handlers/metadata.py | 11 ++++++----- packages/jupyter-ai/jupyter_ai/models.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py index 811155113..a618b2f26 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -1,13 +1,14 @@ +import json + from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult -import json def convert_to_serializable(obj): """Convert an object to a JSON serializable format""" - if hasattr(obj, 'dict') and callable(obj.dict): + if hasattr(obj, "dict") and callable(obj.dict): return obj.dict() - if hasattr(obj, '__dict__'): + if hasattr(obj, "__dict__"): return obj.__dict__ return str(obj) @@ -34,7 +35,7 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None: return metadata = response.generations[0][0].generation_info or {} - + # Convert any non-serializable objects in metadata serializable_metadata = {} for key, value in metadata.items(): @@ -43,5 +44,5 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None: serializable_metadata[key] = value except (TypeError, ValueError): serializable_metadata[key] = convert_to_serializable(value) - + self.jai_metadata = serializable_metadata diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 4cb071ed9..6bd7d4e06 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, List, Literal, Optional, Union import json +from typing import Any, Dict, List, Literal, Optional, Union from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import AuthStrategy, Field @@ -128,7 +128,7 @@ class AgentStreamChunkMessage(BaseModel): chunk should override any metadata from previous chunks. See the docstring on `BaseAgentMessage.metadata` for information. """ - + @validator("metadata") def validate_metadata(cls, v): """Ensure metadata values are JSON serializable"""