Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add STDEV() aggregate function #1553

Merged
merged 14 commits into from
Nov 13, 2024
3 changes: 3 additions & 0 deletions src/engine/GroupBy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/SparqlExpression.h"
#include "engine/sparqlExpressions/SparqlExpressionGenerators.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "global/RuntimeParameters.h"
#include "index/Index.h"
#include "index/IndexImpl.h"
Expand Down Expand Up @@ -1026,6 +1027,8 @@ GroupBy::isSupportedAggregate(sparqlExpression::SparqlExpression* expr) {
if (auto val = dynamic_cast<GroupConcatExpression*>(expr)) {
return H{GROUP_CONCAT, val->getSeparator()};
}
// NOTE: The STDEV function is not suitable for lazy and hash map
// optimizations.
if (dynamic_cast<SampleExpression*>(expr)) return H{SAMPLE};

// `expr` is an unsupported aggregate
Expand Down
6 changes: 6 additions & 0 deletions src/engine/sparqlExpressions/AggregateExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "engine/sparqlExpressions/AggregateExpression.h"

#include "engine/sparqlExpressions/GroupConcatExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"

namespace sparqlExpression::detail {

Expand Down Expand Up @@ -180,6 +181,11 @@ AggregateExpression<AggregateOperation, FinalOperation>::getVariableForCount()
// Explicit instantiation for the AVG expression.
template class AggregateExpression<AvgOperation, decltype(avgFinalOperation)>;

// Explicit instantiation for the STDEV expression.
template class AggregateExpression<AvgOperation, decltype(stdevFinalOperation)>;
template class DeviationAggExpression<AvgOperation,
decltype(stdevFinalOperation)>;

// Explicit instantiations for the other aggregate expressions.
#define INSTANTIATE_AGG_EXP(Function, ValueGetter) \
template class AggregateExpression< \
Expand Down
1 change: 1 addition & 0 deletions src/engine/sparqlExpressions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_library(sparqlExpressions
SampleExpression.cpp
RelationalExpressions.cpp
AggregateExpression.cpp
StdevExpression.cpp
RegexExpression.cpp
NumericUnaryExpressions.cpp
NumericBinaryExpressions.cpp
Expand Down
74 changes: 74 additions & 0 deletions src/engine/sparqlExpressions/StdevExpression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2024, University of Freiburg,
// Chair of Algorithms and Data Structures.
// Author: Christoph Ullinger <[email protected]>

#include "engine/sparqlExpressions/StdevExpression.h"

namespace sparqlExpression {

namespace detail {

// _____________________________________________________________________________
ExpressionResult DeviationExpression::evaluate(
EvaluationContext* context) const {
auto impl = [context](SingleExpressionResult auto&& el) -> ExpressionResult {
// Prepare space for result
VectorWithMemoryLimit<IdOrLiteralOrIri> exprResult{context->_allocator};
std::fill_n(std::back_inserter(exprResult), context->size(),
IdOrLiteralOrIri{Id::makeUndefined()});
bool undef = false;

auto devImpl = [&undef, &exprResult, context](auto generator) {
double sum = 0.0;
// Intermediate storage of the results returned from the child
// expression
VectorWithMemoryLimit<double> childResults{context->_allocator};

// Collect values as doubles
for (auto& inp : generator) {
const auto& n = detail::NumericValueGetter{}(std::move(inp), context);
auto v = std::visit(
[]<typename T>(T&& value) -> std::optional<double> {
if constexpr (ad_utility::isSimilar<T, double> ||
ad_utility::isSimilar<T, int64_t>) {
return static_cast<double>(value);
} else {
return std::nullopt;
}
},
n);
if (v.has_value()) {
childResults.push_back(v.value());
sum += v.value();
} else {
// There is a non-numeric value in the input. Therefore the entire
// result will be undef.
undef = true;
return;
}
context->cancellationHandle_->throwIfCancelled();
}

// Calculate squared deviation and save for result
double avg = sum / static_cast<double>(context->size());
for (size_t i = 0; i < childResults.size(); i++) {
exprResult.at(i) = IdOrLiteralOrIri{
ValueId::makeFromDouble(std::pow(childResults.at(i) - avg, 2))};
}
};

auto generator =
detail::makeGenerator(AD_FWD(el), context->size(), context);
devImpl(std::move(generator));

if (undef) {
return IdOrLiteralOrIri{Id::makeUndefined()};
}
return exprResult;
};
auto childRes = child_->evaluate(context);
return std::visit(impl, std::move(childRes));
};

} // namespace detail
} // namespace sparqlExpression
100 changes: 100 additions & 0 deletions src/engine/sparqlExpressions/StdevExpression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2024, University of Freiburg,
// Chair of Algorithms and Data Structures.
// Author: Christoph Ullinger <[email protected]>

#pragma once

#include <cmath>
#include <functional>
#include <memory>
#include <variant>

#include "engine/sparqlExpressions/AggregateExpression.h"
#include "engine/sparqlExpressions/LiteralExpression.h"
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/SparqlExpression.h"
#include "engine/sparqlExpressions/SparqlExpressionTypes.h"
#include "engine/sparqlExpressions/SparqlExpressionValueGetters.h"
#include "global/ValueId.h"

namespace sparqlExpression {

namespace detail {

/// The STDEV Expression

// Helper expression: The individual deviation squares. A DeviationExpression
// over X corresponds to the value (X - AVG(X))^2.
class DeviationExpression : public SparqlExpression {
private:
Ptr child_;

public:
DeviationExpression(Ptr&& child) : child_{std::move(child)} {}

// __________________________________________________________________________
ExpressionResult evaluate(EvaluationContext* context) const override;

// __________________________________________________________________________
AggregateStatus isAggregate() const override {
return SparqlExpression::AggregateStatus::NoAggregate;
}

// __________________________________________________________________________
[[nodiscard]] string getCacheKey(
const VariableToColumnMap& varColMap) const override {
return absl::StrCat("[ SQ.DEVIATION ]", child_->getCacheKey(varColMap));
}

Check warning on line 47 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L45-L47

Added lines #L45 - L47 were not covered by tests

private:
// _________________________________________________________________________
std::span<SparqlExpression::Ptr> childrenImpl() override {
return {&child_, 1};
}
};

// Separate subclass of AggregateOperation, that replaces its child with a
// DeviationExpression of this child. Everything else is left untouched.
template <typename AggregateOperation,
typename FinalOperation = decltype(identity)>
class DeviationAggExpression
: public AggregateExpression<AggregateOperation, FinalOperation> {
public:
// __________________________________________________________________________
DeviationAggExpression(bool distinct, SparqlExpression::Ptr&& child,
AggregateOperation aggregateOp = AggregateOperation{})
: AggregateExpression<AggregateOperation, FinalOperation>(
distinct, std::make_unique<DeviationExpression>(std::move(child)),
aggregateOp){};
};

// The final operation for dividing by degrees of freedom and calculation square
// root after summing up the squared deviation
inline auto stdevFinalOperation = [](const NumericValue& aggregation,
size_t numElements) {
auto divAndRoot = [](double value, double degreesOfFreedom) {
if (degreesOfFreedom <= 0) {
return 0.0;
} else {
return std::sqrt(value / degreesOfFreedom);
}
};
return makeNumericExpressionForAggregate<decltype(divAndRoot)>()(
aggregation, NumericValue{static_cast<double>(numElements) - 1});
};

// The actual Standard Deviation Expression
// Mind the explicit instantiation of StdevExpressionBase in
// AggregateExpression.cpp
using StdevExpressionBase =
DeviationAggExpression<AvgOperation, decltype(stdevFinalOperation)>;
class StdevExpression : public StdevExpressionBase {
using StdevExpressionBase::StdevExpressionBase;
ValueId resultForEmptyGroup() const override { return Id::makeFromDouble(0); }
};

} // namespace detail

using detail::StdevExpression;

} // namespace sparqlExpression
3 changes: 3 additions & 0 deletions src/parser/sparqlParser/SparqlQleverVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "engine/sparqlExpressions/RegexExpression.h"
#include "engine/sparqlExpressions/RelationalExpressions.h"
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "engine/sparqlExpressions/UuidExpressions.h"
#include "parser/GraphPatternOperation.h"
#include "parser/RdfParser.h"
Expand Down Expand Up @@ -2372,6 +2373,8 @@ ExpressionPtr Visitor::visit(Parser::AggregateContext* ctx) {
}

return makePtr.operator()<GroupConcatExpression>(std::move(separator));
} else if (functionName == "stdev") {
return makePtr.operator()<StdevExpression>();
} else {
AD_CORRECTNESS_CHECK(functionName == "sample");
return makePtr.operator()<SampleExpression>();
Expand Down
1 change: 1 addition & 0 deletions src/parser/sparqlParser/SparqlQleverVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "engine/sparqlExpressions/AggregateExpression.h"
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "parser/data/GraphRef.h"
#undef EOF
#include "parser/sparqlParser/generated/SparqlAutomaticVisitor.h"
Expand Down
2 changes: 2 additions & 0 deletions src/parser/sparqlParser/generated/SparqlAutomatic.g4
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ aggregate : COUNT '(' DISTINCT? ( '*' | expression ) ')'
| MIN '(' DISTINCT? expression ')'
| MAX '(' DISTINCT? expression ')'
| AVG '(' DISTINCT? expression ')'
| STDEV '(' DISTINCT? expression ')'
| SAMPLE '(' DISTINCT? expression ')'
| GROUP_CONCAT '(' DISTINCT? expression ( ';' SEPARATOR '=' string )? ')' ;

Expand Down Expand Up @@ -763,6 +764,7 @@ SUM : S U M;
MIN : M I N;
MAX : M A X;
AVG : A V G;
STDEV : S T D E V ;
SAMPLE : S A M P L E;
SEPARATOR : S E P A R A T O R;

Expand Down
4 changes: 3 additions & 1 deletion src/parser/sparqlParser/generated/SparqlAutomatic.interp

Large diffs are not rendered by default.

75 changes: 38 additions & 37 deletions src/parser/sparqlParser/generated/SparqlAutomatic.tokens
Original file line number Diff line number Diff line change
Expand Up @@ -136,43 +136,44 @@ SUM=135
MIN=136
MAX=137
AVG=138
SAMPLE=139
SEPARATOR=140
IRI_REF=141
PNAME_NS=142
PNAME_LN=143
BLANK_NODE_LABEL=144
VAR1=145
VAR2=146
LANGTAG=147
PREFIX_LANGTAG=148
INTEGER=149
DECIMAL=150
DOUBLE=151
INTEGER_POSITIVE=152
DECIMAL_POSITIVE=153
DOUBLE_POSITIVE=154
INTEGER_NEGATIVE=155
DECIMAL_NEGATIVE=156
DOUBLE_NEGATIVE=157
EXPONENT=158
STRING_LITERAL1=159
STRING_LITERAL2=160
STRING_LITERAL_LONG1=161
STRING_LITERAL_LONG2=162
ECHAR=163
NIL=164
ANON=165
PN_CHARS_U=166
VARNAME=167
PN_PREFIX=168
PN_LOCAL=169
PLX=170
PERCENT=171
HEX=172
PN_LOCAL_ESC=173
WS=174
COMMENTS=175
STDEV=139
SAMPLE=140
SEPARATOR=141
IRI_REF=142
PNAME_NS=143
PNAME_LN=144
BLANK_NODE_LABEL=145
VAR1=146
VAR2=147
LANGTAG=148
PREFIX_LANGTAG=149
INTEGER=150
DECIMAL=151
DOUBLE=152
INTEGER_POSITIVE=153
DECIMAL_POSITIVE=154
DOUBLE_POSITIVE=155
INTEGER_NEGATIVE=156
DECIMAL_NEGATIVE=157
DOUBLE_NEGATIVE=158
EXPONENT=159
STRING_LITERAL1=160
STRING_LITERAL2=161
STRING_LITERAL_LONG1=162
STRING_LITERAL_LONG2=163
ECHAR=164
NIL=165
ANON=166
PN_CHARS_U=167
VARNAME=168
PN_PREFIX=169
PN_LOCAL=170
PLX=171
PERCENT=172
HEX=173
PN_LOCAL_ESC=174
WS=175
COMMENTS=176
'*'=1
'('=2
')'=3
Expand Down
Loading
Loading