Skip to content

Commit

Permalink
Merge branch 'main' into fix-completion-fields
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq authored Dec 5, 2024
2 parents 0205824 + 3a8016a commit 5e19d8f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
31 changes: 30 additions & 1 deletion packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
import inspect
import json

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult


def requires_no_arguments(func):
sig = inspect.signature(func)
for param in sig.parameters.values():
if param.default is param.empty and param.kind in (
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
param.KEYWORD_ONLY,
):
return False
return True


def convert_to_serializable(obj):
"""Convert an object to a JSON serializable format"""
if hasattr(obj, "dict") and callable(obj.dict) and requires_no_arguments(obj.dict):
return obj.dict()
if hasattr(obj, "__dict__"):
return obj.__dict__
return str(obj)


class MetadataCallbackHandler(BaseCallbackHandler):
"""
When passed as a callback handler, this stores the LLMResult's
Expand All @@ -23,4 +47,9 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None:
if not (len(response.generations) and len(response.generations[0])):
return

self.jai_metadata = response.generations[0][0].generation_info or {}
metadata = response.generations[0][0].generation_info or {}

# Convert any non-serializable objects in metadata
self.jai_metadata = json.loads(
json.dumps(metadata, default=convert_to_serializable)
)
10 changes: 10 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Dict, List, Literal, Optional, Union

from jupyter_ai_magics import Persona
Expand Down Expand Up @@ -128,6 +129,15 @@ class AgentStreamChunkMessage(BaseModel):
on `BaseAgentMessage.metadata` for information.
"""

@validator("metadata")
def validate_metadata(cls, v):
"""Ensure metadata values are JSON serializable"""
try:
json.dumps(v)
return v
except TypeError as e:
raise ValueError(f"Metadata must be JSON serializable: {str(e)}")


class HumanChatMessage(BaseModel):
type: Literal["human"] = "human"
Expand Down

0 comments on commit 5e19d8f

Please sign in to comment.