Skip to content

Commit

Permalink
feat: add typing on some public api functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pbabics committed Mar 7, 2024
1 parent 080b879 commit 0ab7b02
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 19 deletions.
28 changes: 25 additions & 3 deletions elasticapm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from elasticapm.base import Client, get_client # noqa: F401
from elasticapm.conf import setup_logging # noqa: F401
from elasticapm.contrib.asyncio.traces import async_capture_span # noqa: F401 E402
from elasticapm.contrib.serverless import capture_serverless # noqa: F401
from elasticapm.instrumentation.control import instrument, uninstrument # noqa: F401
from elasticapm.traces import ( # noqa: F401
Expand All @@ -49,7 +50,30 @@
)
from elasticapm.utils.disttracing import trace_parent_from_headers, trace_parent_from_string # noqa: F401

__all__ = ("VERSION", "Client")
__all__ = (
"VERSION",
"Client",
"get_client",
"setup_logging",
"capture_serverless",
"instrument",
"uninstrument",
"capture_span",
"get_span_id",
"get_trace_id",
"get_trace_parent_header",
"get_transaction_id",
"label",
"set_context",
"set_custom_context",
"set_transaction_name",
"set_transaction_outcome",
"set_transaction_result",
"set_user_context",
"trace_parent_from_headers",
"trace_parent_from_string",
"async_capture_span",
)

_activation_method = None

Expand All @@ -64,5 +88,3 @@

if sys.version_info <= (3, 5):
raise DeprecationWarning("The Elastic APM agent requires Python 3.6+")

from elasticapm.contrib.asyncio.traces import async_capture_span # noqa: F401 E402
20 changes: 15 additions & 5 deletions elasticapm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@
import warnings
from copy import deepcopy
from datetime import timedelta
from typing import Optional, Sequence, Tuple
from types import TracebackType
from typing import Any, Optional, Sequence, Tuple, Type, Union

import elasticapm
from elasticapm.conf import Config, VersionedConfig, constants
from elasticapm.conf.constants import ERROR
from elasticapm.metrics.base_metrics import MetricsRegistry
from elasticapm.traces import DroppedSpan, Tracer, execution_context
from elasticapm.traces import DroppedSpan, Tracer, Transaction, execution_context
from elasticapm.utils import cgroup, cloud, compat, is_master_process, stacks, varmap
from elasticapm.utils.disttracing import TraceParent
from elasticapm.utils.encoding import enforce_label_format, keyword_field, shorten, transform
Expand Down Expand Up @@ -261,15 +262,22 @@ def capture(self, event_type, date=None, context=None, custom=None, stack=None,
self.queue(ERROR, data, flush=not handled)
return data["id"]

def capture_message(self, message=None, param_message=None, **kwargs):
def capture_message(self, message: Optional[str] = None, param_message=None, **kwargs: Any) -> str:
"""
Creates an event from ``message``.
>>> client.capture_message('My event just happened!')
"""
return self.capture("Message", message=message, param_message=param_message, **kwargs)

def capture_exception(self, exc_info=None, handled=True, **kwargs):
def capture_exception(
self,
exc_info: Union[
None, bool, Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]
] = None,
handled: bool = True,
**kwargs: Any
) -> str:
"""
Creates an event from an exception.
Expand Down Expand Up @@ -317,7 +325,9 @@ def begin_transaction(
transaction_type, trace_parent=trace_parent, start=start, auto_activate=auto_activate, links=links
)

def end_transaction(self, name=None, result="", duration=None):
def end_transaction(
self, name: Optional[str] = None, result: str = "", duration: Optional[Union[float, timedelta]] = None
) -> Transaction:
"""
End the current transaction.
Expand Down
25 changes: 14 additions & 11 deletions elasticapm/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import decimal
import functools
import random
import re
Expand Down Expand Up @@ -989,10 +990,12 @@ def begin_transaction(
execution_context.set_transaction(transaction)
return transaction

def end_transaction(self, result=None, transaction_name=None, duration=None):
def end_transaction(
self, result: Optional[str] = None, transaction_name: Optional[str] = None, duration: Optional[timedelta] = None
) -> Transaction:
"""
End the current transaction and queue it for sending
:param result: result of the transaction, e.g. "OK" or 200
:param result: result of the transaction, e.g. "OK" or "HTTP 2xx"
:param transaction_name: name of the transaction
:param duration: override duration, mostly useful for testing
:return:
Expand Down Expand Up @@ -1130,7 +1133,7 @@ def handle_exit(
logger.debug("ended non-existing span %s of type %s", self.name, self.type)


def label(**labels) -> None:
def label(**labels: Union[str, bool, int, float, decimal.Decimal]) -> None:
"""
Labels current transaction. Keys should be strings, values can be strings, booleans,
or numerical values (int, float, Decimal)
Expand Down Expand Up @@ -1211,43 +1214,43 @@ def set_transaction_outcome(outcome=None, http_status_code=None, override=True)
transaction.outcome = outcome


def get_transaction_id():
def get_transaction_id() -> Optional[str]:
"""
Returns the current transaction ID
"""
transaction = execution_context.get_transaction()
if not transaction:
return
return None
return transaction.id


def get_trace_parent_header():
def get_trace_parent_header() -> Optional[str]:
"""
Return the trace parent header for the current transaction.
"""
transaction = execution_context.get_transaction()
if not transaction or not transaction.trace_parent:
return
return None
return transaction.trace_parent.to_string()


def get_trace_id():
def get_trace_id() -> Optional[str]:
"""
Returns the current trace ID
"""
transaction = execution_context.get_transaction()
if not transaction:
return
return None
return transaction.trace_parent.trace_id if transaction.trace_parent else None


def get_span_id():
def get_span_id() -> Optional[str]:
"""
Returns the current span ID
"""
span = execution_context.get_span()
if not span:
return
return None
return span.id


Expand Down

0 comments on commit 0ab7b02

Please sign in to comment.