Skip to content

Commit

Permalink
feat: Add .list.to_array expression (pola-rs#12192)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Nov 2, 2023
1 parent 90d3694 commit 9fc7387
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 2 deletions.
23 changes: 23 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub enum ListFunction {
#[cfg(feature = "list_any_all")]
All,
Join,
#[cfg(feature = "dtype-array")]
ToArray(usize),
}

impl ListFunction {
Expand Down Expand Up @@ -87,10 +89,21 @@ impl ListFunction {
#[cfg(feature = "list_any_all")]
All => mapper.with_dtype(DataType::Boolean),
Join => mapper.with_dtype(DataType::Utf8),
#[cfg(feature = "dtype-array")]
ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
}
}
}

#[cfg(feature = "dtype-array")]
fn map_list_dtype_to_array_dtype(datatype: &DataType, width: usize) -> PolarsResult<DataType> {
if let DataType::List(inner) = datatype {
Ok(DataType::Array(inner.clone(), width))
} else {
polars_bail!(ComputeError: "expected List dtype")
}
}

impl Display for ListFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use ListFunction::*;
Expand Down Expand Up @@ -141,6 +154,8 @@ impl Display for ListFunction {
#[cfg(feature = "list_any_all")]
All => "all",
Join => "join",
#[cfg(feature = "dtype-array")]
ToArray(_) => "to_array",
};
write!(f, "list.{name}")
}
Expand Down Expand Up @@ -194,6 +209,8 @@ impl From<ListFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "list_any_all")]
All => map!(lst_all),
Join => map_as_slice!(join),
#[cfg(feature = "dtype-array")]
ToArray(width) => map!(to_array, width),
}
}
}
Expand Down Expand Up @@ -518,3 +535,9 @@ pub(super) fn join(s: &[Series]) -> PolarsResult<Series> {
let separator = s[1].utf8()?;
Ok(ca.lst_join(separator)?.into_series())
}

#[cfg(feature = "dtype-array")]
pub(super) fn to_array(s: &Series, width: usize) -> PolarsResult<Series> {
let array_dtype = map_list_dtype_to_array_dtype(s.dtype(), width)?;
s.cast(&array_dtype)
}
7 changes: 7 additions & 0 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ impl ListNameSpace {
self.slice(lit(0i64) - n.clone().cast(DataType::Int64), n)
}

#[cfg(feature = "dtype-array")]
/// Convert a List column into an Array column with the same inner data type.
pub fn to_array(self, width: usize) -> Expr {
self.0
.map_private(FunctionExpr::ListExpr(ListFunction::ToArray(width)))
}

#[cfg(feature = "list_to_struct")]
#[allow(clippy::wrong_self_convention)]
/// Convert this `List` to a `Series` of type `Struct`. The width will be determined according to
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The following methods are available under the `expr.list` attribute.
Expr.list.symmetric_difference
Expr.list.tail
Expr.list.take
Expr.list.to_array
Expr.list.to_struct
Expr.list.union
Expr.list.unique
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The following methods are available under the `Series.list` attribute.
Series.list.symmetric_difference
Series.list.tail
Series.list.take
Series.list.to_array
Series.list.to_struct
Series.list.union
Series.list.unique
35 changes: 34 additions & 1 deletion py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,40 @@ def count_matches(self, element: IntoExpr) -> Expr:
element = parse_as_expression(element, str_as_lit=True)
return wrap_expr(self._pyexpr.list_count_matches(element))

def to_array(self, width: int) -> Expr:
"""
Convert a List column into an Array column with the same inner data type.
Parameters
----------
width
Width of the resulting Array column.
Returns
-------
Expr
Expression of data type :class:`Array`.
Examples
--------
>>> df = pl.DataFrame(
... data={"a": [[1, 2], [3, 4]]},
... schema={"a": pl.List(pl.Int8)},
... )
>>> df.select(pl.col("a").list.to_array(2))
shape: (2, 1)
┌──────────────┐
│ a │
│ --- │
│ array[i8, 2] │
╞══════════════╡
│ [1, 2] │
│ [3, 4] │
└──────────────┘
"""
return wrap_expr(self._pyexpr.list_to_array(width))

def to_struct(
self,
n_field_strategy: ToStructStrategy = "first_non_null",
Expand Down Expand Up @@ -1135,7 +1169,6 @@ def set_symmetric_difference(self, other: IntoExpr) -> Expr:
other
Right hand side of the set operation.
Examples
--------
>>> df = pl.DataFrame(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def to_list(self) -> Series:
Returns
-------
Expr
Series
Series of data type :class:`List`.
Examples
Expand Down
27 changes: 27 additions & 0 deletions py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,33 @@ def count_matches(
"""

def to_array(self, width: int) -> Series:
"""
Convert a List column into an Array column with the same inner data type.
Parameters
----------
width
Width of the resulting Array column.
Returns
-------
Series
Series of data type :class:`Array`.
Examples
--------
>>> s = pl.Series([[1, 2], [3, 4]], dtype=pl.List(pl.Int8))
>>> s.list.to_array(2)
shape: (2,)
Series: '' [array[i8, 2]]
[
[1, 2]
[3, 4]
]
"""

def to_struct(
self,
n_field_strategy: ToStructStrategy = "first_non_null",
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ impl PyExpr {
.into()
}

fn list_to_array(&self, width: usize) -> Self {
self.inner.clone().list().to_array(width).into()
}

#[pyo3(signature = (width_strat, name_gen, upper_bound))]
fn list_to_struct(
&self,
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,25 @@ def test_list_lengths_deprecated() -> None:
result = s.list.lengths()
expected = pl.Series([3, 1], dtype=pl.UInt32)
assert_series_equal(result, expected)


def test_list_to_array() -> None:
data = [[1.0, 2.0], [3.0, 4.0]]
s = pl.Series(data, dtype=pl.List(pl.Float32))

result = s.list.to_array(2)

expected = pl.Series(data, dtype=pl.Array(inner=pl.Float32, width=2))
assert_series_equal(result, expected)


def test_list_to_array_wrong_lengths() -> None:
s = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float32))
with pytest.raises(pl.ComputeError, match="incompatible offsets in source list"):
s.list.to_array(3)


def test_list_to_array_wrong_dtype() -> None:
s = pl.Series([1.0, 2.0])
with pytest.raises(pl.ComputeError, match="expected List dtype"):
s.list.to_array(2)

0 comments on commit 9fc7387

Please sign in to comment.