Skip to content

Commit

Permalink
string_agg implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mathanmahe committed Nov 26, 2023
1 parent fc08fd9 commit 20c5cb6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 0 deletions.
1 change: 1 addition & 0 deletions evadb/expression/abstract_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ExpressionType(IntEnum):
AGGREGATION_FIRST = auto()
AGGREGATION_LAST = auto()
AGGREGATION_SEGMENT = auto()
AGGREGATION_STRING_AGG = auto()

CASE = auto()
# add other types
Expand Down
6 changes: 6 additions & 0 deletions evadb/expression/aggregation_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def evaluate(self, *args, **kwargs):
batch.aggregate("min")
elif self.etype == ExpressionType.AGGREGATION_MAX:
batch.aggregate("max")
elif self.etype == ExpressionType.AGGREGATION_STRING_AGG:
# Assuming two children: the column and the delimiter
column_to_aggregate = self.get_child(0).evaluate(*args, **kwargs)
delimiter = kwargs.get('delimiter')
batch.aggregate_string_aggregation(column_to_aggregate, delimiter)

batch.reset_index()

column_name = self.etype.name
Expand Down
16 changes: 16 additions & 0 deletions evadb/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,19 @@ def to_numpy(self):
def rename(self, columns) -> None:
"Rename column names"
self._frames.rename(columns=columns, inplace=True)

def aggregate_string_aggregation(self, column_name:str, delimiter:str):
# First, ensure the column data is in string format
string_column = self._frames[column_name].astype(str)

def aggregate_column(data, sep):
# Join the data using the provided separator
aggregated_string = sep.join(data)
return aggregated_string

aggregated_result = aggregate_column(string_column, delimiter)

aggregated_dataframe = pd.DataFrame({column_name: [aggregated_result]})

# Update the original DataFrame with the new aggregated DataFrame
self._frames = aggregated_dataframe
4 changes: 4 additions & 0 deletions evadb/parser/evadb.lark
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ or_replace: OR REPLACE

function_call: function ->function_call
| aggregate_windowed_function ->aggregate_function_call
| string_agg_function

function: simple_id "(" (STAR | function_args)? ")" dotted_id?

Expand All @@ -306,6 +307,9 @@ aggregate_windowed_function: aggregate_function_name "(" function_arg ")"

aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT

string_agg_function: "STRING_AGG" LR_BRACKET expression COMMA string_literal RR_BRACKET


function_args: (function_arg) ("," function_arg)*

function_arg: constant | expression
Expand Down
10 changes: 10 additions & 0 deletions evadb/parser/lark_visitor/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,14 @@ def get_aggregate_function_type(self, agg_func_name):
agg_func_type = ExpressionType.AGGREGATION_LAST
elif agg_func_name == "SEGMENT":
agg_func_type = ExpressionType.AGGREGATION_SEGMENT
elif agg_func_name == "STRING_AGG":
agg_func_type = ExpressionType.AGGREGATION_STRING_AGG
return agg_func_type

def aggregate_windowed_function(self, tree):
agg_func_arg = None
agg_func_name = None
agg_func_args = []

for child in tree.children:
if isinstance(child, Tree):
Expand All @@ -156,6 +159,13 @@ def aggregate_windowed_function(self, tree):
else:
agg_func_arg = TupleValueExpression(name="id")

if agg_func_name == "STRING_AGG":
if len(agg_func_args) != 2:
raise ValueError("String Agg requires exactly two arguments")
agg_func_type = self.get_aggregate_function_type(agg_func_name)
agg_expr = AggregationExpression(agg_func_type, None, *agg_func_args)
return agg_expr

agg_func_type = self.get_aggregate_function_type(agg_func_name)
agg_expr = AggregationExpression(agg_func_type, None, agg_func_arg)
return agg_expr

0 comments on commit 20c5cb6

Please sign in to comment.