From 20c5cb6f4e56f7c37593510c84c29f7a970e7235 Mon Sep 17 00:00:00 2001 From: Mathan Date: Sat, 25 Nov 2023 23:33:16 -0500 Subject: [PATCH] string_agg implementation --- evadb/expression/abstract_expression.py | 1 + evadb/expression/aggregation_expression.py | 6 ++++++ evadb/models/storage/batch.py | 16 ++++++++++++++++ evadb/parser/evadb.lark | 4 ++++ evadb/parser/lark_visitor/_functions.py | 10 ++++++++++ 5 files changed, 37 insertions(+) diff --git a/evadb/expression/abstract_expression.py b/evadb/expression/abstract_expression.py index 9b72f32e6..8551c4ec4 100644 --- a/evadb/expression/abstract_expression.py +++ b/evadb/expression/abstract_expression.py @@ -56,6 +56,7 @@ class ExpressionType(IntEnum): AGGREGATION_FIRST = auto() AGGREGATION_LAST = auto() AGGREGATION_SEGMENT = auto() + AGGREGATION_STRING_AGG = auto() CASE = auto() # add other types diff --git a/evadb/expression/aggregation_expression.py b/evadb/expression/aggregation_expression.py index f1ba6b16c..c57cc73aa 100644 --- a/evadb/expression/aggregation_expression.py +++ b/evadb/expression/aggregation_expression.py @@ -54,6 +54,12 @@ def evaluate(self, *args, **kwargs): batch.aggregate("min") elif self.etype == ExpressionType.AGGREGATION_MAX: batch.aggregate("max") + elif self.etype == ExpressionType.AGGREGATION_STRING_AGG: + # Assuming two children: the column and the delimiter + column_to_aggregate = self.get_child(0).evaluate(*args, **kwargs) + delimiter = kwargs.get('delimiter') + batch.aggregate_string_aggregation(column_to_aggregate, delimiter) + batch.reset_index() column_name = self.etype.name diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 43e69cc4f..e1fdddbe7 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -450,3 +450,19 @@ def to_numpy(self): def rename(self, columns) -> None: "Rename column names" self._frames.rename(columns=columns, inplace=True) + + def aggregate_string_aggregation(self, column_name:str, delimiter:str): + # First, ensure the column data is in string format + string_column = self._frames[column_name].astype(str) + + def aggregate_column(data, sep): + # Join the data using the provided separator + aggregated_string = sep.join(data) + return aggregated_string + + aggregated_result = aggregate_column(string_column, delimiter) + + aggregated_dataframe = pd.DataFrame({column_name: [aggregated_result]}) + + # Update the original DataFrame with the new aggregated DataFrame + self._frames = aggregated_dataframe diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index 4b96bf647..998a9360b 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -297,6 +297,7 @@ or_replace: OR REPLACE function_call: function ->function_call | aggregate_windowed_function ->aggregate_function_call + | string_agg_function function: simple_id "(" (STAR | function_args)? ")" dotted_id? @@ -306,6 +307,9 @@ aggregate_windowed_function: aggregate_function_name "(" function_arg ")" aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT +string_agg_function: "STRING_AGG" LR_BRACKET expression COMMA string_literal RR_BRACKET + + function_args: (function_arg) ("," function_arg)* function_arg: constant | expression diff --git a/evadb/parser/lark_visitor/_functions.py b/evadb/parser/lark_visitor/_functions.py index 2b2c18095..df1e68240 100644 --- a/evadb/parser/lark_visitor/_functions.py +++ b/evadb/parser/lark_visitor/_functions.py @@ -134,11 +134,14 @@ def get_aggregate_function_type(self, agg_func_name): agg_func_type = ExpressionType.AGGREGATION_LAST elif agg_func_name == "SEGMENT": agg_func_type = ExpressionType.AGGREGATION_SEGMENT + elif agg_func_name == "STRING_AGG": + agg_func_type = ExpressionType.AGGREGATION_STRING_AGG return agg_func_type def aggregate_windowed_function(self, tree): agg_func_arg = None agg_func_name = None + agg_func_args = [] for child in tree.children: if isinstance(child, Tree): @@ -156,6 +159,13 @@ def aggregate_windowed_function(self, tree): else: agg_func_arg = TupleValueExpression(name="id") + if agg_func_name == "STRING_AGG": + if len(agg_func_args) != 2: + raise ValueError("String Agg requires exactly two arguments") + agg_func_type = self.get_aggregate_function_type(agg_func_name) + agg_expr = AggregationExpression(agg_func_type, None, *agg_func_args) + return agg_expr + agg_func_type = self.get_aggregate_function_type(agg_func_name) agg_expr = AggregationExpression(agg_func_type, None, agg_func_arg) return agg_expr