diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py index a618b2f26..783af6362 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -1,12 +1,24 @@ import json +import inspect from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult +def requires_no_arguments(func): + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.default is param.empty and param.kind in ( + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + param.KEYWORD_ONLY + ): + return False + return True + def convert_to_serializable(obj): """Convert an object to a JSON serializable format""" - if hasattr(obj, "dict") and callable(obj.dict): + if hasattr(obj, "dict") and callable(obj.dict) and requires_no_arguments(obj.dict): return obj.dict() if hasattr(obj, "__dict__"): return obj.__dict__ @@ -37,12 +49,9 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None: metadata = response.generations[0][0].generation_info or {} # Convert any non-serializable objects in metadata - serializable_metadata = {} - for key, value in metadata.items(): - try: - json.dumps(value) - serializable_metadata[key] = value - except (TypeError, ValueError): - serializable_metadata[key] = convert_to_serializable(value) - - self.jai_metadata = serializable_metadata + self.jai_metadata = json.loads( + json.dumps( + metadata, + default=convert_to_serializable + ) + )