Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracing postgres #556

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore
from .config import OpenTelemetryConfig


async def get_adapter_impl(config: OpenTelemetryConfig, deps):
from .opentelemetry import OpenTelemetryAdapter

trace_store = JaegerTraceStore(config.jaeger_query_endpoint, config.service_name)
impl = OpenTelemetryAdapter(config, trace_store, deps)
impl = OpenTelemetryAdapter(config, deps)
await impl.initialize()
return impl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@ class OpenTelemetryConfig(BaseModel):
default="llama-stack",
description="The service name to use for telemetry",
)
trace_store: str = Field(
default="postgres",
description="The trace store to use for telemetry",
)
jaeger_query_endpoint: str = Field(
default="http://localhost:16686/api/traces",
description="The Jaeger query endpoint URL",
)
postgres_conn_string: str = Field(
default="host=localhost dbname=llama_stack user=llama_stack password=llama_stack port=5432",
description="The PostgreSQL connection string to use for storing traces",
)

@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from opentelemetry.semconv.resource import ResourceAttributes

from llama_stack.distribution.datatypes import Api
from llama_stack.providers.remote.telemetry.opentelemetry.postgres_processor import (
PostgresSpanProcessor,
)
from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore
from llama_stack.providers.utils.telemetry.postgres import PostgresTraceStore


from llama_stack.apis.telemetry import * # noqa: F403
Expand Down Expand Up @@ -49,12 +54,18 @@ def is_tracing_enabled(tracer):


class OpenTelemetryAdapter(Telemetry):
def __init__(
self, config: OpenTelemetryConfig, trace_store: TraceStore, deps
) -> None:
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
self.config = config
self.datasetio = deps[Api.datasetio]
self.trace_store = trace_store

if config.trace_store == "jaeger":
self.trace_store = JaegerTraceStore(
config.jaeger_query_endpoint, config.service_name
)
elif config.trace_store == "postgres":
self.trace_store = PostgresTraceStore(config.postgres_conn_string)
else:
raise ValueError(f"Invalid trace store: {config.trace_store}")

resource = Resource.create(
{
Expand All @@ -69,6 +80,9 @@ def __init__(
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
trace.get_tracer_provider().add_span_processor(
PostgresSpanProcessor(self.config.postgres_conn_string)
)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
Expand Down Expand Up @@ -252,8 +266,8 @@ def find_execute_turn_children(node: SpanNode) -> List[EvalTrace]:
results.append(
EvalTrace(
step=child.span.name,
input=child.span.attributes.get("input", ""),
output=child.span.attributes.get("output", ""),
input=str(child.span.attributes.get("input", "")),
output=str(child.span.attributes.get("output", "")),
session_id=session_id,
expected_output="",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
from datetime import datetime

import psycopg2
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span


class PostgresSpanProcessor(SpanProcessor):
def __init__(self, conn_string):
"""Initialize the PostgreSQL span processor with a connection string."""
self.conn_string = conn_string
self.conn = None
self.setup_database()

def setup_database(self):
"""Create the necessary table if it doesn't exist."""
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS traces (
trace_id TEXT,
span_id TEXT,
parent_span_id TEXT,
name TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
attributes JSONB,
status TEXT,
kind TEXT,
service_name TEXT,
session_id TEXT
)
"""
)
conn.commit()

def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass

def on_end(self, span: Span):
"""Called when a span ends. Export the span data to PostgreSQL."""
try:
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:

cur.execute(
"""
INSERT INTO traces (
trace_id, span_id, parent_span_id, name,
start_time, end_time, attributes, status,
kind, service_name, session_id
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(
format(span.get_span_context().trace_id, "032x"),
format(span.get_span_context().span_id, "016x"),
(
format(span.parent.span_id, "016x")
if span.parent
else None
),
span.name,
datetime.fromtimestamp(span.start_time / 1e9),
datetime.fromtimestamp(span.end_time / 1e9),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
span.resource.attributes.get("service.name", "unknown"),
span.attributes.get("session_id", None),
),
)
conn.commit()
except Exception as e:
print(f"Error exporting span to PostgreSQL: {e}")

def shutdown(self):
"""Cleanup any resources."""
if self.conn:
self.conn.close()

def force_flush(self, timeout_millis=30000):
"""Force export of spans."""
pass
114 changes: 114 additions & 0 deletions llama_stack/providers/utils/telemetry/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
from datetime import datetime
from typing import List, Optional

import psycopg2

from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree


class PostgresTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string

async def get_trace(self, trace_id: str) -> Optional[TraceTree]:
try:
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:
# Fetch all spans for the trace
cur.execute(
"""
SELECT trace_id, span_id, parent_span_id, name,
start_time, end_time, attributes
FROM traces
WHERE trace_id = %s
""",
(trace_id,),
)
spans_data = cur.fetchall()

if not spans_data:
return None

# First pass: Build span map
span_map = {}
for span_data in spans_data:
# Ensure attributes is a string before parsing
attributes = span_data[6]
if isinstance(attributes, dict):
attributes = json.dumps(attributes)

span = Span(
span_id=span_data[1],
trace_id=span_data[0],
name=span_data[3],
start_time=span_data[4],
end_time=span_data[5],
parent_span_id=span_data[2],
attributes=json.loads(
attributes
), # Now safely parse the JSON string
)
span_map[span.span_id] = SpanNode(span=span)

# Second pass: Build parent-child relationships
root_node = None
for span_node in span_map.values():
parent_id = span_node.span.parent_span_id
if parent_id and parent_id in span_map:
span_map[parent_id].children.append(span_node)
elif not parent_id:
root_node = span_node

trace = Trace(
trace_id=trace_id,
root_span_id=root_node.span.span_id if root_node else "",
start_time=(
root_node.span.start_time if root_node else datetime.now()
),
end_time=root_node.span.end_time if root_node else None,
)

return TraceTree(trace=trace, root=root_node)

except Exception as e:
raise Exception(
f"Error querying PostgreSQL trace structure: {str(e)}"
) from e

async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]:
traces = []
try:
with psycopg2.connect(self.conn_string) as conn:
with conn.cursor() as cur:
# Query traces for all session IDs
cur.execute(
"""
SELECT DISTINCT trace_id, MIN(start_time) as start_time
FROM traces
WHERE attributes->>'session_id' = ANY(%s)
GROUP BY trace_id
""",
(session_ids,),
)
traces_data = cur.fetchall()

for trace_data in traces_data:
traces.append(
Trace(
trace_id=trace_data[0],
root_span_id="",
start_time=trace_data[1],
)
)

except Exception as e:
raise Exception(f"Error querying PostgreSQL traces: {str(e)}") from e

return traces
Loading