Skip to content

Commit

Permalink
fix bug and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ullingerc committed Oct 20, 2024
1 parent f90d91c commit 91d1c15
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 19 deletions.
25 changes: 6 additions & 19 deletions src/engine/sparqlExpressions/StdevExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
// Chair of Algorithms and Data Structures.
// Author: Christoph Ullinger <[email protected]>

#ifndef QLEVER_STDEVEXPRESSION_H
#define QLEVER_STDEVEXPRESSION_H
#pragma once

#include <cmath>
#include <functional>
Expand Down Expand Up @@ -40,11 +39,9 @@ auto inline numValToDouble =
class DeviationExpression : public SparqlExpression {
private:
Ptr child_;
bool distinct_;

public:
DeviationExpression(bool distinct, Ptr&& child)
: child_{std::move(child)}, distinct_{distinct} {}
DeviationExpression(Ptr&& child) : child_{std::move(child)} {}

// __________________________________________________________________________
ExpressionResult evaluate(EvaluationContext* context) const override {
Expand Down Expand Up @@ -88,13 +85,7 @@ class DeviationExpression : public SparqlExpression {

auto generator =
detail::makeGenerator(AD_FWD(el), context->size(), context);
if (distinct_) {
context->cancellationHandle_->throwIfCancelled();
devImpl(detail::getUniqueElements(context, context->size(),
std::move(generator)));
} else {
devImpl(std::move(generator));
}
devImpl(std::move(generator));

if (undef) {
return IdOrLiteralOrIri{Id::makeUndefined()};
Expand All @@ -113,8 +104,7 @@ class DeviationExpression : public SparqlExpression {
// __________________________________________________________________________
[[nodiscard]] string getCacheKey(
const VariableToColumnMap& varColMap) const override {
return absl::StrCat("[ SQ.DEVIATION ", distinct_ ? " DISTINCT " : "", "]",
child_->getCacheKey(varColMap));
return absl::StrCat("[ SQ.DEVIATION ]", child_->getCacheKey(varColMap));
}

private:
Expand All @@ -135,8 +125,7 @@ class DeviationAggExpression
DeviationAggExpression(bool distinct, SparqlExpression::Ptr&& child,
AggregateOperation aggregateOp = AggregateOperation{})
: AggregateExpression<AggregateOperation, FinalOperation>(
distinct,
std::make_unique<DeviationExpression>(distinct, std::move(child)),
distinct, std::make_unique<DeviationExpression>(std::move(child)),
aggregateOp){};
};

Expand All @@ -162,13 +151,11 @@ using StdevExpressionBase =
DeviationAggExpression<AvgOperation, decltype(stdevFinalOperation)>;
class StdevExpression : public StdevExpressionBase {
using StdevExpressionBase::StdevExpressionBase;
ValueId resultForEmptyGroup() const override { return Id::makeFromInt(0); }
ValueId resultForEmptyGroup() const override { return Id::makeFromDouble(0); }
};

} // namespace detail

using detail::StdevExpression;

} // namespace sparqlExpression

#endif // QLEVER_STDEVEXPRESSION_H
33 changes: 33 additions & 0 deletions test/AggregateExpressionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Chair of Algorithms and Data Structures
// Author: Johannes Kalmbach <[email protected]>

#include <optional>
#include <type_traits>

#include "./SparqlExpressionTestHelpers.h"
#include "./util/GTestHelpers.h"
#include "./util/IdTableHelpers.h"
Expand All @@ -11,6 +14,9 @@
#include "engine/sparqlExpressions/AggregateExpression.h"
#include "engine/sparqlExpressions/CountStarExpression.h"
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/SparqlExpressionTypes.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "global/ValueId.h"
#include "gtest/gtest.h"

using namespace sparqlExpression;
Expand Down Expand Up @@ -94,6 +100,33 @@ TEST(AggregateExpression, avg) {
testAvgString({lit("alpha"), lit("äpfel"), lit("Beta"), lit("unfug")}, U);
}

// Test `StdevExpression`.
TEST(StdevExpression, avg) {
auto testStdevId = testAggregate<StdevExpression, Id>;

auto inputAsVector = std::vector{I(3), D(0), I(0), I(4), I(-2)};
VectorWithMemoryLimit<ValueId> input(inputAsVector.begin(),
inputAsVector.end(), makeAllocator());
auto d = std::make_unique<SingleUseExpression>(input.clone());
auto t = TestContext{};
t.context._endIndex = 5;
StdevExpression m{false, std::move(d)};
auto resAsVariant = m.evaluate(&t.context);
ASSERT_NEAR(std::get<ValueId>(resAsVariant).getDouble(), 2.44949, 0.0001);

testStdevId({D(2), D(2), D(2), D(2)}, D(0), true);

testStdevId({I(3), U}, U);
testStdevId({I(3), NaN}, NaN);

testStdevId({}, D(0));
testStdevId({D(500)}, D(0));
testStdevId({D(500), D(500), D(500)}, D(0));

auto testStdevString = testAggregate<StdevExpression, IdOrLiteralOrIri, Id>;
testStdevString({lit("alpha"), lit("äpfel"), lit("Beta"), lit("unfug")}, U);
}

// Test `MinExpression`.
TEST(AggregateExpression, min) {
auto testMinId = testAggregate<MinExpression, Id>;
Expand Down
11 changes: 11 additions & 0 deletions test/SparqlAntlrParserTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "engine/sparqlExpressions/RegexExpression.h"
#include "engine/sparqlExpressions/RelationalExpressions.h"
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "engine/sparqlExpressions/UuidExpressions.h"
#include "parser/ConstructClause.h"
#include "parser/SparqlParserHelpers.h"
Expand Down Expand Up @@ -1835,6 +1836,16 @@ TEST(SparqlParser, aggregateExpressions) {
expectAggregate(
"group_concat(DISTINCT ?x; SEPARATOR=\";\")",
matchAggregate<GroupConcatExpression>(true, V{"?x"}, separator(";")));

// The STDEV expression
// TODO<ullingec> Test failing because StdevExpression replaces its child

// expectAggregate("STDEV(?x)", matchAggregate<StdevExpression>(false,
// V{"?x"})); expectAggregate("stdev(?x)",
// matchAggregate<StdevExpression>(false, V{"?x"})); A distinct stdev is
// probably not very useful, but should be possible anyway
// expectAggregate("STDEV(DISTINCT ?x)",
// matchAggregate<StdevExpression>(true, V{"?x"}));
}

// Update queries are WIP. The individual parts to parse some update queries
Expand Down
4 changes: 4 additions & 0 deletions test/SparqlExpressionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/SparqlExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "parser/GeoPoint.h"
#include "util/AllocatorTestHelpers.h"
#include "util/Conversions.h"
Expand Down Expand Up @@ -1369,6 +1370,9 @@ TEST(SparqlExpression, isAggregateAndIsDistinct) {
EXPECT_THAT(GroupConcatExpression(false, varX(), " "), match(false));
EXPECT_THAT(GroupConcatExpression(true, varX(), " "), match(true));

EXPECT_THAT(StdevExpression(false, varX()), match(false));
EXPECT_THAT(StdevExpression(true, varX()), match(true));

EXPECT_THAT(SampleExpression(false, varX()), match(false));
// For `SAMPLE` the distinctness makes no difference, so we always return `not
// distinct`.
Expand Down
14 changes: 14 additions & 0 deletions test/SparqlParserTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,20 @@ TEST(ParserTest, testSolutionModifiers) {
ASSERT_FALSE(pq._orderBy[0].isDescending_);
}

{
auto pq = SparqlParser::parseQuery(
"SELECT ?r (STDEV(?r) as ?stdev) WHERE {"
"?a <http://schema.org/name> ?b ."
"?a ql:has-relation ?r }"
"GROUP BY ?r "
"ORDER BY ?stdev");
ASSERT_EQ(1u, pq.children().size());
ASSERT_EQ(1u, pq._orderBy.size());
EXPECT_THAT(pq, m::GroupByVariables({Var{"?r"}}));
ASSERT_EQ(Var{"?stdev"}, pq._orderBy[0].variable_);
ASSERT_FALSE(pq._orderBy[0].isDescending_);
}

{
auto pq = SparqlParser::parseQuery(
"SELECT ?r (COUNT(DISTINCT ?r) as ?count) WHERE {"
Expand Down

0 comments on commit 91d1c15

Please sign in to comment.