diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 417c8ab3..a3b23307 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -23,11 +23,16 @@ class BitXor(Aggregate): class GroupConcat(Aggregate): + template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s%(separator)s)" function = "GROUP_CONCAT" + name = "GroupConcat" + output_field = CharField() + allow_distinct = True def __init__( self, expression: Expression, + filter: Any | None = None, distinct: bool = False, separator: str | None = None, ordering: str | None = None, @@ -38,7 +43,7 @@ def __init__( # This can/will be improved to SetTextField or ListTextField extra["output_field"] = CharField() - super().__init__(expression, **extra) + super().__init__(expression, filter=filter, **extra) self.distinct = distinct self.separator = separator @@ -53,18 +58,41 @@ def as_sql( connection: BaseDatabaseWrapper, **extra_context: Any, ) -> tuple[str, tuple[Any, ...]]: + def expr_sql(): + expr_parts = [] + params = [] + for arg in self.source_expressions: + arg_sql, arg_params = compiler.compile(arg) + expr_parts.append(arg_sql) + params.extend(arg_params) + return self.arg_joiner.join(expr_parts), params + + if self.filter: + extra_context["distinct"] = "DISTINCT " if self.distinct else "" + copy = self.copy() + copy.filter = None + source_expressions = copy.get_source_expressions() + condition = When(self.filter, then=source_expressions[0]) + copy.set_source_expressions([Case(condition)] + source_expressions[1:]) + + expr_sql, _ = expr_sql() + + extra_context["order_by"] = ( + f" ORDER BY {expr_sql} {self.ordering}" if self.ordering else "" + ) + + extra_context["separator"] = ( + f" SEPARATOR '{self.separator}' " if self.separator else "" + ) + + return super(Aggregate, copy).as_sql(compiler, connection, **extra_context) + connection.ops.check_expression_support(self) sql = ["GROUP_CONCAT("] if self.distinct: sql.append("DISTINCT ") - expr_parts = [] - params = [] - for arg in self.source_expressions: - arg_sql, arg_params = compiler.compile(arg) - expr_parts.append(arg_sql) - params.extend(arg_params) - expr_sql = self.arg_joiner.join(expr_parts) + expr_sql, params = expr_sql() sql.append(expr_sql)