diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/QueryTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/QueryTests.cs index 1e3ee682..2aeb847d 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/QueryTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/QueryTests.cs @@ -55,6 +55,30 @@ public async Task CanFilterOnProperty() Assert.Equal("Pete Peterson", singer.FullName); } + [Fact] + public async Task CanFilterOnListOfIds() + { + using var db = new TestSpannerSampleDbContext(_fixture.DatabaseName); + var singerIds = new List { _fixture.RandomLong(), _fixture.RandomLong(), _fixture.RandomLong() }; + var index = 0; + foreach (var singerId in singerIds) + { + index++; + db.Singers.Add(new Singers { SingerId = singerId, FirstName = "Pete", LastName = $"Peterson{index}" }); + } + await db.SaveChangesAsync(); + + singerIds.Add(_fixture.RandomLong()); + var singers = await db.Singers + .Where(s => singerIds.Contains(s.SingerId)) + .OrderBy(s => s.LastName) + .ToListAsync(); + Assert.Collection(singers, + singer => Assert.Equal("Peterson1", singer.LastName), + singer => Assert.Equal("Peterson2", singer.LastName), + singer => Assert.Equal("Peterson3", singer.LastName)); + } + [Fact] public async Task CanOrderByProperty() { diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs index 96b1599b..e0a63a15 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs @@ -110,6 +110,38 @@ public async Task FindSingerAsync_ReturnsNull_IfNotFound() Assert.Null(singer); } + [Fact] + public async Task FindSingersUsingListOfIds_UsesParameterizedQuery() + { + var sql = $"SELECT `s`.`SingerId`, `s`.`BirthDate`, `s`.`FirstName`, `s`.`FullName`, `s`.`LastName`, `s`.`Picture`{Environment.NewLine}" + + $"FROM `Singers` AS `s`{Environment.NewLine}" + + $"WHERE `s`.`SingerId` IN UNNEST (@__singerIds_0)"; + AddFindSingerResult(sql); + + var singerIds = new List{1L, 2L, 3L}; + using var db = new MockServerSampleDbContext(ConnectionString); + var singers = await db.Singers.Where(singer => singerIds.Contains(singer.SingerId)).ToListAsync(); + Assert.Single(singers); + Assert.Collection( + _fixture.SpannerMock.Requests.OfType(), + request => + { + Assert.Equal(sql, request.Sql); + Assert.Single(request.Params.Fields); + var fields = request.Params.Fields; + Assert.Collection(fields["__singerIds_0"].ListValue.Values, + v => Assert.Equal("1", v.StringValue), + v => Assert.Equal("2", v.StringValue), + v => Assert.Equal("3", v.StringValue) + ); + Assert.Single(request.ParamTypes); + var type = request.ParamTypes["__singerIds_0"]; + Assert.Equal(V1.TypeCode.Array, type.Code); + Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code); + } + ); + } + [Fact] public async Task FindSingerAsync_ReturnsInstance_IfFound() { diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerInExpression.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerInExpression.cs new file mode 100644 index 00000000..de4ed4b0 --- /dev/null +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerInExpression.cs @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Storage; + +namespace Google.Cloud.EntityFrameworkCore.Spanner.Query.Internal; + +/// +/// Generates an `(NOT) IN UNNEST (@param)`-style IN expression. +/// +class SpannerInExpression( + SqlExpression item, + SqlParameterExpression valuesParameter, + RelationalTypeMapping itemTypeMapping) + : InExpression(item, valuesParameter, itemTypeMapping) +{ + protected override void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Visit(Item); + expressionPrinter.Append(" IN UNNEST ("); + expressionPrinter.Visit(ValuesParameter); + expressionPrinter.Append(")"); + } +} diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerQuerySqlGenerator.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerQuerySqlGenerator.cs index 31aa1950..178590a8 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerQuerySqlGenerator.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerQuerySqlGenerator.cs @@ -128,6 +128,20 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction return base.VisitSqlFunction(sqlFunctionExpression); } + + protected override void GenerateIn(InExpression inExpression, bool negated) + { + if (inExpression.GetType() != typeof(SpannerInExpression)) + { + base.GenerateIn(inExpression, negated); + return; + } + Visit(inExpression.Item); + Sql.Append(negated ? " NOT IN " : " IN "); + Sql.Append(" UNNEST ("); + Visit(inExpression.ValuesParameter); + Sql.Append(")"); + } protected virtual Expression VisitContains(SpannerContainsExpression containsExpression) { diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlExpressionFactory.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlExpressionFactory.cs index 4adce174..91b9e868 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlExpressionFactory.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlExpressionFactory.cs @@ -30,6 +30,19 @@ public SpannerSqlExpressionFactory(SqlExpressionFactoryDependencies dependencies _boolTypeMapping = dependencies.TypeMappingSource.FindMapping(typeof(bool), dependencies.Model)!; } + public override InExpression In(SqlExpression item, SqlParameterExpression valuesParameter) + { + var parametersTypeMapping = Dependencies.TypeMappingSource.FindMapping(valuesParameter.Type); + if (parametersTypeMapping != null) + { + return new SpannerInExpression( + item, + (SqlParameterExpression) valuesParameter.ApplyTypeMapping(parametersTypeMapping), + _boolTypeMapping); + } + return base.In(item, valuesParameter); + } + public virtual SpannerContainsExpression SpannerContains(SqlExpression item, SqlExpression values, bool negated) { var typeMapping = item.TypeMapping ?? Dependencies.TypeMappingSource.FindMapping(item.Type, Dependencies.Model); diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlNullabilityProcessor.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlNullabilityProcessor.cs index a68b859c..c4ac4663 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlNullabilityProcessor.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Query/Internal/SpannerSqlNullabilityProcessor.cs @@ -44,6 +44,19 @@ protected override SqlExpression VisitCustomSqlExpression( return sqlExpression; } + protected override SqlExpression VisitIn( + InExpression inExpression, + bool allowOptimizedExpansion, + out bool nullable) + { + if (inExpression.GetType() == typeof(SpannerInExpression)) + { + nullable = false; + return inExpression; + } + return base.VisitIn(inExpression, allowOptimizedExpansion, out nullable); + } + protected virtual SqlExpression VisitSpannerContains(SpannerContainsExpression containsExpression, out bool nullable) { var item = Visit(containsExpression.Item, out var itemNullable);