From a4db2ba7830de53d1e370ab61860b76d53394cae Mon Sep 17 00:00:00 2001 From: duncanpo Date: Tue, 27 Aug 2024 10:09:07 -0400 Subject: [PATCH] Add proper support to TraceState, fixes #141 --- api/trace/+opentelemetry/+trace/SpanContext.m | 34 +++++++++++++--- api/trace/src/SpanContextProxy.cpp | 39 +++++++++++++++++-- test/tcontextPropagation.m | 10 +++-- test/ttrace.m | 13 +++++-- 4 files changed, 80 insertions(+), 16 deletions(-) diff --git a/api/trace/+opentelemetry/+trace/SpanContext.m b/api/trace/+opentelemetry/+trace/SpanContext.m index 04e840e..6abd72e 100644 --- a/api/trace/+opentelemetry/+trace/SpanContext.m +++ b/api/trace/+opentelemetry/+trace/SpanContext.m @@ -54,8 +54,9 @@ % default option values issampled = true; isremote = true; + includets = false; % whether TraceState is specified if nargin > 2 - optionnames = ["IsSampled", "IsRemote"]; + optionnames = ["IsSampled", "IsRemote", "TraceState"]; for i = 1:2:length(varargin) try namei = validatestring(varargin{i}, optionnames); @@ -68,17 +69,37 @@ if (isnumeric(valuei) || islogical(valuei)) && isscalar(valuei) issampled = logical(valuei); end - else % strcmp(namei, "IsRemote") + elseif strcmp(namei, "IsRemote") if (isnumeric(valuei) || islogical(valuei)) && isscalar(valuei) isremote = logical(valuei); end + else % strcmp(namei, "TraceState") + if isa(valuei, "dictionary") + try + tskeysi = string(keys(valuei)); + tsvaluesi = string(values(valuei)); + catch + % invalid TraceState, ignore + continue + end + tskeys = tskeysi; + tsvalues = tsvaluesi; + includets = true; + end end end end - obj.Proxy = libmexclass.proxy.Proxy("Name", ... - "libmexclass.opentelemetry.SpanContextProxy", ... - "ConstructorArguments", {traceid, spanid, issampled, isremote}); + if includets + obj.Proxy = libmexclass.proxy.Proxy("Name", ... + "libmexclass.opentelemetry.SpanContextProxy", ... + "ConstructorArguments", {traceid, spanid, issampled, ... + isremote, tskeys, tsvalues}); + else + obj.Proxy = libmexclass.proxy.Proxy("Name", ... + "libmexclass.opentelemetry.SpanContextProxy", ... + "ConstructorArguments", {traceid, spanid, issampled, isremote}); + end end end end @@ -93,7 +114,8 @@ end function tracestate = get.TraceState(obj) - tracestate = obj.Proxy.getTraceState(); + [keys, values] = obj.Proxy.getTraceState(); + tracestate = dictionary(keys, values); end function traceflags = get.TraceFlags(obj) diff --git a/api/trace/src/SpanContextProxy.cpp b/api/trace/src/SpanContextProxy.cpp index 6b40eab..8daa875 100644 --- a/api/trace/src/SpanContextProxy.cpp +++ b/api/trace/src/SpanContextProxy.cpp @@ -9,6 +9,7 @@ #include "opentelemetry/trace/default_span.h" #include "opentelemetry/trace/context.h" #include "opentelemetry/trace/trace_flags.h" +#include "opentelemetry/trace/trace_state.h" namespace common = opentelemetry::common; namespace context_api = opentelemetry::context; @@ -29,7 +30,23 @@ libmexclass::proxy::MakeResult SpanContextProxy::make(const libmexclass::proxy:: if (issampled) { traceflags |= trace_api::TraceFlags::kIsSampled; } - return std::make_shared(trace_api::SpanContext{traceid, spanid, trace_api::TraceFlags(traceflags), isremote}); + + if (constructor_arguments.getNumberOfElements() <= 4) { + return std::make_shared(trace_api::SpanContext{traceid, spanid, trace_api::TraceFlags(traceflags), isremote}); + } else { + auto tracestate = trace_api::TraceState::GetDefault(); + matlab::data::StringArray tracestatekeys_mda = constructor_arguments[4]; + matlab::data::StringArray tracestatevalues_mda = constructor_arguments[5]; + size_t ntskeys = tracestatekeys_mda.getNumberOfElements(); + for (size_t i = 0; i < ntskeys; ++i) { + std::string tracestatekeyi = static_cast(tracestatekeys_mda[i]), + tracestatevaluei = static_cast(tracestatevalues_mda[i]); + if (tracestate->IsValidKey(tracestatekeyi) && tracestate->IsValidValue(tracestatevaluei)) { + tracestate = tracestate->Set(tracestatekeyi, tracestatevaluei); + } + } + return std::make_shared(trace_api::SpanContext{traceid, spanid, trace_api::TraceFlags(traceflags), isremote, tracestate}); + } } void SpanContextProxy::getTraceId(libmexclass::proxy::method::Context& context) { @@ -71,11 +88,25 @@ void SpanContextProxy::getSpanId(libmexclass::proxy::method::Context& context) { } void SpanContextProxy::getTraceState(libmexclass::proxy::method::Context& context) { - nostd::shared_ptr tracestate = CppSpanContext.trace_state(); + std::list keys; + std::list values; + + // repeatedly invoke the callback lambda to retrieve each entry + bool success = CppSpanContext.trace_state()->GetAllEntries( + [&keys, &values](nostd::string_view currkey, nostd::string_view currvalue) { + keys.push_back(std::string(currkey)); + values.push_back(std::string(currvalue)); + + return true; + }); + size_t nkeys = keys.size(); + matlab::data::ArrayDimensions dims = {nkeys, 1}; matlab::data::ArrayFactory factory; - auto tracestate_mda = factory.createScalar(tracestate->ToHeader()); - context.outputs[0] = tracestate_mda; + auto keys_mda = factory.createArray(dims, keys.cbegin(), keys.cend()); + auto values_mda = factory.createArray(dims, values.cbegin(), values.cend()); + context.outputs[0] = keys_mda; + context.outputs[1] = values_mda; } void SpanContextProxy::getTraceFlags(libmexclass::proxy::method::Context& context) { diff --git a/test/tcontextPropagation.m b/test/tcontextPropagation.m index 7b503fc..095432e 100644 --- a/test/tcontextPropagation.m +++ b/test/tcontextPropagation.m @@ -17,6 +17,7 @@ TraceId SpanId TraceState + TraceStateString Headers BaggageKeys BaggageValues @@ -33,9 +34,10 @@ function setupOnce(testCase) % simulate an HTTP header with relevant fields, used for extraction testCase.TraceId = "0af7651916cd43dd8448eb211c80319c"; testCase.SpanId = "00f067aa0ba902b7"; - testCase.TraceState = "foo=00f067aa0ba902b7"; + testCase.TraceState = dictionary("foo", "00f067aa0ba902b7"); + testCase.TraceStateString = "foo=00f067aa0ba902b7"; testCase.Headers = ["traceparent", "00-" + testCase.TraceId + ... - "-" + testCase.SpanId + "-01"; "tracestate", testCase.TraceState]; + "-" + testCase.SpanId + "-01"; "tracestate", testCase.TraceStateString]; testCase.BaggageKeys = ["userId", "serverNode", "isProduction"]; testCase.BaggageValues = ["alice", "DF28", "false"]; testCase.BaggageHeaders = ["baggage", strjoin(strcat(testCase.BaggageKeys, ... @@ -71,11 +73,13 @@ function testExtract(testCase) results = readJsonResults(testCase); results = results{1}; - % check trace and parent IDs + % check trace and parent IDs, and span context verifyEqual(testCase, string(results.resourceSpans.scopeSpans.spans.traceId), ... testCase.TraceId); verifyEqual(testCase, string(results.resourceSpans.scopeSpans.spans.parentSpanId), ... testCase.SpanId); + verifyEqual(testCase, string(results.resourceSpans.scopeSpans.spans.traceState), ... + testCase.TraceStateString); % check trace state in span context spancontext = getSpanContext(sp); verifyEqual(testCase, spancontext.TraceState, testCase.TraceState); diff --git a/test/ttrace.m b/test/ttrace.m index bc62043..8271afd 100644 --- a/test/ttrace.m +++ b/test/ttrace.m @@ -254,7 +254,7 @@ function testGetSpanContext(testCase) results = readJsonResults(testCase); verifyEqual(testCase, ctxt.TraceId, string(results{1}.resourceSpans.scopeSpans.spans.traceId)); verifyEqual(testCase, ctxt.SpanId, string(results{1}.resourceSpans.scopeSpans.spans.spanId)); - verifyEqual(testCase, ctxt.TraceState, ""); + verifyTrue(testCase, isa(ctxt.TraceState, "dictionary") && numEntries(ctxt.TraceState) == 0); verifyEqual(testCase, ctxt.TraceFlags, "01"); % sampled flag should be on end @@ -265,13 +265,16 @@ function testSpanContext(testCase) spanid = "0000000000111122"; issampled = false; isremote = false; + tracestate = dictionary(["foo" "bar"], ["foo1" "bar1"]); + tracestate_str = "bar=bar1,foo=foo1"; sc = opentelemetry.trace.SpanContext(traceid, spanid, ... - "IsSampled", issampled, "IsRemote", isremote); + "IsSampled", issampled, "IsRemote", isremote, ... + "TraceState", tracestate); % verify SpanContext object created correctly verifyEqual(testCase, sc.TraceId, lower(traceid)); verifyEqual(testCase, sc.SpanId, spanid); - verifyEqual(testCase, sc.TraceState, ""); + verifyEqual(testCase, sc.TraceState, tracestate); verifyEqual(testCase, sc.TraceFlags, "00"); % sampled flag should be off verifyEqual(testCase, isRemote(sc), isremote); @@ -296,10 +299,14 @@ function testSpanContext(testCase) lower(traceid)); verifyEqual(testCase, string(results{1}.resourceSpans.scopeSpans.spans.parentSpanId), ... spanid); + verifyEqual(testCase, string(results{1}.resourceSpans.scopeSpans.spans.traceState), ... + tracestate_str); verifyEqual(testCase, string(results{2}.resourceSpans.scopeSpans.spans.traceId), ... lower(traceid)); verifyEqual(testCase, string(results{2}.resourceSpans.scopeSpans.spans.parentSpanId), ... spanid); + verifyEqual(testCase, string(results{2}.resourceSpans.scopeSpans.spans.traceState), ... + tracestate_str); end function testTime(testCase)