Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 30, 2024
1 parent 5d8263f commit bc3ba6b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
import json


def convert_to_serializable(obj):
"""Convert an object to a JSON serializable format"""
if hasattr(obj, 'dict') and callable(obj.dict):
if hasattr(obj, "dict") and callable(obj.dict):
return obj.dict()
if hasattr(obj, '__dict__'):
if hasattr(obj, "__dict__"):
return obj.__dict__
return str(obj)

Expand All @@ -34,7 +35,7 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None:
return

metadata = response.generations[0][0].generation_info or {}

# Convert any non-serializable objects in metadata
serializable_metadata = {}
for key, value in metadata.items():
Expand All @@ -43,5 +44,5 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None:
serializable_metadata[key] = value
except (TypeError, ValueError):
serializable_metadata[key] = convert_to_serializable(value)

self.jai_metadata = serializable_metadata
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Dict, List, Literal, Optional, Union
import json
from typing import Any, Dict, List, Literal, Optional, Union

from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import AuthStrategy, Field
Expand Down Expand Up @@ -128,7 +128,7 @@ class AgentStreamChunkMessage(BaseModel):
chunk should override any metadata from previous chunks. See the docstring
on `BaseAgentMessage.metadata` for information.
"""

@validator("metadata")
def validate_metadata(cls, v):
"""Ensure metadata values are JSON serializable"""
Expand Down

0 comments on commit bc3ba6b

Please sign in to comment.