Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Mar 11, 2024
1 parent 2749b89 commit 4f37907
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 193 deletions.
1 change: 0 additions & 1 deletion src/databricks/labs/lsql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .core import Row

__all__ = ["Row"]

21 changes: 15 additions & 6 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
self.execute(sql)

@staticmethod
def _row_to_sql(row: DataclassInstance, fields: list[dataclasses.Field]):
def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]):
data = []
for f in fields:
value = getattr(row, f.name)
Expand Down Expand Up @@ -215,8 +215,11 @@ def __init__(self, debug_truncate_bytes: int | None = None):
msg = "Not in the Databricks Runtime"
raise RuntimeError(msg)
try:
# pylint: disable-next=import-error,import-outside-toplevel
from pyspark.sql.session import SparkSession # type: ignore[import-not-found]
# pylint: disable-next=import-error,import-outside-toplevel,useless-suppression
from pyspark.sql.session import ( # type: ignore[import-not-found]
SparkSession,
)

super().__init__(SparkSession.builder.getOrCreate(), debug_truncate_bytes)
except ImportError as e:
raise RuntimeError("pyspark is not available") from e
Expand All @@ -225,15 +228,21 @@ def __init__(self, debug_truncate_bytes: int | None = None):
class DatabricksConnectBackend(_SparkBackend):
def __init__(self, ws: WorkspaceClient):
try:
from databricks.connect import DatabricksSession
# pylint: disable-next=import-outside-toplevel
from databricks.connect import ( # type: ignore[import-untyped]
DatabricksSession,
)

spark = DatabricksSession.builder().sdk_config(ws.config).getOrCreate()
super().__init__(spark, ws.config.debug_truncate_bytes)
except ImportError as e:
raise RuntimeError("Please run `pip install databricks-connect`") from e


class MockBackend(SqlBackend):
def __init__(self, *, fails_on_first: dict[str,str] | None = None, rows: dict | None = None, debug_truncate_bytes=96):
def __init__(
self, *, fails_on_first: dict[str, str] | None = None, rows: dict | None = None, debug_truncate_bytes=96
):
self._fails_on_first = fails_on_first
if not rows:
rows = {}
Expand Down Expand Up @@ -286,4 +295,4 @@ def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance]

@staticmethod
def _row_factory(klass: Dataclass) -> type:
return Row.factory([f.name for f in dataclasses.fields(klass)])
return Row.factory([f.name for f in dataclasses.fields(klass)])
79 changes: 43 additions & 36 deletions src/databricks/labs/lsql/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import base64
import datetime
import functools
import json
import logging
import random
import threading
import time
import types
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from datetime import timedelta
from typing import Any, Callable
from typing import Any

import requests
import sqlglot
Expand All @@ -20,11 +19,11 @@
Disposition,
ExecuteStatementResponse,
Format,
ResultData,
ServiceError,
ServiceErrorCode,
State,
StatementState,
StatementStatus, State,
StatementStatus,
)

MAX_SLEEP_PER_ATTEMPT = 10
Expand All @@ -38,6 +37,7 @@

class Row(tuple):
"""Row is a tuple with named fields that resembles PySpark's SQL Row API."""

def __new__(cls, *args, **kwargs):
"""Create a new instance of Row."""
if args and kwargs:
Expand All @@ -47,7 +47,7 @@ def __new__(cls, *args, **kwargs):
row = tuple.__new__(cls, list(kwargs.values()))
row.__columns__ = list(kwargs.keys())
return row
if len(args) == 1 and hasattr(cls, '__columns__') and isinstance(args[0], (types.GeneratorType, list, tuple)):
if len(args) == 1 and hasattr(cls, "__columns__") and isinstance(args[0], (types.GeneratorType, list, tuple)):
# this type returned by Row.factory() and we already know the column names
return cls(*args[0])
if len(args) == 2 and isinstance(args[0], (list, tuple)) and isinstance(args[1], (list, tuple)):
Expand Down Expand Up @@ -115,19 +115,21 @@ class StatementExecutionExt:
megabytes or gigabytes of data serialized in Apache Arrow format, and low result fetching latency, should use
the stateful Databricks SQL Connector for Python."""

def __init__(self, ws: WorkspaceClient,
disposition: Disposition | None = None,
warehouse_id: str | None = None,
byte_limit: int | None = None,
catalog: str | None = None,
schema: str | None = None,
timeout: timedelta = timedelta(minutes=20),
disable_magic: bool = False,
http_session_factory: Callable[[], requests.Session] | None = None):
def __init__( # pylint: disable=too-many-arguments
self,
ws: WorkspaceClient,
disposition: Disposition | None = None,
warehouse_id: str | None = None,
byte_limit: int | None = None,
catalog: str | None = None,
schema: str | None = None,
timeout: timedelta = timedelta(minutes=20),
disable_magic: bool = False,
http_session_factory: Callable[[], requests.Session] | None = None,
):
if not http_session_factory:
http_session_factory = requests.Session
self._ws = ws
self._api = ws.api_client
self._http = http_session_factory()
self._lock = threading.Lock()
self._warehouse_id = warehouse_id
Expand Down Expand Up @@ -301,15 +303,13 @@ def fetch_all(
Timeout after which the query is cancelled. See :py:meth:`execute` for more details.
:return: Iterator[Row]
"""
execute_response = self.execute(statement,
warehouse_id=warehouse_id,
byte_limit=byte_limit,
catalog=catalog,
schema=schema,
timeout=timeout)
execute_response = self.execute(
statement, warehouse_id=warehouse_id, byte_limit=byte_limit, catalog=catalog, schema=schema, timeout=timeout
)
assert execute_response.statement_id is not None
result_data = execute_response.result
if result_data is None:
return []
return
row_factory, col_conv = self._result_schema(execute_response)
while True:
if result_data.data_array:
Expand All @@ -318,6 +318,7 @@ def fetch_all(
next_chunk_index = result_data.next_chunk_index
if result_data.external_links:
for external_link in result_data.external_links:
assert external_link.external_link is not None
next_chunk_index = external_link.next_chunk_index
response = self._http.get(external_link.external_link)
response.raise_for_status()
Expand All @@ -326,8 +327,8 @@ def fetch_all(
if not next_chunk_index:
return
result_data = self._ws.statement_execution.get_statement_result_chunk_n(
execute_response.statement_id,
next_chunk_index)
execute_response.statement_id, next_chunk_index
)

def fetch_one(self, statement: str, disable_magic: bool = False, **kwargs) -> Row | None:
"""Execute a query and fetch the first available record.
Expand Down Expand Up @@ -360,31 +361,32 @@ def fetch_one(self, statement: str, disable_magic: bool = False, **kwargs) -> Ro

def fetch_value(self, statement: str, **kwargs) -> Any | None:
"""Execute a query and fetch the first available value."""
for v, in self.fetch_all(statement, **kwargs):
for (v,) in self.fetch_all(statement, **kwargs):
return v
return None

def _statement_timeouts(self, timeout):
def _statement_timeouts(self, timeout) -> tuple[timedelta, str | None]:
"""Set server-side and client-side timeouts for statement execution."""
if timeout is None:
timeout = self._timeout
wait_timeout = None
if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT:
# set server-side timeout
wait_timeout = f"{timeout.total_seconds()}s"
assert timeout is not None
return timeout, wait_timeout

@staticmethod
def _parse_date(value: str) -> datetime.date:
"""Parse date from string in ISO format."""
year, month, day = value.split('-')
year, month, day = value.split("-")
return datetime.date(int(year), int(month), int(day))

@staticmethod
def _parse_timestamp(value: str) -> datetime.datetime:
"""Parse timestamp from string in ISO format."""
# make it work with Python 3.7 to 3.10 as well
return datetime.datetime.fromisoformat(value.replace('Z', '+00:00'))
return datetime.datetime.fromisoformat(value.replace("Z", "+00:00"))

@staticmethod
def _raise_if_needed(status: StatementStatus):
Expand Down Expand Up @@ -440,19 +442,22 @@ def _default_warehouse(self) -> str:
return self._ws.config.warehouse_id
ids = []
for v in self._ws.warehouses.list():
assert v.id is not None
if v.state in [State.DELETED, State.DELETING]:
continue
elif v.state == State.RUNNING:
if v.state == State.RUNNING:
self._ws.config.warehouse_id = v.id
return self._ws.config.warehouse_id
ids.append(v.id)
if len(ids) > 0:
# otherwise - first warehouse
self._ws.config.warehouse_id = ids[0]
return self._ws.config.warehouse_id
raise ValueError("no warehouse_id=... given, "
"neither it is set in the WorkspaceClient(..., warehouse_id=...), "
"nor in the DATABRICKS_WAREHOUSE_ID environment variable")
raise ValueError(
"no warehouse_id=... given, "
"neither it is set in the WorkspaceClient(..., warehouse_id=...), "
"nor in the DATABRICKS_WAREHOUSE_ID environment variable"
)

@staticmethod
def _add_limit(statement: str) -> str:
Expand All @@ -463,10 +468,10 @@ def _add_limit(statement: str) -> str:
statement_ast = statements[0]
if isinstance(statement_ast, sqlglot.expressions.Select):
if statement_ast.limit is not None:
limit = statement_ast.args.get('limit', None)
if limit and limit.text('expression') != '1':
limit = statement_ast.args.get("limit", None)
if limit and limit.text("expression") != "1":
raise ValueError(f"limit is not 1: {limit.text('expression')}")
return statement_ast.limit(expression=1).sql('databricks')
return statement_ast.limit(expression=1).sql("databricks")
return statement

def _result_schema(self, execute_response: ExecuteStatementResponse):
Expand All @@ -485,6 +490,8 @@ def _result_schema(self, execute_response: ExecuteStatementResponse):
if not columns:
columns = []
for col in columns:
assert col.name is not None
assert col.type_name is not None
col_names.append(col.name)
conv = self._type_converters.get(col.type_name, None)
if conv is None:
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import pathlib
import string
import sys
from typing import MutableMapping, Callable
from typing import Callable, MutableMapping

import pytest
from databricks.labs.blueprint.logger import install_logger
from databricks.sdk import WorkspaceClient
from pytest import fixture

from databricks.labs.lsql.__about__ import __version__
from databricks.labs.blueprint.logger import install_logger

install_logger()
logging.getLogger("databricks").setLevel("DEBUG")
Expand Down Expand Up @@ -89,4 +89,4 @@ def inner(var: str) -> str:
skip(f"Environment variable {var} is missing")
return debug_env[var]

return inner
return inner
29 changes: 14 additions & 15 deletions tests/integration/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

import pytest
from databricks.sdk.service.sql import Disposition

from databricks.labs.lsql.core import StatementExecutionExt
from databricks.sdk.service.sql import Disposition

logger = logging.getLogger(__name__)

Expand All @@ -21,21 +21,20 @@ def test_sql_execution(ws, env_or_skip):
results = []
see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"))
for pickup_zip, dropoff_zip in see.fetch_all(
"SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10",
catalog="samples"
"SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10", catalog="samples"
):
results.append((pickup_zip, dropoff_zip))
assert results == [
(10282, 10171),
(10110, 10110),
(10103, 10023),
(10022, 10017),
(10110, 10282),
(10009, 10065),
(10153, 10199),
(10112, 10069),
(10023, 10153),
(10012, 10003)
(10282, 10171),
(10110, 10110),
(10103, 10023),
(10022, 10017),
(10110, 10282),
(10009, 10065),
(10153, 10199),
(10112, 10069),
(10023, 10153),
(10012, 10003),
]


Expand All @@ -59,7 +58,7 @@ def test_sql_execution_partial(ws, env_or_skip):
(10153, 10199),
(10112, 10069),
(10023, 10153),
(10012, 10003)
(10012, 10003),
]


Expand All @@ -83,4 +82,4 @@ def test_fetch_one_works(ws):
def test_fetch_value(ws):
see = StatementExecutionExt(ws)
count = see.fetch_value("SELECT COUNT(*) FROM samples.nyctaxi.trips")
assert count == 21932
assert count == 21932
Loading

0 comments on commit 4f37907

Please sign in to comment.