From 9fc738794b58bf57a073986c5b3d16474514dc9f Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 2 Nov 2023 14:59:48 +0100 Subject: [PATCH] feat: Add `.list.to_array` expression (#12192) --- .../polars-plan/src/dsl/function_expr/list.rs | 23 ++++++++++++ crates/polars-plan/src/dsl/list.rs | 7 ++++ .../source/reference/expressions/list.rst | 1 + .../docs/source/reference/series/list.rst | 1 + py-polars/polars/expr/list.py | 35 ++++++++++++++++++- py-polars/polars/series/array.py | 2 +- py-polars/polars/series/list.py | 27 ++++++++++++++ py-polars/src/expr/list.rs | 4 +++ py-polars/tests/unit/namespaces/test_list.py | 22 ++++++++++++ 9 files changed, 120 insertions(+), 2 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 70542cd9f874..a2d7aa711164 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -48,6 +48,8 @@ pub enum ListFunction { #[cfg(feature = "list_any_all")] All, Join, + #[cfg(feature = "dtype-array")] + ToArray(usize), } impl ListFunction { @@ -87,10 +89,21 @@ impl ListFunction { #[cfg(feature = "list_any_all")] All => mapper.with_dtype(DataType::Boolean), Join => mapper.with_dtype(DataType::Utf8), + #[cfg(feature = "dtype-array")] + ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), } } } +#[cfg(feature = "dtype-array")] +fn map_list_dtype_to_array_dtype(datatype: &DataType, width: usize) -> PolarsResult { + if let DataType::List(inner) = datatype { + Ok(DataType::Array(inner.clone(), width)) + } else { + polars_bail!(ComputeError: "expected List dtype") + } +} + impl Display for ListFunction { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use ListFunction::*; @@ -141,6 +154,8 @@ impl Display for ListFunction { #[cfg(feature = "list_any_all")] All => "all", Join => "join", + #[cfg(feature = "dtype-array")] + ToArray(_) => "to_array", }; write!(f, "list.{name}") } @@ -194,6 +209,8 @@ impl From for SpecialEq> { #[cfg(feature = "list_any_all")] All => map!(lst_all), Join => map_as_slice!(join), + #[cfg(feature = "dtype-array")] + ToArray(width) => map!(to_array, width), } } } @@ -518,3 +535,9 @@ pub(super) fn join(s: &[Series]) -> PolarsResult { let separator = s[1].utf8()?; Ok(ca.lst_join(separator)?.into_series()) } + +#[cfg(feature = "dtype-array")] +pub(super) fn to_array(s: &Series, width: usize) -> PolarsResult { + let array_dtype = map_list_dtype_to_array_dtype(s.dtype(), width)?; + s.cast(&array_dtype) +} diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 6232c29c0a46..568a9866e268 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -231,6 +231,13 @@ impl ListNameSpace { self.slice(lit(0i64) - n.clone().cast(DataType::Int64), n) } + #[cfg(feature = "dtype-array")] + /// Convert a List column into an Array column with the same inner data type. + pub fn to_array(self, width: usize) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::ToArray(width))) + } + #[cfg(feature = "list_to_struct")] #[allow(clippy::wrong_self_convention)] /// Convert this `List` to a `Series` of type `Struct`. The width will be determined according to diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index f43401e20561..37c4ce1c6511 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -46,6 +46,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.symmetric_difference Expr.list.tail Expr.list.take + Expr.list.to_array Expr.list.to_struct Expr.list.union Expr.list.unique diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index ad766dd92eb9..9ce9d37fd911 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -46,6 +46,7 @@ The following methods are available under the `Series.list` attribute. Series.list.symmetric_difference Series.list.tail Series.list.take + Series.list.to_array Series.list.to_struct Series.list.union Series.list.unique diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 220cefe96a25..d78976ca5be5 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -905,6 +905,40 @@ def count_matches(self, element: IntoExpr) -> Expr: element = parse_as_expression(element, str_as_lit=True) return wrap_expr(self._pyexpr.list_count_matches(element)) + def to_array(self, width: int) -> Expr: + """ + Convert a List column into an Array column with the same inner data type. + + Parameters + ---------- + width + Width of the resulting Array column. + + Returns + ------- + Expr + Expression of data type :class:`Array`. + + Examples + -------- + >>> df = pl.DataFrame( + ... data={"a": [[1, 2], [3, 4]]}, + ... schema={"a": pl.List(pl.Int8)}, + ... ) + >>> df.select(pl.col("a").list.to_array(2)) + shape: (2, 1) + ┌──────────────┐ + │ a │ + │ --- │ + │ array[i8, 2] │ + ╞══════════════╡ + │ [1, 2] │ + │ [3, 4] │ + └──────────────┘ + + """ + return wrap_expr(self._pyexpr.list_to_array(width)) + def to_struct( self, n_field_strategy: ToStructStrategy = "first_non_null", @@ -1135,7 +1169,6 @@ def set_symmetric_difference(self, other: IntoExpr) -> Expr: other Right hand side of the set operation. - Examples -------- >>> df = pl.DataFrame( diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index d0fd834f8c98..855827231172 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -114,7 +114,7 @@ def to_list(self) -> Series: Returns ------- - Expr + Series Series of data type :class:`List`. Examples diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 3af52072abf2..70f65cae5454 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -541,6 +541,33 @@ def count_matches( """ + def to_array(self, width: int) -> Series: + """ + Convert a List column into an Array column with the same inner data type. + + Parameters + ---------- + width + Width of the resulting Array column. + + Returns + ------- + Series + Series of data type :class:`Array`. + + Examples + -------- + >>> s = pl.Series([[1, 2], [3, 4]], dtype=pl.List(pl.Int8)) + >>> s.list.to_array(2) + shape: (2,) + Series: '' [array[i8, 2]] + [ + [1, 2] + [3, 4] + ] + + """ + def to_struct( self, n_field_strategy: ToStructStrategy = "first_non_null", diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index a8a6db6613b9..547a3fb93359 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -154,6 +154,10 @@ impl PyExpr { .into() } + fn list_to_array(&self, width: usize) -> Self { + self.inner.clone().list().to_array(width).into() + } + #[pyo3(signature = (width_strat, name_gen, upper_bound))] fn list_to_struct( &self, diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 4b957dc3f704..457c7e4865e5 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -663,3 +663,25 @@ def test_list_lengths_deprecated() -> None: result = s.list.lengths() expected = pl.Series([3, 1], dtype=pl.UInt32) assert_series_equal(result, expected) + + +def test_list_to_array() -> None: + data = [[1.0, 2.0], [3.0, 4.0]] + s = pl.Series(data, dtype=pl.List(pl.Float32)) + + result = s.list.to_array(2) + + expected = pl.Series(data, dtype=pl.Array(inner=pl.Float32, width=2)) + assert_series_equal(result, expected) + + +def test_list_to_array_wrong_lengths() -> None: + s = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float32)) + with pytest.raises(pl.ComputeError, match="incompatible offsets in source list"): + s.list.to_array(3) + + +def test_list_to_array_wrong_dtype() -> None: + s = pl.Series([1.0, 2.0]) + with pytest.raises(pl.ComputeError, match="expected List dtype"): + s.list.to_array(2)