diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index 715ba423f28..b51b457359b 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -18,16 +18,16 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; +import java.util.Collection; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.mapping.AggregatePath; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.query.Criteria; import org.springframework.data.relational.core.query.CriteriaDefinition; import org.springframework.data.relational.core.query.Query; import org.springframework.data.relational.core.sql.Condition; @@ -36,6 +36,7 @@ import org.springframework.data.relational.core.sqlgeneration.SingleQuerySqlGenerator; import org.springframework.data.relational.core.sqlgeneration.SqlGenerator; import org.springframework.data.relational.domain.RowDocument; +import org.springframework.data.util.Streamable; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.lang.Nullable; @@ -43,7 +44,7 @@ /** * Reads complete Aggregates from the database, by generating appropriate SQL using a {@link SingleQuerySqlGenerator} - * through {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate}. Results are converterd into an + * through {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate}. Results are converted into an * intermediate {@link RowDocumentResultSetExtractor RowDocument} and mapped via * {@link org.springframework.data.relational.core.conversion.RelationalConverter#read(Class, RowDocument)}. * @@ -55,7 +56,8 @@ class AggregateReader { private final RelationalPersistentEntity aggregate; - private final org.springframework.data.relational.core.sqlgeneration.SqlGenerator sqlGenerator; + private final Table table; + private final SqlGenerator sqlGenerator; private final JdbcConverter converter; private final NamedParameterJdbcOperations jdbcTemplate; private final RowDocumentResultSetExtractor extractor; @@ -66,6 +68,7 @@ class AggregateReader { this.converter = converter; this.aggregate = aggregate; this.jdbcTemplate = jdbcTemplate; + this.table = Table.create(aggregate.getQualifiedTableName()); this.sqlGenerator = new CachingSqlGenerator( new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate)); @@ -74,62 +77,58 @@ class AggregateReader { createPathToColumnMapping(aliasFactory)); } - public List findAll() { - return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll); - } - @Nullable public T findById(Object id) { - id = converter.writeValue(id, aggregate.getRequiredIdProperty().getTypeInformation()); + Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).is(id)).limit(1); - return jdbcTemplate.query(sqlGenerator.findById(), Map.of("id", id), this::extractZeroOrOne); + return findOne(query); } - public Iterable findAllById(Iterable ids) { + @Nullable + public T findOne(Query query) { - List convertedIds = new ArrayList<>(); - for (Object id : ids) { - convertedIds.add(converter.writeValue(id, aggregate.getRequiredIdProperty().getTypeInformation())); - } + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + Condition condition = createCondition(query, parameterSource); - return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll); + return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractZeroOrOne); } - public Iterable findAllBy(Query query) { + public List findAll() { + return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll); + } - MapSqlParameterSource parameterSource = new MapSqlParameterSource(); - BiFunction condition = createConditionSource(query, parameterSource); - return jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractAll); + public List findAllById(Iterable ids) { + + Collection identifiers = ids instanceof Collection idl ? idl : Streamable.of(ids).toList(); + Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers)).limit(1); + + return findAll(query); } - public Optional findOneByQuery(Query query) { - - MapSqlParameterSource parameterSource = new MapSqlParameterSource(); - BiFunction condition = createConditionSource(query, parameterSource); + public List findAll(Query query) { - return Optional.ofNullable( - jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractZeroOrOne)); + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + Condition condition = createCondition(query, parameterSource); + return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractAll); } - private BiFunction createConditionSource(Query query, MapSqlParameterSource parameterSource) { + @Nullable + private Condition createCondition(Query query, MapSqlParameterSource parameterSource) { QueryMapper queryMapper = new QueryMapper(converter); - BiFunction condition = (table, aggregate) -> { - Optional criteria = query.getCriteria(); - return criteria - .map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate)) - .orElse(null); - }; - return condition; + Optional criteria = query.getCriteria(); + return criteria + .map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate)) + .orElse(null); } /** * Extracts a list of aggregates from the given {@link ResultSet} by utilizing the * {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms * to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract. - * + * * @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}. * @return a {@code List} of aggregates, fully converted. * @throws SQLException @@ -195,21 +194,15 @@ public String keyColumn(AggregatePath path) { * @author Jens Schauder * @since 3.2 */ - static class CachingSqlGenerator implements org.springframework.data.relational.core.sqlgeneration.SqlGenerator { - - private final org.springframework.data.relational.core.sqlgeneration.SqlGenerator delegate; + static class CachingSqlGenerator implements SqlGenerator { + private final SqlGenerator delegate; private final String findAll; - private final String findById; - private final String findAllById; public CachingSqlGenerator(SqlGenerator delegate) { this.delegate = delegate; - - findAll = delegate.findAll(); - findById = delegate.findById(); - findAllById = delegate.findAllById(); + this.findAll = delegate.findAll(); } @Override @@ -218,18 +211,8 @@ public String findAll() { } @Override - public String findById() { - return findById; - } - - @Override - public String findAllById() { - return findAllById; - } - - @Override - public String findAllByCondition(BiFunction conditionSource) { - return delegate.findAllByCondition(conditionSource); + public String findAll(@Nullable Condition condition) { + return delegate.findAll(condition); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java index 547eac67163..6695647197b 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java @@ -22,8 +22,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.function.Function; import org.springframework.data.domain.Sort; import org.springframework.data.jdbc.core.mapping.JdbcValue; @@ -35,7 +33,6 @@ import org.springframework.data.mapping.context.InvalidPersistentPropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.relational.core.dialect.Dialect; -import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.query.CriteriaDefinition; @@ -77,7 +74,7 @@ public QueryMapper(Dialect dialect, JdbcConverter converter) { Assert.notNull(converter, "JdbcConverter must not be null"); this.converter = converter; - this.mappingContext = (MappingContext) converter.getMappingContext(); + this.mappingContext = converter.getMappingContext(); } /** @@ -310,7 +307,7 @@ private Condition mapCondition(CriteriaDefinition criteria, MapSqlParameterSourc sqlType = getTypeHint(mappedValue, actualType.getType(), settableValue); } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())); + mappedValue = valueFunction.map(v -> convertValue(comparator, v, propertyField.getTypeHint())); sqlType = propertyField.getSqlType(); } else if (propertyField instanceof MetadataBackedField metadataBackedField // diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java index c0e20c425f7..a609619c2d3 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java @@ -77,13 +77,12 @@ public Iterable findAll(Class domainType, Pageable pageable) { @Override public Optional findOne(Query query, Class domainType) { - return getReader(domainType).findOneByQuery(query); + return Optional.ofNullable(getReader(domainType).findOne(query)); } @Override public Iterable findAll(Query query, Class domainType) { - - return getReader(domainType).findAllBy(query); + return getReader(domainType).findAll(query); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java index 0cb0b04638e..9628588f7ae 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java @@ -87,6 +87,7 @@ public Iterable findAllById(Iterable ids, Class domainType) { return super.findAllById(ids, domainType); } + @Override public Optional findOne(Query query, Class domainType) { if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) { @@ -137,11 +138,6 @@ private boolean entityQualifiesForSingleQueryLoading(Class entityType) { referenceFound = true; } - - // AggregateReferences aren't supported yet - // if (property.isAssociation()) { - // return false; - // } } return true; diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java index 8f0fa6e8184..16bdad90e69 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java @@ -21,12 +21,13 @@ import org.springframework.jdbc.core.namedparam.SqlParameterSource; /** - * This {@link SqlParameterSource} will apply escaping to it's values. - * + * This {@link SqlParameterSource} will apply escaping to its values. + * * @author Jens Schauder * @since 3.2 */ -public class EscapingParameterSource implements SqlParameterSource { +class EscapingParameterSource implements SqlParameterSource { + private final SqlParameterSource parameterSource; private final Escaper escaper; diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java index ac3c256d27a..b2f5c6ac936 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java @@ -15,12 +15,12 @@ */ package org.springframework.data.jdbc.repository.query; -import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.jdbc.core.namedparam.SqlParameterSource; /** - * Value object encapsulating a query containing named parameters and a{@link SqlParameterSource} to bind the parameters. + * Value object encapsulating a query containing named parameters and a{@link SqlParameterSource} to bind the + * parameters. * * @author Mark Paluch * @author Jens Schauder @@ -41,13 +41,12 @@ String getQuery() { return query; } + SqlParameterSource getParameterSource(Escaper escaper) { + return new EscapingParameterSource(parameterSource, escaper); + } + @Override public String toString() { return this.query; } - - public SqlParameterSource getParameterSource(Escaper escaper) { - - return new EscapingParameterSource(parameterSource, escaper); - } } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index 08adc38d749..0f233cdd645 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -176,9 +176,8 @@ public Expression getMappedObject(Expression expression, @Nullable RelationalPer return expression; } - if (expression instanceof Column) { + if (expression instanceof Column column) { - Column column = (Column) expression; Field field = createPropertyField(entity, column.getName()); TableLike table = column.getTable(); @@ -186,9 +185,7 @@ public Expression getMappedObject(Expression expression, @Nullable RelationalPer return column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable; } - if (expression instanceof SimpleFunction) { - - SimpleFunction function = (SimpleFunction) expression; + if (expression instanceof SimpleFunction function) { List arguments = function.getExpressions(); List mappedArguments = new ArrayList<>(arguments.size()); @@ -367,15 +364,14 @@ private Condition mapCondition(CriteriaDefinition criteria, MutableBindings bind Class typeHint; Comparator comparator = criteria.getComparator(); - if (criteria.getValue() instanceof Parameter) { - - Parameter parameter = (Parameter) criteria.getValue(); + if (criteria.getValue()instanceof Parameter parameter) { mappedValue = convertValue(comparator, parameter.getValue(), propertyField.getTypeHint()); typeHint = getTypeHint(mappedValue, actualType.getType(), parameter); } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())).apply(getEscaper(comparator)); + mappedValue = valueFunction.map(v -> convertValue(comparator, v, propertyField.getTypeHint())) + .apply(getEscaper(comparator)); typeHint = actualType.getType(); } else { diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java index b77d95bab94..372ed39048f 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java @@ -118,7 +118,7 @@ private Assignment getAssignment(SqlIdentifier columnName, Object value, Mutable } else if (value instanceof ValueFunction valueFunction) { - mappedValue = valueFunction.transform(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); + mappedValue = valueFunction.map(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); if (mappedValue == null) { return Assignments.value(column, SQL.nullLiteral()); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java index 780fdf0d9d7..8951ac2a81e 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java @@ -58,14 +58,18 @@ default Supplier toSupplier(Escaper escaper) { } /** - * Transforms the inner value of the ValueFunction using the profided transformation. + * Return a new ValueFunction applying the given mapping {@link Function}. The mapping function is applied after + * applying {@link Escaper}. * - * The default implementation just return the current {@literal ValueFunction}. - * This is not a valid implementation and serves just to maintain backward compatibility. - * - * @param transformation to be applied to the underlying value. + * @param mapper the mapping function to apply to the value. + * @param the type of the value returned from the mapping function. * @return a new {@literal ValueFunction}. * @since 3.2 */ - default ValueFunction transform(Function transformation) {return this;}; + default ValueFunction map(Function mapper) { + + Assert.notNull(mapper, "Mapping function must not be null"); + + return escaper -> mapper.apply(this.apply(escaper)); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java index 505a027dcaa..ff0a61f771c 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java @@ -19,9 +19,9 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import org.springframework.data.mapping.PersistentProperty; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.PersistentPropertyPaths; @@ -46,7 +46,6 @@ public class SingleQuerySqlGenerator implements SqlGenerator { private final Dialect dialect; private final AliasFactory aliases; private final RelationalPersistentEntity aggregate; - private final Table table; public SingleQuerySqlGenerator(RelationalMappingContext context, AliasFactory aliasFactory, Dialect dialect, RelationalPersistentEntity aggregate) { @@ -55,47 +54,14 @@ public SingleQuerySqlGenerator(RelationalMappingContext context, AliasFactory al this.aliases = aliasFactory; this.dialect = dialect; this.aggregate = aggregate; - - this.table = Table.create(aggregate.getQualifiedTableName()); - } - - @Override - public String findAll() { - return createSelect(null); - } - - @Override - public String findById() { - - AggregatePath path = getRootIdPath(); - Condition condition = Conditions.isEqual(table.column(path.getColumnInfo().name()), Expressions.just(":id")); - - return createSelect(condition); } @Override - public String findAllById() { - - AggregatePath path = getRootIdPath(); - Condition condition = Conditions.in(table.column(path.getColumnInfo().name()), Expressions.just(":ids")); - - return createSelect(condition); - } - - @Override - public String findAllByCondition(BiFunction conditionSource) { - Condition condition = conditionSource.apply(table, aggregate); + public String findAll(@Nullable Condition condition) { return createSelect(condition); } - /** - * @return The {@link AggregatePath} to the id property of the aggregate root. - */ - private AggregatePath getRootIdPath() { - return context.getAggregatePath(aggregate).append(aggregate.getRequiredIdProperty()); - } - - String createSelect(Condition condition) { + String createSelect(@Nullable Condition condition) { AggregatePath rootPath = context.getAggregatePath(aggregate); QueryMeta queryMeta = createInlineQuery(rootPath, condition); @@ -168,7 +134,7 @@ private List createInlineQueries(PersistentPropertyPaths inlineQueries = new ArrayList<>(); - for (PersistentPropertyPath ppp : paths) { + for (PersistentPropertyPath ppp : paths) { QueryMeta queryMeta = createInlineQuery(context.getAggregatePath(ppp), null); inlineQueries.add(queryMeta); @@ -188,7 +154,7 @@ private List createInlineQueries(PersistentPropertyPaths entity = basePath.getRequiredLeafEntity(); Table table = Table.create(entity.getQualifiedTableName()); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java index 80eb9a1a874..fe783882a54 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java @@ -15,11 +15,8 @@ */ package org.springframework.data.relational.core.sqlgeneration; -import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.sql.Condition; -import org.springframework.data.relational.core.sql.Table; - -import java.util.function.BiFunction; +import org.springframework.lang.Nullable; /** * Generates SQL statements for loading aggregates. @@ -28,13 +25,12 @@ * @since 3.2 */ public interface SqlGenerator { - String findAll(); - - String findById(); - String findAllById(); + default String findAll() { + return findAll(null); + } - String findAllByCondition(BiFunction conditionSource); + String findAll(@Nullable Condition condition); AliasFactory getAliasFactory(); } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java deleted file mode 100644 index 2be3e8c3710..00000000000 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.data.relational.repository.query; - -import org.springframework.data.relational.core.dialect.Escaper; -import org.springframework.data.relational.core.query.ValueFunction; - -import java.util.function.Function; - -/** - * Value function that has an underlying value and a modifier that gets applied after the escaper. - * - * @author Jens Schauder - * @since 3.2 - */ -record ModifyingValueFunction(Object value, Function modifier) implements ValueFunction { - - static ModifyingValueFunction of(Object value, Function modifier) { - return new ModifyingValueFunction(value, modifier); - } - - @Override - public String apply(Escaper escaper) { - return modifier.apply(escaper.escape(value.toString())); - } - - @Override - public ValueFunction transform(Function transformation) { - return new ModifyingValueFunction(transformation.apply(value), modifier); - } -} diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java index 1261899b49c..2f781e89c34 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java @@ -20,6 +20,7 @@ import java.util.List; import org.springframework.data.relational.core.dialect.Escaper; +import org.springframework.data.relational.core.query.ValueFunction; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; import org.springframework.data.repository.query.parser.Part; @@ -136,16 +137,12 @@ protected Object prepareParameterValue(@Nullable Object value, Class valueTyp return value; } - switch (partType) { - case STARTING_WITH: - return ModifyingValueFunction.of(value, s -> s + "%"); - case ENDING_WITH: - return ModifyingValueFunction.of(value, s -> "%" + s); - case CONTAINING: - case NOT_CONTAINING: - return ModifyingValueFunction.of(value, s -> "%" + s + "%"); - default: - return value; - } + return switch (partType) { + case STARTING_WITH -> (ValueFunction) escaper -> escaper.escape(value.toString()) + "%"; + case ENDING_WITH -> (ValueFunction) escaper -> "%" + escaper.escape(value.toString()); + case CONTAINING, NOT_CONTAINING -> (ValueFunction) escaper -> "%" + escaper.escape(value.toString()) + + "%"; + default -> value; + }; } } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java index 5742a2c4ed2..bb62ab7b91e 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java @@ -16,7 +16,6 @@ package org.springframework.data.relational.core.mapping; import static org.assertj.core.api.Assertions.*; -import static org.assertj.core.api.SoftAssertions.*; import org.junit.jupiter.api.Test; import org.springframework.data.relational.core.sql.IdentifierProcessing; @@ -44,7 +43,6 @@ public void quotedSimpleObjectIdentifierWithAdjustableLetterCasing() { assertThat(identifier.toSql(BRACKETS_LOWER_CASE)).isEqualTo("[somename]"); assertThat(identifier.getReference(BRACKETS_LOWER_CASE)).isEqualTo("someName"); assertThat(identifier.getReference()).isEqualTo("someName"); - } @Test // DATAJDBC-386 @@ -77,12 +75,12 @@ public void equality() { SqlIdentifier notSimple = SqlIdentifier.from(new DerivedSqlIdentifier("simple", false), new DerivedSqlIdentifier("not", false)); - assertSoftly(softly -> { + assertThat(basis).isEqualTo(equal).isEqualTo(SqlIdentifier.unquoted("simple")) + .hasSameHashCodeAs(SqlIdentifier.unquoted("simple")); + assertThat(equal).isEqualTo(basis); + assertThat(basis).isNotEqualTo(quoted); + assertThat(basis).isNotEqualTo(notSimple); - softly.assertThat(basis).isEqualTo(equal); - softly.assertThat(equal).isEqualTo(basis); - softly.assertThat(basis).isNotEqualTo(quoted); - softly.assertThat(basis).isNotEqualTo(notSimple); - }); + assertThat(quoted).isEqualTo(SqlIdentifier.quoted("SIMPLE")).hasSameHashCodeAs(SqlIdentifier.quoted("SIMPLE")); } } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java index 5721ce2b422..ade6e0dad12 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.data.relational.core.sqlgeneration; import static org.springframework.data.relational.core.sqlgeneration.SqlAssert.*; @@ -28,7 +27,10 @@ import org.springframework.data.relational.core.dialect.PostgresDialect; import org.springframework.data.relational.core.mapping.AggregatePath; import org.springframework.data.relational.core.mapping.RelationalMappingContext; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.sql.Conditions; +import org.springframework.data.relational.core.sql.Table; /** * Tests for {@link SingleQuerySqlGenerator}. @@ -76,7 +78,8 @@ void createSelectForFindAll() { @Test // GH-1446 void createSelectForFindById() { - String sql = sqlGenerator.findById(); + Table table = Table.create(persistentEntity.getQualifiedTableName()); + String sql = sqlGenerator.findAll(table.column("id").isEqualTo(Conditions.just(":id"))); SqlAssert baseSelect = assertThatParsed(sql).hasInlineView(); @@ -94,13 +97,14 @@ void createSelectForFindById() { col("\"id\"").as(alias("id")), // col("\"name\"").as(alias("name")) // ) // - .extractWhereClause().isEqualTo("\"trivial_aggregate\".\"id\" = :id"); + .extractWhereClause().isEqualTo("\"trivial_aggregate\".id = :id"); } @Test // GH-1446 void createSelectForFindAllById() { - String sql = sqlGenerator.findAllById(); + Table table = Table.create(persistentEntity.getQualifiedTableName()); + String sql = sqlGenerator.findAll(table.column("id").in(Conditions.just(":ids"))); SqlAssert baseSelect = assertThatParsed(sql).hasInlineView(); @@ -118,7 +122,7 @@ void createSelectForFindAllById() { col("\"id\"").as(alias("id")), // col("\"name\"").as(alias("name")) // ) // - .extractWhereClause().isEqualTo("\"trivial_aggregate\".\"id\" IN (:ids)"); + .extractWhereClause().isEqualTo("\"trivial_aggregate\".id IN (:ids)"); } } @@ -133,7 +137,8 @@ private AggregateWithSingleReference() { @Test // GH-1446 void createSelectForFindById() { - String sql = sqlGenerator.findById(); + Table table = Table.create(persistentEntity.getQualifiedTableName()); + String sql = sqlGenerator.findAll(table.column("id").isEqualTo(Conditions.just(":id"))); String rootRowNumber = rnAlias(); String rootCount = rcAlias(); @@ -167,7 +172,7 @@ void createSelectForFindById() { col("\"id\"").as(alias("id")), // col("\"name\"").as(alias("name")) // ) // - .extractWhereClause().isEqualTo("\"single_reference_aggregate\".\"id\" = :id"); + .extractWhereClause().isEqualTo("\"single_reference_aggregate\".id = :id"); baseSelect.hasInlineViewSelectingFrom("\"trivial_aggregate\"") // .hasExactlyColumns( // rn(col("\"single_reference_aggregate\"")).as(trivialsRowNumber), // @@ -206,13 +211,14 @@ record SingleReferenceAggregate(@Id Long id, String name, List private class AbstractTestFixture { final Class aggregateRootType; final SingleQuerySqlGenerator sqlGenerator; + final RelationalPersistentEntity persistentEntity; final AliasFactory aliases; private AbstractTestFixture(Class aggregateRootType) { this.aggregateRootType = aggregateRootType; - this.sqlGenerator = new SingleQuerySqlGenerator(context, new AliasFactory(), dialect, - context.getRequiredPersistentEntity(aggregateRootType)); + this.persistentEntity = context.getRequiredPersistentEntity(aggregateRootType); + this.sqlGenerator = new SingleQuerySqlGenerator(context, new AliasFactory(), dialect, persistentEntity); this.aliases = sqlGenerator.getAliasFactory(); }