Skip to content

Commit

Permalink
Fix opentelemetry adapter (#510)
Browse files Browse the repository at this point in the history
# What does this PR do?

This PR fixes some of the issues with our telemetry setup to enable logs
to be delivered to opentelemetry and jaeger. Main fixes
1) Updates the open telemetry provider to use the latest oltp exports
instead of deprected ones.
2) Adds a tracing middleware, which injects traces into each HTTP
request that the server recieves and this is going to be the root trace.
Previously, we did this in the create_dynamic_route method, which is
actually not the actual exectuion flow, but more of a config and this
causes the traces to end prematurely. Through middleware, we plugin the
trace start and end at the right location.
3) We manage our own methods to create traces and spans and this does
not fit well with Opentelemetry SDK since it does not support provide a
way to take in traces and spans that are already created. it expects us
to use the SDK to create them. For now, I have a hacky approach of just
maintaining a map from our internal telemetry objects to the open
telemetry specfic ones. This is not the ideal solution. I will explore
other ways to get around this issue. for now, to have something that
works, i am going to keep this as is.

Addresses: #509
  • Loading branch information
dineshyv authored Nov 23, 2024
1 parent beab798 commit 501e7c9
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 217 deletions.
2 changes: 1 addition & 1 deletion llama_stack/apis/models/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def register_model(self, model: Model) -> None:
response = await client.post(
f"{self.base_url}/models/register",
json={
"model": json.loads(model.json()),
"model": json.loads(model.model_dump_json()),
},
headers={"Content-Type": "application/json"},
)
Expand Down
89 changes: 16 additions & 73 deletions llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

from contextlib import asynccontextmanager
from pathlib import Path
from ssl import SSLError
from typing import Any, Dict, Optional
from typing import Any, Union

import httpx
import yaml

from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
Expand All @@ -35,7 +33,6 @@
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
Expand Down Expand Up @@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)


async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})

headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})

content = await request.body()

client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)

async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk

await response.aclose()
await client.aclose()

return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)

except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)


def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")

Expand Down Expand Up @@ -217,7 +153,6 @@ async def maybe_await(value):


async def sse_generator(event_gen):
await start_trace("sse_generator")
try:
event_gen = await event_gen
async for item in event_gen:
Expand All @@ -235,14 +170,10 @@ async def sse_generator(event_gen):
},
}
)
finally:
await end_trace()


def create_dynamic_typed_route(func: Any, method: str):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)

set_request_provider_data(request.headers)

is_streaming = is_streaming_request(func.__name__, request, **kwargs)
Expand All @@ -257,8 +188,6 @@ async def endpoint(request: Request, **kwargs):
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()

sig = inspect.signature(func)
new_params = [
Expand All @@ -282,6 +211,19 @@ async def endpoint(request: Request, **kwargs):
return endpoint


class TracingMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path, {"location": "server"})
try:
return await self.app(scope, receive, send)
finally:
await end_trace()


def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
Expand Down Expand Up @@ -338,6 +280,7 @@ def main():
print(yaml.dump(config.model_dump(), indent=2))

app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)

try:
impls = asyncio.run(construct_stack(config))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def turn_to_messages(self, turn: Turn) -> List[Message]:
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages:
msg = m.copy()
msg = m.model_copy()
if isinstance(msg, UserMessage):
msg.context = None
messages.append(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def create_agent(

await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.json(),
value=agent_config.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def create_session(self, name: str) -> str:
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
value=session_info.model_dump_json(),
)
return session_id

Expand All @@ -60,13 +60,13 @@ async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
value=session_info.model_dump_json(),
)

async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(),
value=turn.model_dump_json(),
)

async def get_session_turns(self, session_id: str) -> List[Turn]:
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/inline/eval/meta_reference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def register_eval_task(self, task_def: EvalTask) -> None:
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set(
key=key,
value=task_def.json(),
value=task_def.model_dump_json(),
)
self.eval_tasks[task_def.identifier] = task_def

Expand Down
6 changes: 4 additions & 2 deletions llama_stack/providers/inline/memory/faiss/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ async def _save_index(self):
np.savetxt(buffer, np_index)
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
}

Expand Down Expand Up @@ -162,7 +164,7 @@ async def register_memory_bank(
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
await self.kvstore.set(
key=key,
value=memory_bank.json(),
value=memory_bank.model_dump_json(),
)

# Store in cache
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/memory/chroma/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def register_memory_bank(

collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.json()},
metadata={"bank": memory_bank.model_dump_json()},
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
Expand Down
21 changes: 18 additions & 3 deletions llama_stack/providers/remote/telemetry/opentelemetry/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from pydantic import BaseModel
from typing import Any, Dict

from pydantic import BaseModel, Field


class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost"
jaeger_port: int = 6831
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)

@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
}
Loading

0 comments on commit 501e7c9

Please sign in to comment.