From fb75b18fbeacefad77b391db8d0284c007f5acc7 Mon Sep 17 00:00:00 2001 From: laodouya Date: Thu, 18 May 2023 15:09:29 +0800 Subject: [PATCH 1/5] Add db api 2.0 driver --- chdb/dbapi/__init__.py | 84 ++++++++ chdb/dbapi/connections.py | 206 ++++++++++++++++++++ chdb/dbapi/constants/FIELD_TYPE.py | 32 ++++ chdb/dbapi/constants/__init__.py | 0 chdb/dbapi/converters.py | 292 ++++++++++++++++++++++++++++ chdb/dbapi/cursors.py | 298 +++++++++++++++++++++++++++++ chdb/dbapi/err.py | 61 ++++++ chdb/dbapi/times.py | 21 ++ examples/dbapi.py | 34 ++++ setup.py | 18 +- 10 files changed, 1039 insertions(+), 7 deletions(-) create mode 100644 chdb/dbapi/__init__.py create mode 100644 chdb/dbapi/connections.py create mode 100644 chdb/dbapi/constants/FIELD_TYPE.py create mode 100644 chdb/dbapi/constants/__init__.py create mode 100644 chdb/dbapi/converters.py create mode 100644 chdb/dbapi/cursors.py create mode 100644 chdb/dbapi/err.py create mode 100644 chdb/dbapi/times.py create mode 100644 examples/dbapi.py diff --git a/chdb/dbapi/__init__.py b/chdb/dbapi/__init__.py new file mode 100644 index 00000000000..4c15da98bd3 --- /dev/null +++ b/chdb/dbapi/__init__.py @@ -0,0 +1,84 @@ +from .converters import escape_dict, escape_sequence, escape_string +from .constants import FIELD_TYPE +from .err import ( + Warning, Error, InterfaceError, DataError, + DatabaseError, OperationalError, IntegrityError, InternalError, + NotSupportedError, ProgrammingError) +from . import connections as _orig_conn + +VERSION = (0, 1, 0, None) +if VERSION[3] is not None: + VERSION_STRING = "%d.%d.%d_%s" % VERSION +else: + VERSION_STRING = "%d.%d.%d" % VERSION[:3] + +threadsafety = 1 +apilevel = "2.0" +paramstyle = "format" + + +class DBAPISet(frozenset): + + def __ne__(self, other): + if isinstance(other, set): + return frozenset.__ne__(self, other) + else: + return other not in self + + def __eq__(self, other): + if isinstance(other, frozenset): + return frozenset.__eq__(self, other) + else: + return other in self + + def __hash__(self): + return frozenset.__hash__(self) + + +# TODO it's in pep249 find out meaning and usage of this +# https://www.python.org/dev/peps/pep-0249/#string +STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, + FIELD_TYPE.VAR_STRING]) +BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, + FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB]) +NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT, + FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG, + FIELD_TYPE.TINY, FIELD_TYPE.YEAR]) +DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE]) +TIME = DBAPISet([FIELD_TYPE.TIME]) +TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME]) +DATETIME = TIMESTAMP +ROWID = DBAPISet() + + +def Binary(x): + """Return x as a binary type.""" + return bytes(x) + + +def Connect(*args, **kwargs): + """ + Connect to the database; see connections.Connection.__init__() for + more information. + """ + from .connections import Connection + return Connection(*args, **kwargs) + + +if _orig_conn.Connection.__init__.__doc__ is not None: + Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ +del _orig_conn + + +def get_client_info(): # for MySQLdb compatibility + version = VERSION + if VERSION[3] is None: + version = VERSION[:3] + return '.'.join(map(str, version)) + + +connect = Connection = Connect + +NULL = "NULL" + +__version__ = get_client_info() diff --git a/chdb/dbapi/connections.py b/chdb/dbapi/connections.py new file mode 100644 index 00000000000..5fca50dc209 --- /dev/null +++ b/chdb/dbapi/connections.py @@ -0,0 +1,206 @@ +import json +from . import err +from .cursors import Cursor +from . import converters + +DEBUG = False +VERBOSE = False + + +class Connection(object): + """ + Representation of a connection with chdb. + + The proper way to get an instance of this class is to call + connect(). + + Accepts several arguments: + + :param cursorclass: Custom cursor class to use. + + See `Connection `_ in the + specification. + """ + + _closed = False + + def __init__(self, cursorclass=Cursor): + + self._resp = None + + # 1. pre-process params in init + self.encoding = 'utf8' + + self.cursorclass = cursorclass + + self._result = None + self._affected_rows = 0 + + self.connect() + + def connect(self): + self._closed = False + self._execute_command("select 1;") + self._read_query_result() + + def close(self): + """ + Send the quit message and close the socket. + + See `Connection.close() `_ + in the specification. + + :raise Error: If the connection is already closed. + """ + if self._closed: + raise err.Error("Already closed") + self._closed = True + + @property + def open(self): + """Return True if the connection is open""" + return not self._closed + + def commit(self): + """ + Commit changes to stable storage. + + See `Connection.commit() `_ + in the specification. + """ + return + + def rollback(self): + """ + Roll back the current transaction. + + See `Connection.rollback() `_ + in the specification. + """ + return + + def cursor(self, cursor=None): + """ + Create a new cursor to execute queries with. + + :param cursor: The type of cursor to create; current only :py:class:`Cursor` + None means use Cursor. + """ + if cursor: + return cursor(self) + return self.cursorclass(self) + + # The following methods are INTERNAL USE ONLY (called from Cursor) + def query(self, sql): + if isinstance(sql, str): + sql = sql.encode(self.encoding, 'surrogateescape') + self._execute_command(sql) + self._affected_rows = self._read_query_result() + return self._affected_rows + + def _execute_command(self, sql): + """ + :raise InterfaceError: If the connection is closed. + :raise ValueError: If no username was specified. + """ + if self._closed: + raise err.InterfaceError("Connection closed") + + if isinstance(sql, str): + sql = sql.encode(self.encoding) + + if isinstance(sql, bytearray): + sql = bytes(sql) + + # drop last command return + if self._resp is not None: + self._resp = None + + if DEBUG: + print("DEBUG: query:", sql) + try: + import chdb + self._resp = chdb.query(sql, output_format="JSON").data() + except Exception as error: + raise err.InterfaceError("query err: %s" % error) + + def escape(self, obj, mapping=None): + """Escape whatever value you pass to it. + + Non-standard, for internal use; do not use this in your applications. + """ + if isinstance(obj, str): + return "'" + self.escape_string(obj) + "'" + if isinstance(obj, (bytes, bytearray)): + ret = self._quote_bytes(obj) + return ret + return converters.escape_item(obj, mapping=mapping) + + def escape_string(self, s): + return converters.escape_string(s) + + def _quote_bytes(self, s): + return converters.escape_bytes(s) + + def _read_query_result(self): + self._result = None + result = CHDBResult(self) + result.read() + self._result = result + return result.affected_rows + + def __enter__(self): + """Context manager that returns a Cursor""" + return self.cursor() + + def __exit__(self, exc, value, traceback): + """On successful exit, commit. On exception, rollback""" + if exc: + self.rollback() + else: + self.commit() + + @property + def resp(self): + return self._resp + + +class CHDBResult(object): + def __init__(self, connection): + """ + :type connection: Connection + """ + self.connection = connection + self.affected_rows = 0 + self.insert_id = None + self.warning_count = 0 + self.message = None + self.field_count = 0 + self.description = None + self.rows = None + self.has_next = None + + def read(self): + try: + data = json.loads(self.connection.resp) + except Exception as error: + raise err.InterfaceError("Unsupported response format:" % error) + + try: + self.field_count = len(data["meta"]) + description = [] + for meta in data["meta"]: + fields = [meta["name"], meta["type"]] + description.append(tuple(fields)) + self.description = tuple(description) + + rows = [] + for line in data["data"]: + row = [] + for i in range(self.field_count): + column_data = converters.convert_column_data(self.description[i][1], line[self.description[i][0]]) + row.append(column_data) + rows.append(tuple(row)) + self.rows = tuple(rows) + except Exception as error: + raise err.InterfaceError("Read return data err:" % error) diff --git a/chdb/dbapi/constants/FIELD_TYPE.py b/chdb/dbapi/constants/FIELD_TYPE.py new file mode 100644 index 00000000000..2bc7713424a --- /dev/null +++ b/chdb/dbapi/constants/FIELD_TYPE.py @@ -0,0 +1,32 @@ +DECIMAL = 0 +TINY = 1 +SHORT = 2 +LONG = 3 +FLOAT = 4 +DOUBLE = 5 +NULL = 6 +TIMESTAMP = 7 +LONGLONG = 8 +INT24 = 9 +DATE = 10 +TIME = 11 +DATETIME = 12 +YEAR = 13 +NEWDATE = 14 +VARCHAR = 15 +BIT = 16 +JSON = 245 +NEWDECIMAL = 246 +ENUM = 247 +SET = 248 +TINY_BLOB = 249 +MEDIUM_BLOB = 250 +LONG_BLOB = 251 +BLOB = 252 +VAR_STRING = 253 +STRING = 254 +GEOMETRY = 255 + +CHAR = TINY +INTERVAL = ENUM + diff --git a/chdb/dbapi/constants/__init__.py b/chdb/dbapi/constants/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chdb/dbapi/converters.py b/chdb/dbapi/converters.py new file mode 100644 index 00000000000..17a210f219b --- /dev/null +++ b/chdb/dbapi/converters.py @@ -0,0 +1,292 @@ +import datetime +from decimal import Decimal +from .err import DataError +import re +import time +import arrow + + +def escape_item(val, mapping=None): + if mapping is None: + mapping = encoders + encoder = mapping.get(type(val)) + + # Fallback to default when no encoder found + if not encoder: + try: + encoder = mapping[str] + except KeyError: + raise TypeError("no default type converter defined") + + val = encoder(val, mapping) + return val + + +def escape_dict(val, mapping=None): + n = {} + for k, v in val.items(): + quoted = escape_item(v, mapping) + n[k] = quoted + return n + + +def escape_sequence(val, mapping=None): + n = [] + for item in val: + quoted = escape_item(item, mapping) + n.append(quoted) + return "(" + ",".join(n) + ")" + + +def escape_set(val, mapping=None): + return ','.join([escape_item(x, mapping) for x in val]) + + +def escape_bool(value, mapping=None): + return str(int(value)) + + +def escape_object(value, mapping=None): + return str(value) + + +def escape_int(value, mapping=None): + return str(value) + + +def escape_float(value, mapping=None): + return '%.15g' % value + + +_escape_table = [chr(x) for x in range(128)] +_escape_table[ord("'")] = u"''" + + +def _escape_unicode(value, mapping=None): + """escapes *value* with adding single quote. + + Value should be unicode + """ + return value.translate(_escape_table) + + +escape_string = _escape_unicode + +# On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow. +# (fixed in Python 3.6, http://bugs.python.org/issue24870) +# Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff. +# We can escape special chars and surrogateescape at once. +_escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)] + + +def escape_bytes(value, mapping=None): + return "'%s'" % value.decode('latin1').translate(_escape_bytes_table) + + +def escape_unicode(value, mapping=None): + return u"'%s'" % _escape_unicode(value) + + +def escape_str(value, mapping=None): + return "'%s'" % escape_string(str(value), mapping) + + +def escape_None(value, mapping=None): + return 'NULL' + + +def escape_timedelta(obj, mapping=None): + seconds = int(obj.seconds) % 60 + minutes = int(obj.seconds // 60) % 60 + hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24 + if obj.microseconds: + fmt = "'{0:02d}:{1:02d}:{2:02d}.{3:06d}'" + else: + fmt = "'{0:02d}:{1:02d}:{2:02d}'" + return fmt.format(hours, minutes, seconds, obj.microseconds) + + +def escape_time(obj, mapping=None): + return "'{}'".format(obj.isoformat(timespec='microseconds')) + + +def escape_datetime(obj, mapping=None): + return "'{}'".format(obj.isoformat(sep=' ', timespec='microseconds')) + # if obj.microsecond: + # fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'" + # else: + # fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'" + # return fmt.format(obj) + + +def escape_date(obj, mapping=None): + return "'{}'".format(obj.isoformat()) + + +def escape_struct_time(obj, mapping=None): + return escape_datetime(datetime.datetime(*obj[:6])) + + +def _convert_second_fraction(s): + if not s: + return 0 + # Pad zeros to ensure the fraction length in microseconds + s = s.ljust(6, '0') + return int(s[:6]) + + +def convert_datetime(obj): + """Returns a DATETIME or TIMESTAMP column value as a datetime object: + + >>> datetime_or_None('2007-02-25 23:06:20') + datetime.datetime(2007, 2, 25, 23, 6, 20) + >>> datetime_or_None('2007-02-25T23:06:20') + datetime.datetime(2007, 2, 25, 23, 6, 20) + + Illegal values are raise DataError + + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + + try: + return arrow.get(obj).datetime + except Exception as err: + raise DataError("Not valid datetime struct: %s" % err) + + +TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") + + +def convert_timedelta(obj): + """Returns a TIME column as a timedelta object: + + >>> timedelta_or_None('25:06:17') + datetime.timedelta(1, 3977) + >>> timedelta_or_None('-25:06:17') + datetime.timedelta(-2, 83177) + + Illegal values are returned as None: + + >>> timedelta_or_None('random crap') is None + True + + Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but + can accept values as (+|-)DD HH:MM:SS. The latter format will not + be parsed correctly by this function. + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + + m = TIMEDELTA_RE.match(obj) + if not m: + return obj + + try: + groups = list(m.groups()) + groups[-1] = _convert_second_fraction(groups[-1]) + negate = -1 if groups[0] else 1 + hours, minutes, seconds, microseconds = groups[1:] + + tdelta = datetime.timedelta( + hours=int(hours), + minutes=int(minutes), + seconds=int(seconds), + microseconds=int(microseconds) + ) * negate + return tdelta + except ValueError as err: + raise DataError("Not valid time or timedelta struct: %s" % err) + + +def convert_time(obj): + """Returns a TIME column as a time object: + + >>> time_or_None('15:06:17') + datetime.time(15, 6, 17) + + Illegal values are returned DataError: + + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + + try: + return arrow.get("1970-01-01T" + obj).time() + except Exception: + return convert_timedelta(obj) + + +def convert_date(obj): + """Returns a DATE column as a date object: + + >>> date_or_None('2007-02-26') + datetime.date(2007, 2, 26) + + Illegal values are returned as None: + + >>> date_or_None('2007-02-31') is None + True + >>> date_or_None('0000-00-00') is None + True + + """ + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode('ascii') + try: + return arrow.get(obj).date() + except Exception as err: + raise DataError("Not valid date struct: %s" % err) + + +def convert_set(s): + if isinstance(s, (bytes, bytearray)): + return set(s.split(b",")) + return set(s.split(",")) + + +def convert_characters(connection, data): + if connection.use_unicode: + data = data.decode("utf8") + return data + + +def convert_column_data(column_type, column_data): + data = column_data + + # Null + if data is None: + return data + + if not isinstance(column_type, str): + return data + + column_type = column_type.lower().strip() + if column_type == 'time': + data = convert_time(column_data) + elif column_type == 'date': + data = convert_date(column_data) + elif column_type == 'datetime': + data = convert_datetime(column_data) + + return data + + +encoders = { + bool: escape_bool, + int: escape_int, + float: escape_float, + str: escape_unicode, + tuple: escape_sequence, + list: escape_sequence, + set: escape_sequence, + frozenset: escape_sequence, + dict: escape_dict, + type(None): escape_None, + datetime.date: escape_date, + datetime.datetime: escape_datetime, + datetime.timedelta: escape_timedelta, + datetime.time: escape_time, + time.struct_time: escape_struct_time, + Decimal: escape_object, +} diff --git a/chdb/dbapi/cursors.py b/chdb/dbapi/cursors.py new file mode 100644 index 00000000000..9fa762b30bc --- /dev/null +++ b/chdb/dbapi/cursors.py @@ -0,0 +1,298 @@ +from . import err +import re + +# Regular expression for :meth:`Cursor.executemany`. +# executemany only supports simple bulk insert. +# You can use it to load large dataset. +RE_INSERT_VALUES = re.compile( + r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" + + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", + re.IGNORECASE | re.DOTALL) + + +class Cursor(object): + """ + This is the object you use to interact with the database. + + Do not create an instance of a Cursor yourself. Call + connections.Connection.cursor(). + + See `Cursor `_ in + the specification. + """ + + #: Max statement size which :meth:`executemany` generates. + #: + #: Default value is 1024000. + max_stmt_length = 1024000 + + def __init__(self, connection): + self.connection = connection + self.description = None + self.rowcount = -1 + self.rownumber = 0 + self.arraysize = 1 + self.lastrowid = None + self._result = None + self._rows = None + self._executed = None + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + del exc_info + self.close() + + def __iter__(self): + return iter(self.fetchone, None) + + def callproc(self, procname, args=()): + """Execute stored procedure procname with args + + procname -- string, name of procedure to execute on server + + args -- Sequence of parameters to use with procedure + + Returns the original args. + + Compatibility warning: PEP-249 specifies that any modified + parameters must be returned. This is currently impossible + as they are only available by storing them in a server + variable and then retrieved by a query. Since stored + procedures return zero or more result sets, there is no + reliable way to get at OUT or INOUT parameters via callproc. + The server variables are named @_procname_n, where procname + is the parameter above and n is the position of the parameter + (from zero). Once all result sets generated by the procedure + have been fetched, you can issue a SELECT @_procname_0, ... + query using .execute() to get any OUT or INOUT values. + + Compatibility warning: The act of calling a stored procedure + itself creates an empty result set. This appears after any + result sets generated by the procedure. This is non-standard + behavior with respect to the DB-API. Be sure to use nextset() + to advance through all result sets; otherwise you may get + disconnected. + """ + + return args + + def close(self): + """ + Closing a cursor just exhausts all remaining data. + """ + conn = self.connection + if conn is None: + return + try: + while self.nextset(): + pass + finally: + self.connection = None + + def _get_db(self): + if not self.connection: + raise err.ProgrammingError("Cursor closed") + return self.connection + + def _escape_args(self, args, conn): + if isinstance(args, (tuple, list)): + return tuple(conn.escape(arg) for arg in args) + elif isinstance(args, dict): + return {key: conn.escape(val) for (key, val) in args.items()} + else: + # If it's not a dictionary let's try escaping it anyway. + # Worst case it will throw a Value error + return conn.escape(args) + + def mogrify(self, query, args=None): + """ + Returns the exact string that is sent to the database by calling the + execute() method. + + This method follows the extension to the DB API 2.0 followed by Psycopg. + """ + conn = self._get_db() + + if args is not None: + query = query % self._escape_args(args, conn) + + return query + + def _clear_result(self): + self.rownumber = 0 + self._result = None + + self.rowcount = 0 + self.description = None + self.lastrowid = None + self._rows = None + + def _do_get_result(self): + conn = self._get_db() + + self._result = result = conn._result + + self.rowcount = result.affected_rows + self.description = result.description + self.lastrowid = result.insert_id + self._rows = result.rows + + def _query(self, q): + conn = self._get_db() + self._last_executed = q + self._clear_result() + conn.query(q) + self._do_get_result() + return self.rowcount + + def execute(self, query, args=None): + """Execute a query + + :param str query: Query to execute. + + :param args: parameters used with query. (optional) + :type args: tuple, list or dict + + :return: Number of affected rows + :rtype: int + + If args is a list or tuple, %s can be used as a placeholder in the query. + If args is a dict, %(name)s can be used as a placeholder in the query. + """ + while self.nextset(): + pass + + query = self.mogrify(query, args) + + result = self._query(query) + self._executed = query + return result + + def executemany(self, query, args): + # type: (str, list) -> int + """Run several data against one query + + :param query: query to execute on server + :param args: Sequence of sequences or mappings. It is used as parameter. + :return: Number of rows affected, if any. + + This method improves performance on multiple-row INSERT and + REPLACE. Otherwise, it is equivalent to looping over args with + execute(). + """ + if not args: + return 0 + + m = RE_INSERT_VALUES.match(query) + if m: + q_prefix = m.group(1) % () + q_values = m.group(2).rstrip() + q_postfix = m.group(3) or '' + assert q_values[0] == '(' and q_values[-1] == ')' + return self._do_execute_many(q_prefix, q_values, q_postfix, args, + self.max_stmt_length, + self._get_db().encoding) + + self.rowcount = sum(self.execute(query, arg) for arg in args) + return self.rowcount + + def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): + conn = self._get_db() + escape = self._escape_args + if isinstance(prefix, str): + prefix = prefix.encode(encoding) + if isinstance(postfix, str): + postfix = postfix.encode(encoding) + sql = str(prefix) + args = iter(args) + v = values % escape(next(args), conn) + if isinstance(v, str): + v = v.encode(encoding, 'surrogateescape') + sql += v + rows = 0 + for arg in args: + v = values % escape(arg, conn) + if isinstance(v, str): + v = v.encode(encoding, 'surrogateescape') + if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: + rows += self.execute(sql + postfix) + sql = str(prefix) + else: + sql += ',' + sql += v + rows += self.execute(sql + postfix) + self.rowcount = rows + return rows + + def _check_executed(self): + if not self._executed: + raise err.ProgrammingError("execute() first") + + def fetchone(self): + """Fetch the next row""" + self._check_executed() + if self._rows is None or self.rownumber >= len(self._rows): + return None + result = self._rows[self.rownumber] + self.rownumber += 1 + return result + + def fetchmany(self, size=None): + """Fetch several rows""" + self._check_executed() + if self._rows is None: + return () + end = self.rownumber + (size or self.arraysize) + result = self._rows[self.rownumber:end] + self.rownumber = min(end, len(self._rows)) + return result + + def fetchall(self): + """Fetch all the rows""" + self._check_executed() + if self._rows is None: + return () + if self.rownumber: + result = self._rows[self.rownumber:] + else: + result = self._rows + self.rownumber = len(self._rows) + return result + + def nextset(self): + """Get the next query set""" + # Not support for now + return None + + def setinputsizes(self, *args): + """Does nothing, required by DB API.""" + + def setoutputsizes(self, *args): + """Does nothing, required by DB API.""" + + +class DictCursor(Cursor): + """A cursor which returns results as a dictionary""" + # You can override this to use OrderedDict or other dict-like types. + dict_type = dict + + def _do_get_result(self): + super()._do_get_result() + fields = [] + if self.description: + for f in self.description: + name = f[0] + fields.append(name) + self._fields = fields + + if fields and self._rows: + self._rows = [self._conv_row(r) for r in self._rows] + + def _conv_row(self, row): + if row is None: + return None + return self.dict_type(zip(self._fields, row)) + diff --git a/chdb/dbapi/err.py b/chdb/dbapi/err.py new file mode 100644 index 00000000000..df97a15e108 --- /dev/null +++ b/chdb/dbapi/err.py @@ -0,0 +1,61 @@ +class StandardError(Exception): + """Exception related to operation with chdb.""" + + +class Warning(StandardError): + """Exception raised for important warnings like data truncations + while inserting, etc.""" + + +class Error(StandardError): + """Exception that is the base class of all other error exceptions + (not Warning).""" + + +class InterfaceError(Error): + """Exception raised for errors that are related to the database + interface rather than the database itself.""" + + +class DatabaseError(Error): + """Exception raised for errors that are related to the + database.""" + + +class DataError(DatabaseError): + """Exception raised for errors that are due to problems with the + processed data like division by zero, numeric value out of range, + etc.""" + + +class OperationalError(DatabaseError): + """Exception raised for errors that are related to the database's + operation and not necessarily under the control of the programmer, + e.g. an unexpected disconnect occurs, the data source name is not + found, a transaction could not be processed, a memory allocation + error occurred during processing, etc.""" + + +class IntegrityError(DatabaseError): + """Exception raised when the relational integrity of the database + is affected, e.g. a foreign key check fails, duplicate key, + etc.""" + + +class InternalError(DatabaseError): + """Exception raised when the database encounters an internal + error, e.g. the cursor is not valid anymore, the transaction is + out of sync, etc.""" + + +class ProgrammingError(DatabaseError): + """Exception raised for programming errors, e.g. table not found + or already exists, syntax error in the SQL statement, wrong number + of parameters specified, etc.""" + + +class NotSupportedError(DatabaseError): + """Exception raised in case a method or database API was used + which is not supported by the database, e.g. requesting a + .rollback() on a connection that does not support transaction or + has transactions turned off.""" diff --git a/chdb/dbapi/times.py b/chdb/dbapi/times.py new file mode 100644 index 00000000000..9afa599677a --- /dev/null +++ b/chdb/dbapi/times.py @@ -0,0 +1,21 @@ +from time import localtime +from datetime import date, datetime, time, timedelta + + +Date = date +Time = time +TimeDelta = timedelta +Timestamp = datetime + + +def DateFromTicks(ticks): + return date(*localtime(ticks)[:3]) + + +def TimeFromTicks(ticks): + return time(*localtime(ticks)[3:6]) + + +def TimestampFromTicks(ticks): + return datetime(*localtime(ticks)[:6]) + diff --git a/examples/dbapi.py b/examples/dbapi.py new file mode 100644 index 00000000000..82baa6f6f37 --- /dev/null +++ b/examples/dbapi.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +from chdb import dbapi +from chdb.dbapi.cursors import DictCursor + +print("chdb driver version: {0}".format(dbapi.get_client_info())) + +conn1 = dbapi.connect() +cur1 = conn1.cursor() +cur1.execute('select version()') +print("description: ", cur1.description) +print("data: ", cur1.fetchone()) +cur1.close() +conn1.close() + +conn2 = dbapi.connect(cursorclass=DictCursor) +cur2 = conn2.cursor() +cur2.execute(''' +SELECT + town, + district, + count() AS c, + round(avg(price)) AS price +FROM url('https://datasets-documentation.s3.eu-west-3.amazonaws.com/house_parquet/house_0.parquet') +GROUP BY + town, + district +LIMIT 10 +''') +print("description", cur2.description) +for row in cur2: + print(row) + +cur2.close() +conn2.close() diff --git a/setup.py b/setup.py index 5a0e57e66e0..147c012c64f 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ def fix_version_init(version): f.seek(0) f.write(init_content) f.truncate() - + # As of Python 3.6, CCompiler has a `has_flag` method. # cf http://bugs.python.org/issue26689 @@ -117,12 +117,16 @@ def build_extensions(self): print("CC: " + os.environ.get('CC')) print("CXX: " + os.environ.get('CXX')) if sys.platform == 'darwin': - if os.system('which /usr/local/opt/llvm/bin/clang++ > /dev/null') == 0: - os.environ['CC'] = '/usr/local/opt/llvm/bin/clang' - os.environ['CXX'] = '/usr/local/opt/llvm/bin/clang++' - elif os.system('which /usr/local/opt/llvm@15/bin/clang++ > /dev/null') == 0: - os.environ['CC'] = '/usr/local/opt/llvm@15/bin/clang' - os.environ['CXX'] = '/usr/local/opt/llvm@15/bin/clang++' + try: + brew_prefix = subprocess.check_output('brew --prefix', shell=True).decode("utf-8").strip("\n") + except Exception: + raise RuntimeError("Must install brew") + if os.system('which '+brew_prefix+'/opt/llvm/bin/clang++ > /dev/null') == 0: + os.environ['CC'] = brew_prefix + '/opt/llvm/bin/clang' + os.environ['CXX'] = brew_prefix + '/opt/llvm/bin/clang++' + elif os.system('which '+brew_prefix+'/opt/llvm@15/bin/clang++ > /dev/null') == 0: + os.environ['CC'] = brew_prefix + '/opt/llvm@15/bin/clang' + os.environ['CXX'] = brew_prefix + '/opt/llvm@15/bin/clang++' else: raise RuntimeError("Must use brew clang++") elif sys.platform == 'linux': From 02fa69acfc73c2cc20ded2d290ebf7092604bce2 Mon Sep 17 00:00:00 2001 From: auxten Date: Thu, 18 May 2023 11:03:40 +0000 Subject: [PATCH 2/5] Fix dbapi feature packaging --- MANIFEST.in | 2 -- chdb/dbapi/__init__.py | 14 +++++++------- gen_manifest.sh | 11 ++++++++++- setup.py | 11 ++++++++++- 4 files changed, 27 insertions(+), 11 deletions(-) delete mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 1e38966800c..00000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include chdb/*.py -include chdb/_chdb.cpython-39-darwin.so diff --git a/chdb/dbapi/__init__.py b/chdb/dbapi/__init__.py index 4c15da98bd3..dda61b73142 100644 --- a/chdb/dbapi/__init__.py +++ b/chdb/dbapi/__init__.py @@ -5,12 +5,12 @@ DatabaseError, OperationalError, IntegrityError, InternalError, NotSupportedError, ProgrammingError) from . import connections as _orig_conn +from .. import chdb_version -VERSION = (0, 1, 0, None) -if VERSION[3] is not None: - VERSION_STRING = "%d.%d.%d_%s" % VERSION +if chdb_version[3] is not None: + VERSION_STRING = "%d.%d.%d_%s" % chdb_version else: - VERSION_STRING = "%d.%d.%d" % VERSION[:3] + VERSION_STRING = "%d.%d.%d" % chdb_version[:3] threadsafety = 1 apilevel = "2.0" @@ -71,9 +71,9 @@ def Connect(*args, **kwargs): def get_client_info(): # for MySQLdb compatibility - version = VERSION - if VERSION[3] is None: - version = VERSION[:3] + version = chdb_version + if chdb_version[3] is None: + version = chdb_version[:3] return '.'.join(map(str, version)) diff --git a/gen_manifest.sh b/gen_manifest.sh index 1cd3ce35354..12d52796c05 100755 --- a/gen_manifest.sh +++ b/gen_manifest.sh @@ -6,6 +6,15 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" cd ${DIR} -echo "include chdb/*.py" > MANIFEST.in +rm -f MANIFEST.in + +echo "include README.md" >> MANIFEST.in +echo "include LICENSE.txt" >> MANIFEST.in +echo "graft chdb" >> MANIFEST.in +echo "global-exclude *.py[cod]" >> MANIFEST.in +echo "global-exclude __pycache__" >> MANIFEST.in +echo "global-exclude .DS_Store" >> MANIFEST.in +echo "global-exclude .git*" >> MANIFEST.in +echo "global-exclude ~*" >> MANIFEST.in export SO_SUFFIX=$(python3 -c "import sysconfig; print(sysconfig.get_config_var('EXT_SUFFIX'))") echo "include chdb/_chdb${SO_SUFFIX}" >> MANIFEST.in \ No newline at end of file diff --git a/setup.py b/setup.py index 147c012c64f..a801c499516 100644 --- a/setup.py +++ b/setup.py @@ -170,10 +170,19 @@ def build_extensions(self): # fix the version in chdb/__init__.py versionStr = get_latest_git_tag() fix_version_init(versionStr) + + # scan the chdb directory and add all the .py files to the package + pkg_files = [] + for root, dirs, files in os.walk(libdir): + for file in files: + if file.endswith(".py"): + pkg_files.append(os.path.join(root, file)) + pkg_files.append(chdb_so) + setup( packages=['chdb'], version=versionStr, - package_data={'chdb': [chdb_so]}, + package_data={'chdb': pkg_files}, exclude_package_data={'': ['*.pyc', 'src/**']}, ext_modules=ext_modules, python_requires='>=3.7', From cbaf958b64dbfe4c6b10bd56b95b5ca5fa591d0b Mon Sep 17 00:00:00 2001 From: auxten Date: Wed, 24 May 2023 05:16:14 +0000 Subject: [PATCH 3/5] Fix minor bug --- chdb/dbapi/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chdb/dbapi/__init__.py b/chdb/dbapi/__init__.py index dda61b73142..b77d2fc9ed6 100644 --- a/chdb/dbapi/__init__.py +++ b/chdb/dbapi/__init__.py @@ -7,7 +7,7 @@ from . import connections as _orig_conn from .. import chdb_version -if chdb_version[3] is not None: +if len(chdb_version) > 3 and chdb_version[3] is not None: VERSION_STRING = "%d.%d.%d_%s" % chdb_version else: VERSION_STRING = "%d.%d.%d" % chdb_version[:3] @@ -72,7 +72,7 @@ def Connect(*args, **kwargs): def get_client_info(): # for MySQLdb compatibility version = chdb_version - if chdb_version[3] is None: + if len(chdb_version) > 3 and chdb_version[3] is None: version = chdb_version[:3] return '.'.join(map(str, version)) From efc33f02e6256cead67dcdc3687abfdfabeaf1da Mon Sep 17 00:00:00 2001 From: laodouya Date: Wed, 24 May 2023 20:38:51 +0800 Subject: [PATCH 4/5] Remove arrow in dbapi --- chdb/dbapi/converters.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/chdb/dbapi/converters.py b/chdb/dbapi/converters.py index 17a210f219b..f5e7e7cb341 100644 --- a/chdb/dbapi/converters.py +++ b/chdb/dbapi/converters.py @@ -3,7 +3,6 @@ from .err import DataError import re import time -import arrow def escape_item(val, mapping=None): @@ -140,8 +139,6 @@ def convert_datetime(obj): >>> datetime_or_None('2007-02-25 23:06:20') datetime.datetime(2007, 2, 25, 23, 6, 20) - >>> datetime_or_None('2007-02-25T23:06:20') - datetime.datetime(2007, 2, 25, 23, 6, 20) Illegal values are raise DataError @@ -150,7 +147,8 @@ def convert_datetime(obj): obj = obj.decode('ascii') try: - return arrow.get(obj).datetime + time_obj = datetime.datetime.strptime(obj, '%Y-%m-%d %H:%M:%S') + return time_obj except Exception as err: raise DataError("Not valid datetime struct: %s" % err) @@ -212,7 +210,8 @@ def convert_time(obj): obj = obj.decode('ascii') try: - return arrow.get("1970-01-01T" + obj).time() + time_obj = datetime.datetime.strptime(obj, '%H:%M:%S') + return time_obj.time() except Exception: return convert_timedelta(obj) @@ -234,7 +233,8 @@ def convert_date(obj): if isinstance(obj, (bytes, bytearray)): obj = obj.decode('ascii') try: - return arrow.get(obj).date() + time_obj = datetime.datetime.strptime(obj, '%Y-%m-%d') + return time_obj.date() except Exception as err: raise DataError("Not valid date struct: %s" % err) From 58165f724d71c882626c1318d4debf1c8a8b800c Mon Sep 17 00:00:00 2001 From: auxten Date: Thu, 25 May 2023 07:03:41 +0000 Subject: [PATCH 5/5] Readme --- README-zh.md | 1 + README.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README-zh.md b/README-zh.md index 765d0122bec..6e917977c60 100644 --- a/README-zh.md +++ b/README-zh.md @@ -20,6 +20,7 @@ * 嵌入在 Python 中的 SQL OLAP 引擎,由 ClickHouse 驱动 * 不需要安装 ClickHouse * 支持 Parquet、CSV、JSON、Arrow、ORC 和其他 60 多种格式的[输入输出](https://clickhouse.com/docs/en/interfaces/formats),[示例](tests/format_output.py)。 +* 支持 Python DB API 2.0 标准, [example](examples/dbapi.py) ## 架构
diff --git a/README.md b/README.md index 67ad972c67a..ff56f12cabb 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ * No need to install ClickHouse * Minimized data copy from C++ to Python with [python memoryview](https://docs.python.org/3/c-api/memoryview.html) * Input&Output support Parquet, CSV, JSON, Arrow, ORC and 60+[more](https://clickhouse.com/docs/en/interfaces/formats) formats, [samples](tests/format_output.py) +* Support Python DB API 2.0, [example](examples/dbapi.py) ## Arch