Skip to content

Commit

Permalink
perf: translate Contains to a parameterized IN query (#472)
Browse files Browse the repository at this point in the history
IN queries were automatically translated by EF Core to queries that
used `IN (value1, value2, ..., valueN)`. This requires Spanner to
re-compile the query for every possible permutation of the list of
values.

This change registers a specific SpannerInExpression that is used
when a query using a fixed list of values is being generated. These
queries are then translated to queries of this form:

`IN UNNEST (@values)`

The `@values` parameter is of type ARRAY.
  • Loading branch information
olavloite authored Nov 29, 2024
1 parent 8a4eb12 commit 85cad1c
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ public async Task CanFilterOnProperty()
Assert.Equal("Pete Peterson", singer.FullName);
}

[Fact]
public async Task CanFilterOnListOfIds()
{
using var db = new TestSpannerSampleDbContext(_fixture.DatabaseName);
var singerIds = new List<long> { _fixture.RandomLong(), _fixture.RandomLong(), _fixture.RandomLong() };
var index = 0;
foreach (var singerId in singerIds)
{
index++;
db.Singers.Add(new Singers { SingerId = singerId, FirstName = "Pete", LastName = $"Peterson{index}" });
}
await db.SaveChangesAsync();

singerIds.Add(_fixture.RandomLong());
var singers = await db.Singers
.Where(s => singerIds.Contains(s.SingerId))
.OrderBy(s => s.LastName)
.ToListAsync();
Assert.Collection(singers,
singer => Assert.Equal("Peterson1", singer.LastName),
singer => Assert.Equal("Peterson2", singer.LastName),
singer => Assert.Equal("Peterson3", singer.LastName));
}

[Fact]
public async Task CanOrderByProperty()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,38 @@ public async Task FindSingerAsync_ReturnsNull_IfNotFound()
Assert.Null(singer);
}

[Fact]
public async Task FindSingersUsingListOfIds_UsesParameterizedQuery()
{
var sql = $"SELECT `s`.`SingerId`, `s`.`BirthDate`, `s`.`FirstName`, `s`.`FullName`, `s`.`LastName`, `s`.`Picture`{Environment.NewLine}" +
$"FROM `Singers` AS `s`{Environment.NewLine}" +
$"WHERE `s`.`SingerId` IN UNNEST (@__singerIds_0)";
AddFindSingerResult(sql);

var singerIds = new List<long>{1L, 2L, 3L};
using var db = new MockServerSampleDbContext(ConnectionString);
var singers = await db.Singers.Where(singer => singerIds.Contains(singer.SingerId)).ToListAsync();
Assert.Single(singers);
Assert.Collection(
_fixture.SpannerMock.Requests.OfType<ExecuteSqlRequest>(),
request =>
{
Assert.Equal(sql, request.Sql);
Assert.Single(request.Params.Fields);
var fields = request.Params.Fields;
Assert.Collection(fields["__singerIds_0"].ListValue.Values,
v => Assert.Equal("1", v.StringValue),
v => Assert.Equal("2", v.StringValue),
v => Assert.Equal("3", v.StringValue)
);
Assert.Single(request.ParamTypes);
var type = request.ParamTypes["__singerIds_0"];
Assert.Equal(V1.TypeCode.Array, type.Code);
Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code);
}
);
}

[Fact]
public async Task FindSingerAsync_ReturnsInstance_IfFound()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Google.Cloud.EntityFrameworkCore.Spanner.Query.Internal;

/// <summary>
/// Generates an `(NOT) IN UNNEST (@param)`-style IN expression.
/// </summary>
class SpannerInExpression(
SqlExpression item,
SqlParameterExpression valuesParameter,
RelationalTypeMapping itemTypeMapping)
: InExpression(item, valuesParameter, itemTypeMapping)
{
protected override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Visit(Item);
expressionPrinter.Append(" IN UNNEST (");
expressionPrinter.Visit(ValuesParameter);
expressionPrinter.Append(")");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction

return base.VisitSqlFunction(sqlFunctionExpression);
}

protected override void GenerateIn(InExpression inExpression, bool negated)
{
if (inExpression.GetType() != typeof(SpannerInExpression))
{
base.GenerateIn(inExpression, negated);
return;
}
Visit(inExpression.Item);
Sql.Append(negated ? " NOT IN " : " IN ");
Sql.Append(" UNNEST (");
Visit(inExpression.ValuesParameter);
Sql.Append(")");
}

protected virtual Expression VisitContains(SpannerContainsExpression containsExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ public SpannerSqlExpressionFactory(SqlExpressionFactoryDependencies dependencies
_boolTypeMapping = dependencies.TypeMappingSource.FindMapping(typeof(bool), dependencies.Model)!;
}

public override InExpression In(SqlExpression item, SqlParameterExpression valuesParameter)
{
var parametersTypeMapping = Dependencies.TypeMappingSource.FindMapping(valuesParameter.Type);
if (parametersTypeMapping != null)
{
return new SpannerInExpression(
item,
(SqlParameterExpression) valuesParameter.ApplyTypeMapping(parametersTypeMapping),
_boolTypeMapping);
}
return base.In(item, valuesParameter);
}

public virtual SpannerContainsExpression SpannerContains(SqlExpression item, SqlExpression values, bool negated)
{
var typeMapping = item.TypeMapping ?? Dependencies.TypeMappingSource.FindMapping(item.Type, Dependencies.Model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ protected override SqlExpression VisitCustomSqlExpression(
return sqlExpression;
}

protected override SqlExpression VisitIn(
InExpression inExpression,
bool allowOptimizedExpansion,
out bool nullable)
{
if (inExpression.GetType() == typeof(SpannerInExpression))
{
nullable = false;
return inExpression;
}
return base.VisitIn(inExpression, allowOptimizedExpansion, out nullable);
}

protected virtual SqlExpression VisitSpannerContains(SpannerContainsExpression containsExpression, out bool nullable)
{
var item = Visit(containsExpression.Item, out var itemNullable);
Expand Down

0 comments on commit 85cad1c

Please sign in to comment.