Skip to content

Commit

Permalink
Add SnowflakeTraceIdGenerator implementation (snowflakedb#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bdrutu authored Jul 18, 2024
1 parent b6cd806 commit a57ca02
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 19 deletions.
7 changes: 4 additions & 3 deletions src/snowflake/telemetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Stored Procedures.
"""

from opentelemetry import trace
from opentelemetry.trace import get_current_span
from opentelemetry.util import types

from snowflake.telemetry.version import VERSION
Expand All @@ -26,7 +26,8 @@ def add_event(
"""
Add an event name and associated attributes to the current span.
"""
trace.get_current_span().add_event(name, attributes)
get_current_span().add_event(name, attributes)


def set_span_attribute(
key: str,
Expand All @@ -35,4 +36,4 @@ def set_span_attribute(
"""
Set an attribute key, value pair on the current span.
"""
trace.get_current_span().set_attribute(key, value)
get_current_span().set_attribute(key, value)
29 changes: 29 additions & 0 deletions src/snowflake/telemetry/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import time
import random

from opentelemetry import trace
from opentelemetry.sdk.trace import RandomIdGenerator

# Generator that returns
# trace_id: the given (inherited) trace id on the first call to generate_trace_id, and a Snowflake trace_id on subsequent calls
# span_id: a random span_id
class SnowflakeTraceIdGenerator(RandomIdGenerator):
def generate_trace_id(self) -> int:
trace_id = trace.INVALID_TRACE_ID
while trace_id == trace.INVALID_TRACE_ID:
# Number of minutes since the epoch
timestamp_in_minutes = int(time.time()) // 60
# Convert and pad to 4 bytes
timestamp_bytes = timestamp_in_minutes.to_bytes(4, byteorder='big', signed=False)
suffix_bytes = random.getrandbits(96).to_bytes(12, byteorder='big', signed=False)
trace_id = int.from_bytes(timestamp_bytes + suffix_bytes, byteorder='big', signed=False)
return trace_id


__all__ = [
"SnowflakeTraceIdGenerator",
]
29 changes: 13 additions & 16 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

from snowflake import telemetry
from opentelemetry import trace
from opentelemetry.trace import get_current_span
from opentelemetry.sdk.trace import (
TracerProvider,
)
Expand All @@ -14,10 +14,7 @@
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.trace import (
Status,
StatusCode,
)

from opentelemetry.trace.span import INVALID_SPAN


Expand All @@ -33,7 +30,7 @@ def test_api_without_current_span(self):
"""
Tests that no exceptions are raised by public API methods when called without a current span
"""
self.assertEqual(trace.get_current_span(), INVALID_SPAN)
self.assertEqual(get_current_span(), INVALID_SPAN)
telemetry.add_event("EventName1")
telemetry.add_event("EventName2",
{
Expand All @@ -48,9 +45,9 @@ def test_add_event(self):
"""
self.configure_open_telemetry()
with self.tracer.start_as_current_span("Auto-instrumented span"):
self.assertNotEqual(trace.get_current_span(), INVALID_SPAN)
self.assertNotEqual(get_current_span(), INVALID_SPAN)
telemetry.add_event("EventName1")
trace.get_current_span().end()
get_current_span().end()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(1, len(spans))
events = spans[0].events
Expand All @@ -64,9 +61,9 @@ def test_add_event_none_name(self):
"""
self.configure_open_telemetry()
with self.tracer.start_as_current_span("Auto-instrumented span"):
self.assertNotEqual(trace.get_current_span(), INVALID_SPAN)
self.assertNotEqual(get_current_span(), INVALID_SPAN)
telemetry.add_event(None)
trace.get_current_span().end()
get_current_span().end()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(1, len(spans))
events = spans[0].events
Expand All @@ -80,9 +77,9 @@ def test_add_event_empty_name(self):
"""
self.configure_open_telemetry()
with self.tracer.start_as_current_span("Auto-instrumented span"):
self.assertNotEqual(trace.get_current_span(), INVALID_SPAN)
self.assertNotEqual(get_current_span(), INVALID_SPAN)
telemetry.add_event("")
trace.get_current_span().end()
get_current_span().end()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(1, len(spans))
events = spans[0].events
Expand All @@ -93,7 +90,7 @@ def test_add_event_empty_name(self):
def test_add_event_with_attributes(self):
self.configure_open_telemetry()
with self.tracer.start_as_current_span("Auto-instrumented span"):
self.assertNotEqual(trace.get_current_span(), INVALID_SPAN)
self.assertNotEqual(get_current_span(), INVALID_SPAN)
telemetry.add_event("EventName2",
{
"some int": 42,
Expand All @@ -103,7 +100,7 @@ def test_add_event_with_attributes(self):
"a false value": False,
"a none value": None,
})
trace.get_current_span().end()
get_current_span().end()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(1, len(spans))
events = spans[0].events
Expand All @@ -120,14 +117,14 @@ def test_add_event_with_attributes(self):
def test_set_span_attribute(self):
self.configure_open_telemetry()
with self.tracer.start_as_current_span("Auto-instrumented span"):
self.assertNotEqual(trace.get_current_span(), INVALID_SPAN)
self.assertNotEqual(get_current_span(), INVALID_SPAN)
telemetry.set_span_attribute("some int", 42)
telemetry.set_span_attribute("some str", "Val1")
telemetry.set_span_attribute("some float", 3.14)
telemetry.set_span_attribute("a true value", True)
telemetry.set_span_attribute("a false value", False)
telemetry.set_span_attribute("a none value", None)
trace.get_current_span().end()
get_current_span().end()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(1, len(spans))
attributes = spans[0].attributes
Expand Down
43 changes: 43 additions & 0 deletions tests/test_snowflake_trace_id_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest
from unittest.mock import patch

from opentelemetry import trace
from snowflake.telemetry.trace import SnowflakeTraceIdGenerator

MOCK_TIMESTAMP = 1719588243.3379807
INVALID_TRACE_ID = 0x00000000000000000000000000000000
TRACE_ID_MAX_VALUE = 2**128 - 1
SPAN_ID_MAX_VALUE = 2**64 - 1


class TestSnowflakeTraceIdGenerator(unittest.TestCase):

def test_valid_span_id(self):
id_generator = SnowflakeTraceIdGenerator()
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE)
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE)
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE)
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE)
self.assertTrue(trace.INVALID_SPAN_ID < id_generator.generate_span_id() <= SPAN_ID_MAX_VALUE)

@patch('time.time', return_value=MOCK_TIMESTAMP)
def test_valid_snowflake_trace_id(self, mock_time):
id_generator = SnowflakeTraceIdGenerator()
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())
self._verify_snowflake_trace_id(id_generator.generate_trace_id())

def _verify_snowflake_trace_id(self, trace_id: int):
# https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-api/src/opentelemetry/trace/span.py
self.assertTrue(trace.INVALID_TRACE_ID < trace_id <= TRACE_ID_MAX_VALUE)

# Get the hex format of the snowflake_trace_id and pad it to 32 characters
# The timestamp prefix is the first 8 characters of this.
timestamp_prefix = f'{trace_id:x}'.zfill(32)[:8]

# the expected prefix is the timestamp (in minutes) in hex format padded to 8 characters.
mock_timestamp_minutes = int(MOCK_TIMESTAMP) // 60
mock_timestamp_prefix = f'{mock_timestamp_minutes:x}'.zfill(8)
self.assertEqual(timestamp_prefix, mock_timestamp_prefix)

0 comments on commit a57ca02

Please sign in to comment.