Skip to content

Commit

Permalink
Ensure propagation of lsql version into User-Agent header when it…
Browse files Browse the repository at this point in the history
… is used as library (#206)

This PR ensures correct library attribution.
  • Loading branch information
nfx authored Jul 2, 2024
1 parent 56e7f70 commit 4990ce1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]
dependencies = [
"databricks-labs-blueprint[yaml]>=0.4.2",
"databricks-sdk>=0.22.0",
"databricks-sdk>=0.29.0",
"sqlglot>=22.3.1"
]

Expand Down
8 changes: 7 additions & 1 deletion src/databricks/labs/lsql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from databricks.labs.lsql.core import Row
from databricks.sdk.core import with_user_agent_extra

from .__about__ import __version__
from .core import Row

__all__ = ["Row"]


with_user_agent_extra("lsql", __version__)
8 changes: 4 additions & 4 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ..
field_type = field_type.__args__[0]
if value is None:
data.append("NULL")
elif field_type == bool:
elif field_type is bool:
data.append("TRUE" if value else "FALSE")
elif field_type == str:
elif field_type is str:
value = str(value).replace("'", "''")
data.append(f"'{value}'")
elif field_type == int:
elif field_type is int:
data.append(f"{value}")
else:
msg = f"unknown type: {field_type}"
Expand Down Expand Up @@ -336,7 +336,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
rows = self._filter_none_rows(rows, klass)
if mode == "overwrite":
self._save_table = []
if klass.__class__ == type:
if klass.__class__ == type: # noqa: E721
row_factory = self._row_factory(klass)
rows = [row_factory(*dataclasses.astuple(r)) for r in rows]
self._save_table.append((full_name, rows, mode))
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_useragent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import contextlib
import functools
import typing
from http.server import BaseHTTPRequestHandler

from databricks.sdk import WorkspaceClient

from databricks.labs.lsql.__about__ import __version__
from databricks.labs.lsql.dashboards import Dashboards


@contextlib.contextmanager
def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]):
from http.server import HTTPServer
from threading import Thread

class _handler(BaseHTTPRequestHandler):
def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args):
self._handler = handler
super().__init__(*args)

def __getattr__(self, item):
if "do_" != item[0:3]:
raise AttributeError(f"method {item} not found")
return functools.partial(self._handler, self)

handler_factory = functools.partial(_handler, handler)
srv = HTTPServer(("localhost", 0), handler_factory)
t = Thread(target=srv.serve_forever)
try:
t.daemon = True
t.start()
yield "http://{0}:{1}".format(*srv.server_address)
finally:
srv.shutdown()


def test_user_agent_is_propagated():
user_agent = {}

def inner(h: BaseHTTPRequestHandler):
for pair in h.headers["User-Agent"].split(" "):
if "/" not in pair:
continue
k, v = pair.split("/")
user_agent[k] = v
h.send_response(200)
h.send_header("Content-Type", "application/json")
h.end_headers()
h.wfile.write(b"{}")
h.wfile.flush()

with http_fixture_server(inner) as host:
ws = WorkspaceClient(host=host, token="_")
d = Dashboards(ws)
d.get_dashboard("...")

assert "lsql" in user_agent
assert user_agent["lsql"] == __version__

0 comments on commit 4990ce1

Please sign in to comment.