diff --git a/bindings/python/pymongoarrow/context.py b/bindings/python/pymongoarrow/context.py index 8af48e4d..d7e4b352 100644 --- a/bindings/python/pymongoarrow/context.py +++ b/bindings/python/pymongoarrow/context.py @@ -121,4 +121,6 @@ def finish(self): for fname, builder in self.builder_map.items(): arrays.append(builder.finish()) names.append(fname.decode("utf-8")) + if self.schema is not None: + return Table.from_arrays(arrays=arrays, schema=self.schema.to_arrow()) return Table.from_arrays(arrays=arrays, names=names) diff --git a/bindings/python/pymongoarrow/lib.pyx b/bindings/python/pymongoarrow/lib.pyx index 1e4e516b..ac4c7716 100644 --- a/bindings/python/pymongoarrow/lib.pyx +++ b/bindings/python/pymongoarrow/lib.pyx @@ -710,12 +710,16 @@ cdef object get_field_builder(object field, object tzinfo): field_builder = DatetimeBuilder(field_type) elif _atypes.is_string(field_type): field_builder = StringBuilder() + elif _atypes.is_large_string(field_type): + field_builder = StringBuilder() elif _atypes.is_boolean(field_type): field_builder = BoolBuilder() elif _atypes.is_struct(field_type): field_builder = DocumentBuilder(field_type, tzinfo) elif _atypes.is_list(field_type): field_builder = ListBuilder(field_type, tzinfo) + elif _atypes.is_large_list(field_type): + field_builder = ListBuilder(field_type, tzinfo) elif getattr(field_type, '_type_marker') == _BsonArrowTypes.objectid: field_builder = ObjectIdBuilder() elif getattr(field_type, '_type_marker') == _BsonArrowTypes.decimal128: @@ -799,8 +803,8 @@ cdef class ListBuilder(_ArrayBuilderBase): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) cdef shared_ptr[CArrayBuilder] grandchild_builder self.dtype = dtype - if not _atypes.is_list(dtype): - raise ValueError("dtype must be a list_()") + if not (_atypes.is_list(dtype) or _atypes.is_large_list(dtype)): + raise ValueError("dtype must be a list_() or large_list()") self.context = context = PyMongoArrowContext(None, {}) self.context.tzinfo = tzinfo field_builder = get_field_builder(self.dtype.value_type, tzinfo) diff --git a/bindings/python/pymongoarrow/schema.py b/bindings/python/pymongoarrow/schema.py index 3e3398ad..940be840 100644 --- a/bindings/python/pymongoarrow/schema.py +++ b/bindings/python/pymongoarrow/schema.py @@ -13,7 +13,7 @@ # limitations under the License. import collections.abc as abc -from pyarrow import ListType, StructType +import pyarrow as pa from pymongoarrow.types import _normalize_typeid @@ -73,9 +73,9 @@ def _get_projection(self): def _get_field_projection_value(self, ftype): value = True - if isinstance(ftype, ListType): + if isinstance(ftype, pa.ListType): return self._get_field_projection_value(ftype.value_field.type) - if isinstance(ftype, StructType): + if isinstance(ftype, pa.StructType): projection = {} for nested_ftype in ftype: projection[nested_ftype.name] = True @@ -86,3 +86,22 @@ def __eq__(self, other): if isinstance(other, type(self)): return self.typemap == other.typemap return False + + @classmethod + def from_arrow(cls, aschema: pa.Schema): + """Create a :class:`~pymongoarrow.schema.Schema` instance from a :class:`~pyarrow.Schema` + + :Parameters: + - `aschema`: PyArrow Schema + """ + self = cls({}) + for field in aschema: + self.typemap[field.name] = field.type + return self + + def to_arrow(self): + """Output the Schema as an instance of class:`~pyarrow.Schema`.""" + fields = [] + for name, type_ in self.typemap.items(): + fields.append(pa.field(name=name, type=type_)) + return pa.schema(fields) diff --git a/bindings/python/pymongoarrow/types.py b/bindings/python/pymongoarrow/types.py index 0f87add6..3a692753 100644 --- a/bindings/python/pymongoarrow/types.py +++ b/bindings/python/pymongoarrow/types.py @@ -266,6 +266,8 @@ def get_numpy_type(type): _atypes.is_boolean: _BsonArrowTypes.bool, _atypes.is_struct: _BsonArrowTypes.document, _atypes.is_list: _BsonArrowTypes.array, + _atypes.is_large_string: _BsonArrowTypes.string, + _atypes.is_large_list: _BsonArrowTypes.array, } @@ -296,6 +298,7 @@ def _get_internal_typemap(typemap): for checker, internal_id in _TYPE_CHECKER_TO_INTERNAL_TYPE.items(): if checker(ftype): internal_typemap[fname] = internal_id + break if fname not in internal_typemap: msg = f'Unsupported data type in schema for field "{fname}" of type "{ftype}"' diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index 1882b763..628c93af 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -32,6 +32,8 @@ field, int32, int64, + large_list, + large_string, list_, string, struct, @@ -707,6 +709,46 @@ def test_nested_bson_extension_types(self): self.assertIsInstance(new_obj.type[2].type, BinaryType) self.assertIsInstance(new_obj.type[3].type, CodeType) + def test_large_string_type(self): + """Tests pyarrow._large_string() DataType""" + data = Table.from_pydict( + {"string": ["A", "B", "C"], "large_string": ["C", "D", "E"]}, + ArrowSchema({"string": string(), "large_string": large_string()}), + ) + self.round_trip(data, Schema({"string": str, "large_string": large_string()})) + + def test_large_list_type(self): + """Tests pyarrow._large_list() DataType + + 1. Test large_list of large_string + - with schema in query, one gets full roundtrip consistency + - without schema, normal list and string will be inferred + + 2. Test nested as well + """ + + schema = ArrowSchema([field("_id", int32()), field("txns", large_list(large_string()))]) + + data = { + "_id": [1, 2, 3, 4], + "txns": [["A"], ["A", "B"], ["A", "B", "C"], ["A", "B", "C", "D"]], + } + table_orig = pa.Table.from_pydict(data, schema) + self.coll.drop() + res = write(self.coll, table_orig) + # 1a. + self.assertEqual(len(data["_id"]), res.raw_result["insertedCount"]) + table_schema = find_arrow_all(self.coll, {}, schema=Schema.from_arrow(schema)) + self.assertTrue(table_schema, table_orig) + # 1b. + table_none = find_arrow_all(self.coll, {}, schema=None) + self.assertTrue(table_none.schema.types == [int32(), list_(string())]) + self.assertTrue(table_none.to_pydict() == data) + + # 2. Test in sublist + schema, data = self._create_nested_data((large_list(int32()), list(range(3)))) + self.round_trip(data, Schema(schema)) + class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase): def run_find(self, *args, **kwargs): diff --git a/bindings/python/test/test_datetime.py b/bindings/python/test/test_datetime.py index 01fe6522..181e8a75 100644 --- a/bindings/python/test/test_datetime.py +++ b/bindings/python/test/test_datetime.py @@ -98,38 +98,68 @@ def test_timezone_specified_in_schema(self): self.assertEqual(table, expected) def test_timezone_specified_in_codec_options(self): - # 1. When specified, CodecOptions.tzinfo will modify timestamp - # type specifiers in the schema to inherit the specified timezone - tz = pytz.timezone("US/Pacific") - codec_options = CodecOptions(tz_aware=True, tzinfo=tz) - expected = Table.from_pydict( - {"_id": [1, 2], "data": self.expected_times}, - ArrowSchema([("_id", int32()), ("data", timestamp("ms", tz=tz))]), + """Test behavior of setting tzinfo CodecOptions in Collection.with_options. + + When provided, timestamp type specifiers in the schema to inherit the specified timezone. + Read values will maintain this information for timestamps whether schema is passed or not. + + Note, this does not apply to datetimes. + We also test here that if one asks for a different timezone upon reading, + on returns the requested timezone. + """ + + # 1. We pass tzinfo to Collection.with_options, and same tzinfo in schema of find_arrow_all + tz_west = pytz.timezone("US/Pacific") + codec_options = CodecOptions(tz_aware=True, tzinfo=tz_west) + coll_west = self.coll.with_options(codec_options=codec_options) + + schema_west = ArrowSchema([("_id", int32()), ("data", timestamp("ms", tz=tz_west))]) + table_west = find_arrow_all( + collection=coll_west, + query={}, + schema=Schema.from_arrow(schema_west), + sort=[("_id", ASCENDING)], ) - schemas = [ - Schema({"_id": int32(), "data": timestamp("ms")}), - Schema({"_id": int32(), "data": datetime}), - ] - for schema in schemas: - table = find_arrow_all( - self.coll.with_options(codec_options=codec_options), - {}, - schema=schema, - sort=[("_id", ASCENDING)], - ) + expected_west = Table.from_pydict( + {"_id": [1, 2], "data": self.expected_times}, schema=schema_west + ) + self.assertTrue(table_west.equals(expected_west)) - self.assertEqual(table, expected) + # 2. We pass tzinfo to Collection.with_options, but do NOT include a schema in find_arrow_all + table_none = find_arrow_all( + collection=coll_west, + query={}, + schema=None, + sort=[("_id", ASCENDING)], + ) + self.assertTrue(table_none.equals(expected_west)) - # 2. CodecOptions.tzinfo will be ignored when tzinfo is specified - # in the original schema type specifier. - tz_east = pytz.timezone("US/Eastern") - codec_options = CodecOptions(tz_aware=True, tzinfo=tz_east) - schema = Schema({"_id": int32(), "data": timestamp("ms", tz=tz)}) - table = find_arrow_all( - self.coll.with_options(codec_options=codec_options), - {}, - schema=schema, + # 3. Now we pass a DIFFERENT timezone to the schema in find_arrow_all than we did to the Collection + schema_east = Schema( + {"_id": int32(), "data": timestamp("ms", tz=pytz.timezone("US/Eastern"))} + ) + table_east = find_arrow_all( + collection=coll_west, + query={}, + schema=schema_east, sort=[("_id", ASCENDING)], ) - self.assertEqual(table, expected) + # Confirm that we get the timezone we requested + self.assertTrue(table_east.schema.types == [int32(), timestamp(unit="ms", tz="US/Eastern")]) + # Confirm that the times have been adjusted + times_west = table_west["data"].to_pylist() + times_east = table_east["data"].to_pylist() + self.assertTrue(all([times_west[i] == times_east[i] for i in range(len(table_east))])) + + # 4. Test behavior of datetime. Output will be pyarrow.timestamp("ms") without timezone + schema_dt = Schema({"_id": int32(), "data": datetime}) + table_dt = find_arrow_all( + collection=coll_west, + query={}, + schema=schema_dt, + sort=[("_id", ASCENDING)], + ) + self.assertTrue(table_dt.schema.types == [int32(), timestamp(unit="ms")]) + times = table_dt["data"].to_pylist() + self.assertTrue(times == self.expected_times)