diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py index 04594c0405..56050411df 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore from .config import OpenTelemetryConfig async def get_adapter_impl(config: OpenTelemetryConfig, deps): from .opentelemetry import OpenTelemetryAdapter - trace_store = JaegerTraceStore(config.jaeger_query_endpoint, config.service_name) - impl = OpenTelemetryAdapter(config, trace_store, deps) + impl = OpenTelemetryAdapter(config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py index 81c1aed4fa..9d829d110d 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/config.py @@ -18,10 +18,18 @@ class OpenTelemetryConfig(BaseModel): default="llama-stack", description="The service name to use for telemetry", ) + trace_store: str = Field( + default="postgres", + description="The trace store to use for telemetry", + ) jaeger_query_endpoint: str = Field( default="http://localhost:16686/api/traces", description="The Jaeger query endpoint URL", ) + postgres_conn_string: str = Field( + default="host=localhost dbname=llama_stack user=llama_stack password=llama_stack port=5432", + description="The PostgreSQL connection string to use for storing traces", + ) @classmethod def sample_run_config(cls, **kwargs) -> Dict[str, Any]: diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index 3bea3b921f..59094a0804 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -18,6 +18,11 @@ from opentelemetry.semconv.resource import ResourceAttributes from llama_stack.distribution.datatypes import Api +from llama_stack.providers.remote.telemetry.opentelemetry.postgres_processor import ( + PostgresSpanProcessor, +) +from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore +from llama_stack.providers.utils.telemetry.postgres import PostgresTraceStore from llama_stack.apis.telemetry import * # noqa: F403 @@ -49,12 +54,18 @@ def is_tracing_enabled(tracer): class OpenTelemetryAdapter(Telemetry): - def __init__( - self, config: OpenTelemetryConfig, trace_store: TraceStore, deps - ) -> None: + def __init__(self, config: OpenTelemetryConfig, deps) -> None: self.config = config self.datasetio = deps[Api.datasetio] - self.trace_store = trace_store + + if config.trace_store == "jaeger": + self.trace_store = JaegerTraceStore( + config.jaeger_query_endpoint, config.service_name + ) + elif config.trace_store == "postgres": + self.trace_store = PostgresTraceStore(config.postgres_conn_string) + else: + raise ValueError(f"Invalid trace store: {config.trace_store}") resource = Resource.create( { @@ -69,6 +80,9 @@ def __init__( ) span_processor = BatchSpanProcessor(otlp_exporter) trace.get_tracer_provider().add_span_processor(span_processor) + trace.get_tracer_provider().add_span_processor( + PostgresSpanProcessor(self.config.postgres_conn_string) + ) # Set up metrics metric_reader = PeriodicExportingMetricReader( OTLPMetricExporter( @@ -252,8 +266,8 @@ def find_execute_turn_children(node: SpanNode) -> List[EvalTrace]: results.append( EvalTrace( step=child.span.name, - input=child.span.attributes.get("input", ""), - output=child.span.attributes.get("output", ""), + input=str(child.span.attributes.get("input", "")), + output=str(child.span.attributes.get("output", "")), session_id=session_id, expected_output="", ) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/postgres_processor.py b/llama_stack/providers/remote/telemetry/opentelemetry/postgres_processor.py new file mode 100644 index 0000000000..de8bf15b65 --- /dev/null +++ b/llama_stack/providers/remote/telemetry/opentelemetry/postgres_processor.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime + +import psycopg2 +from opentelemetry.sdk.trace import SpanProcessor +from opentelemetry.trace import Span + + +class PostgresSpanProcessor(SpanProcessor): + def __init__(self, conn_string): + """Initialize the PostgreSQL span processor with a connection string.""" + self.conn_string = conn_string + self.conn = None + self.setup_database() + + def setup_database(self): + """Create the necessary table if it doesn't exist.""" + with psycopg2.connect(self.conn_string) as conn: + with conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS traces ( + trace_id TEXT, + span_id TEXT, + parent_span_id TEXT, + name TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + attributes JSONB, + status TEXT, + kind TEXT, + service_name TEXT, + session_id TEXT + ) + """ + ) + conn.commit() + + def on_start(self, span: Span, parent_context=None): + """Called when a span starts.""" + pass + + def on_end(self, span: Span): + """Called when a span ends. Export the span data to PostgreSQL.""" + try: + with psycopg2.connect(self.conn_string) as conn: + with conn.cursor() as cur: + + cur.execute( + """ + INSERT INTO traces ( + trace_id, span_id, parent_span_id, name, + start_time, end_time, attributes, status, + kind, service_name, session_id + ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + ( + format(span.get_span_context().trace_id, "032x"), + format(span.get_span_context().span_id, "016x"), + ( + format(span.parent.span_id, "016x") + if span.parent + else None + ), + span.name, + datetime.fromtimestamp(span.start_time / 1e9), + datetime.fromtimestamp(span.end_time / 1e9), + json.dumps(dict(span.attributes)), + span.status.status_code.name, + span.kind.name, + span.resource.attributes.get("service.name", "unknown"), + span.attributes.get("session_id", None), + ), + ) + conn.commit() + except Exception as e: + print(f"Error exporting span to PostgreSQL: {e}") + + def shutdown(self): + """Cleanup any resources.""" + if self.conn: + self.conn.close() + + def force_flush(self, timeout_millis=30000): + """Force export of spans.""" + pass diff --git a/llama_stack/providers/utils/telemetry/postgres.py b/llama_stack/providers/utils/telemetry/postgres.py new file mode 100644 index 0000000000..ed68fc293b --- /dev/null +++ b/llama_stack/providers/utils/telemetry/postgres.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional + +import psycopg2 + +from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree + + +class PostgresTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def get_trace(self, trace_id: str) -> Optional[TraceTree]: + try: + with psycopg2.connect(self.conn_string) as conn: + with conn.cursor() as cur: + # Fetch all spans for the trace + cur.execute( + """ + SELECT trace_id, span_id, parent_span_id, name, + start_time, end_time, attributes + FROM traces + WHERE trace_id = %s + """, + (trace_id,), + ) + spans_data = cur.fetchall() + + if not spans_data: + return None + + # First pass: Build span map + span_map = {} + for span_data in spans_data: + # Ensure attributes is a string before parsing + attributes = span_data[6] + if isinstance(attributes, dict): + attributes = json.dumps(attributes) + + span = Span( + span_id=span_data[1], + trace_id=span_data[0], + name=span_data[3], + start_time=span_data[4], + end_time=span_data[5], + parent_span_id=span_data[2], + attributes=json.loads( + attributes + ), # Now safely parse the JSON string + ) + span_map[span.span_id] = SpanNode(span=span) + + # Second pass: Build parent-child relationships + root_node = None + for span_node in span_map.values(): + parent_id = span_node.span.parent_span_id + if parent_id and parent_id in span_map: + span_map[parent_id].children.append(span_node) + elif not parent_id: + root_node = span_node + + trace = Trace( + trace_id=trace_id, + root_span_id=root_node.span.span_id if root_node else "", + start_time=( + root_node.span.start_time if root_node else datetime.now() + ), + end_time=root_node.span.end_time if root_node else None, + ) + + return TraceTree(trace=trace, root=root_node) + + except Exception as e: + raise Exception( + f"Error querying PostgreSQL trace structure: {str(e)}" + ) from e + + async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]: + traces = [] + try: + with psycopg2.connect(self.conn_string) as conn: + with conn.cursor() as cur: + # Query traces for all session IDs + cur.execute( + """ + SELECT DISTINCT trace_id, MIN(start_time) as start_time + FROM traces + WHERE attributes->>'session_id' = ANY(%s) + GROUP BY trace_id + """, + (session_ids,), + ) + traces_data = cur.fetchall() + + for trace_data in traces_data: + traces.append( + Trace( + trace_id=trace_data[0], + root_span_id="", + start_time=trace_data[1], + ) + ) + + except Exception as e: + raise Exception(f"Error querying PostgreSQL traces: {str(e)}") from e + + return traces