diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index 77813ea9b3..715ba423f2 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -21,15 +21,22 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.mapping.AggregatePath; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.query.CriteriaDefinition; +import org.springframework.data.relational.core.query.Query; +import org.springframework.data.relational.core.sql.Condition; +import org.springframework.data.relational.core.sql.Table; import org.springframework.data.relational.core.sqlgeneration.AliasFactory; import org.springframework.data.relational.core.sqlgeneration.SingleQuerySqlGenerator; import org.springframework.data.relational.core.sqlgeneration.SqlGenerator; import org.springframework.data.relational.domain.RowDocument; +import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -89,6 +96,35 @@ public Iterable findAllById(Iterable ids) { return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll); } + public Iterable findAllBy(Query query) { + + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + BiFunction condition = createConditionSource(query, parameterSource); + return jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractAll); + } + + public Optional findOneByQuery(Query query) { + + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + BiFunction condition = createConditionSource(query, parameterSource); + + return Optional.ofNullable( + jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractZeroOrOne)); + } + + private BiFunction createConditionSource(Query query, MapSqlParameterSource parameterSource) { + + QueryMapper queryMapper = new QueryMapper(converter); + + BiFunction condition = (table, aggregate) -> { + Optional criteria = query.getCriteria(); + return criteria + .map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate)) + .orElse(null); + }; + return condition; + } + /** * Extracts a list of aggregates from the given {@link ResultSet} by utilizing the * {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms @@ -115,7 +151,8 @@ private List extractAll(ResultSet rs) throws SQLException { * to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract. * * @param @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}. - * @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is empty, null is returned. + * @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is + * empty, null is returned. * @throws SQLException * @throws IncorrectResultSizeDataAccessException when the conversion yields more than one instance. */ @@ -190,9 +227,15 @@ public String findAllById() { return findAllById; } + @Override + public String findAllByCondition(BiFunction conditionSource) { + return delegate.findAllByCondition(conditionSource); + } + @Override public AliasFactory getAliasFactory() { return delegate.getAliasFactory(); } + } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java index b4dcd05b64..547eac6716 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java @@ -310,7 +310,7 @@ private Condition mapCondition(CriteriaDefinition criteria, MapSqlParameterSourc sqlType = getTypeHint(mappedValue, actualType.getType(), settableValue); } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - mappedValue = valueFunction; + mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())); sqlType = propertyField.getSqlType(); } else if (propertyField instanceof MetadataBackedField metadataBackedField // diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java index 3f43d0652e..c0e20c425f 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java @@ -77,12 +77,13 @@ public Iterable findAll(Class domainType, Pageable pageable) { @Override public Optional findOne(Query query, Class domainType) { - return Optional.empty(); + return getReader(domainType).findOneByQuery(query); } @Override public Iterable findAll(Query query, Class domainType) { - throw new UnsupportedOperationException(); + + return getReader(domainType).findAllBy(query); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java index bc93cd09dd..0cb0b04638 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java @@ -16,9 +16,11 @@ package org.springframework.data.jdbc.core.convert; import java.util.Collections; +import java.util.Optional; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.query.Query; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.util.Assert; @@ -85,13 +87,37 @@ public Iterable findAllById(Iterable ids, Class domainType) { return super.findAllById(ids, domainType); } + public Optional findOne(Query query, Class domainType) { + + if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) { + return singleSelectDelegate.findOne(query, domainType); + } + + return super.findOne(query, domainType); + } + + @Override + public Iterable findAll(Query query, Class domainType) { + + if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) { + return singleSelectDelegate.findAll(query, domainType); + } + + return super.findAll(query, domainType); + } + + private static boolean isSingleSelectQuerySupported(Query query) { + return !query.isSorted() && !query.isLimited(); + } + private boolean isSingleSelectQuerySupported(Class entityType) { - return sqlGeneratorSource.getDialect().supportsSingleQueryLoading()// - && entityQualifiesForSingleSelectQuery(entityType); + return converter.getMappingContext().isSingleQueryLoadingEnabled() + && sqlGeneratorSource.getDialect().supportsSingleQueryLoading()// + && entityQualifiesForSingleQueryLoading(entityType); } - private boolean entityQualifiesForSingleSelectQuery(Class entityType) { + private boolean entityQualifiesForSingleQueryLoading(Class entityType) { boolean referenceFound = false; for (PersistentPropertyPath path : converter.getMappingContext() @@ -113,9 +139,9 @@ private boolean entityQualifiesForSingleSelectQuery(Class entityType) { } // AggregateReferences aren't supported yet - if (property.isAssociation()) { - return false; - } + // if (property.isAssociation()) { + // return false; + // } } return true; diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java index abc85b3d11..8f0fa6e818 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java @@ -45,8 +45,8 @@ public boolean hasValue(String paramName) { public Object getValue(String paramName) throws IllegalArgumentException { Object value = parameterSource.getValue(paramName); - if (value instanceof ValueFunction) { - return ((ValueFunction) value).apply(escaper); + if (value instanceof ValueFunction valueFunction) { + return valueFunction.apply(escaper); } return value; } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java index aa790fc854..e76825b29e 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java @@ -23,15 +23,8 @@ import static org.springframework.data.jdbc.testing.TestDatabaseFeatures.Feature.*; import java.time.LocalDateTime; +import java.util.*; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; import java.util.function.Function; import java.util.stream.IntStream; @@ -42,6 +35,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; +import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.annotation.Id; @@ -64,6 +58,9 @@ import org.springframework.data.relational.core.mapping.MappedCollection; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.Table; +import org.springframework.data.relational.core.query.Criteria; +import org.springframework.data.relational.core.query.CriteriaDefinition; +import org.springframework.data.relational.core.query.Query; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.test.context.ActiveProfiles; @@ -223,6 +220,62 @@ void findAllById() { .containsExactlyInAnyOrder(tuple(entity.id, "entity"), tuple(yetAnother.id, "yetAnother")); } + @Test // GH-1601 + void findAllByQuery() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(two.id)); + Query query = Query.query(criteria); + Iterable reloadedById = template.findAll(query, SimpleListParent.class); + + assertThat(reloadedById).extracting(e -> e.id, e -> e.content.size()).containsExactly(tuple(two.id, 2)); + } + + @Test // GH-1601 + void findOneByQuery() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(two.id)); + Query query = Query.query(criteria); + Optional reloadedById = template.findOne(query, SimpleListParent.class); + + assertThat(reloadedById).get().extracting(e -> e.id, e -> e.content.size()).containsExactly(two.id, 2); + } + + @Test // GH-1601 + void findOneByQueryNothingFound() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(4711)); + Query query = Query.query(criteria); + Optional reloadedById = template.findOne(query, SimpleListParent.class); + + assertThat(reloadedById).isEmpty(); + } + + @Test // GH-1601 + void findOneByQueryToManyResults() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").not(two.id)); + Query query = Query.query(criteria); + + assertThatExceptionOfType(IncorrectResultSizeDataAccessException.class) + .isThrownBy(() -> template.findOne(query, SimpleListParent.class)); + } + @Test // DATAJDBC-112 @EnabledOnFeature(SUPPORTS_QUOTED_IDS) void saveAndLoadAnEntityWithReferencedEntityById() { @@ -1266,6 +1319,29 @@ static class ChildNoId { private String content; } + @SuppressWarnings("unused") + static class SimpleListParent { + + @Id private Long id; + String name; + List content = new ArrayList<>(); + + static SimpleListParent of(String name, String... contents) { + + SimpleListParent parent = new SimpleListParent(); + parent.name = name; + + for (String content : contents) { + + ElementNoId element = new ElementNoId(); + element.content = content; + parent.content.add(element); + } + + return parent; + } + } + @Table("LIST_PARENT") @SuppressWarnings("unused") static class ListParent { diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql index 8ad4fda2dc..f086a03b5c 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql @@ -6,6 +6,7 @@ DROP TABLE ONE_TO_ONE_PARENT; DROP TABLE ELEMENT_NO_ID; DROP TABLE LIST_PARENT; +DROP TABLE SIMPLE_LIST_PARENT; DROP TABLE BYTE_ARRAY_OWNER; @@ -74,11 +75,18 @@ CREATE TABLE LIST_PARENT "id4" BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE ELEMENT_NO_ID ( CONTENT VARCHAR(100), LIST_PARENT_KEY BIGINT, - LIST_PARENT BIGINT + SIMPLE_LIST_PARENT_KEY BIGINT, + LIST_PARENT BIGINT, + SIMPLE_LIST_PARENT BIGINT ); ALTER TABLE ELEMENT_NO_ID ADD FOREIGN KEY (LIST_PARENT) diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql index a0aff08ce8..a6e5eabad7 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql @@ -32,9 +32,17 @@ CREATE TABLE LIST_PARENT NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID SERIAL PRIMARY KEY, + NAME VARCHAR(100) +); + CREATE TABLE element_no_id ( content VARCHAR(100), + SIMPLE_LIST_PARENT_key BIGINT, + SIMPLE_LIST_PARENT INTEGER, LIST_PARENT_key BIGINT, LIST_PARENT INTEGER ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql index 4dd1294ab2..dc73899207 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql @@ -26,6 +26,11 @@ CREATE TABLE Child_No_Id content VARCHAR(30) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE LIST_PARENT ( "id4" BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, @@ -34,6 +39,8 @@ CREATE TABLE LIST_PARENT CREATE TABLE ELEMENT_NO_ID ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_KEY BIGINT, + SIMPLE_LIST_PARENT BIGINT, LIST_PARENT_KEY BIGINT, LIST_PARENT BIGINT ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql index 4dd82b9003..4258e7b438 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql @@ -31,9 +31,16 @@ CREATE TABLE LIST_PARENT `id4` BIGINT AUTO_INCREMENT PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT AUTO_INCREMENT PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE element_no_id ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_key BIGINT, + SIMPLE_LIST_PARENT BIGINT, LIST_PARENT_key BIGINT, LIST_PARENT BIGINT ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql index 880528cdbf..e9a378f49b 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql @@ -30,14 +30,22 @@ CREATE TABLE Child_No_Id DROP TABLE IF EXISTS element_no_id; DROP TABLE IF EXISTS LIST_PARENT; +DROP TABLE IF EXISTS SIMPLE_LIST_PARENT; CREATE TABLE LIST_PARENT ( [id4] BIGINT IDENTITY PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT IDENTITY PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE element_no_id ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_key BIGINT, + SIMPLE_LIST_PARENT BIGINT, LIST_PARENT_key BIGINT, LIST_PARENT BIGINT ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql index 6808c8a912..40e32f1692 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql @@ -26,6 +26,11 @@ CREATE TABLE Child_No_Id `content` VARCHAR(30) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT AUTO_INCREMENT PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE LIST_PARENT ( `id4` BIGINT AUTO_INCREMENT PRIMARY KEY, @@ -35,7 +40,9 @@ CREATE TABLE element_no_id ( CONTENT VARCHAR(100), LIST_PARENT_key BIGINT, - LIST_PARENT BIGINT + SIMPLE_LIST_PARENT_key BIGINT, + LIST_PARENT BIGINT, + SIMPLE_LIST_PARENT BIGINT ); CREATE TABLE BYTE_ARRAY_OWNER diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql index 084e5db460..5a5c5baf40 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql @@ -4,6 +4,7 @@ DROP TABLE CHILD_NO_ID CASCADE CONSTRAINTS PURGE; DROP TABLE ONE_TO_ONE_PARENT CASCADE CONSTRAINTS PURGE; DROP TABLE ELEMENT_NO_ID CASCADE CONSTRAINTS PURGE; DROP TABLE LIST_PARENT CASCADE CONSTRAINTS PURGE; +DROP TABLE SIMPLE_LIST_PARENT CASCADE CONSTRAINTS PURGE; DROP TABLE BYTE_ARRAY_OWNER CASCADE CONSTRAINTS PURGE; DROP TABLE CHAIN0 CASCADE CONSTRAINTS PURGE; DROP TABLE CHAIN1 CASCADE CONSTRAINTS PURGE; @@ -64,9 +65,16 @@ CREATE TABLE LIST_PARENT "id4" NUMBER GENERATED by default on null as IDENTITY PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID NUMBER GENERATED by default on null as IDENTITY PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE element_no_id ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_key NUMBER, + SIMPLE_LIST_PARENT NUMBER, LIST_PARENT_key NUMBER, LIST_PARENT NUMBER ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql index 0c77c88139..d43b5750b1 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql @@ -4,6 +4,7 @@ DROP TABLE ONE_TO_ONE_PARENT; DROP TABLE Child_No_Id; DROP TABLE element_no_id; DROP TABLE "LIST_PARENT"; +DROP TABLE SIMPLE_LIST_PARENT; DROP TABLE "ARRAY_OWNER"; DROP TABLE DOUBLE_LIST_OWNER; DROP TABLE FLOAT_LIST_OWNER; @@ -68,11 +69,19 @@ CREATE TABLE "LIST_PARENT" NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + id SERIAL PRIMARY KEY, + NAME VARCHAR(100) +); + CREATE TABLE element_no_id ( content VARCHAR(100), LIST_PARENT_key BIGINT, - "LIST_PARENT" INTEGER + SIMPLE_LIST_PARENT_key BIGINT, + "LIST_PARENT" INTEGER, + SIMPLE_LIST_PARENT INTEGER ); CREATE TABLE "ARRAY_OWNER" diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index 3a6a2936ed..08adc38d74 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -373,12 +373,10 @@ private Condition mapCondition(CriteriaDefinition criteria, MutableBindings bind mappedValue = convertValue(comparator, parameter.getValue(), propertyField.getTypeHint()); typeHint = getTypeHint(mappedValue, actualType.getType(), parameter); - } else if (criteria.getValue() instanceof ValueFunction) { + } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - ValueFunction valueFunction = (ValueFunction) criteria.getValue(); - Object value = valueFunction.apply(getEscaper(comparator)); + mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())).apply(getEscaper(comparator)); - mappedValue = convertValue(comparator, value, propertyField.getTypeHint()); typeHint = actualType.getType(); } else { diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java index 770010dc31..b77d95bab9 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java @@ -111,18 +111,14 @@ private Assignment getAssignment(SqlIdentifier columnName, Object value, Mutable Object mappedValue; Class typeHint; - if (value instanceof Parameter) { - - Parameter parameter = (Parameter) value; + if (value instanceof Parameter parameter) { mappedValue = convertValue(parameter.getValue(), propertyField.getTypeHint()); typeHint = getTypeHint(mappedValue, actualType.getType(), parameter); - } else if (value instanceof ValueFunction) { - - ValueFunction valueFunction = (ValueFunction) value; + } else if (value instanceof ValueFunction valueFunction) { - mappedValue = convertValue(valueFunction.apply(Escaper.DEFAULT), propertyField.getTypeHint()); + mappedValue = valueFunction.transform(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); if (mappedValue == null) { return Assignments.value(column, SQL.nullLiteral()); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java index cd6908174e..780fdf0d9d 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java @@ -56,4 +56,16 @@ default Supplier toSupplier(Escaper escaper) { return () -> apply(escaper); } + + /** + * Transforms the inner value of the ValueFunction using the profided transformation. + * + * The default implementation just return the current {@literal ValueFunction}. + * This is not a valid implementation and serves just to maintain backward compatibility. + * + * @param transformation to be applied to the underlying value. + * @return a new {@literal ValueFunction}. + * @since 3.2 + */ + default ValueFunction transform(Function transformation) {return this;}; } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java index 5bb11e4b81..505a027dca 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java @@ -19,6 +19,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import org.jetbrains.annotations.NotNull; import org.springframework.data.mapping.PersistentProperty; @@ -81,6 +82,12 @@ public String findAllById() { return createSelect(condition); } + @Override + public String findAllByCondition(BiFunction conditionSource) { + Condition condition = conditionSource.apply(table, aggregate); + return createSelect(condition); + } + /** * @return The {@link AggregatePath} to the id property of the aggregate root. */ @@ -88,13 +95,7 @@ private AggregatePath getRootIdPath() { return context.getAggregatePath(aggregate).append(aggregate.getRequiredIdProperty()); } - /** - * Creates a SQL suitable of loading all the data required for constructing complete aggregates. - * - * @param condition a constraint for limiting the aggregates to be loaded. - * @return a {@literal String} containing the generated SQL statement - */ - private String createSelect(Condition condition) { + String createSelect(Condition condition) { AggregatePath rootPath = context.getAggregatePath(aggregate); QueryMeta queryMeta = createInlineQuery(rootPath, condition); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java index 78049657e0..80eb9a1a87 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java @@ -15,6 +15,12 @@ */ package org.springframework.data.relational.core.sqlgeneration; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.sql.Condition; +import org.springframework.data.relational.core.sql.Table; + +import java.util.function.BiFunction; + /** * Generates SQL statements for loading aggregates. * @@ -28,5 +34,7 @@ public interface SqlGenerator { String findAllById(); + String findAllByCondition(BiFunction conditionSource); + AliasFactory getAliasFactory(); } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java new file mode 100644 index 0000000000..2be3e8c371 --- /dev/null +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.data.relational.repository.query; + +import org.springframework.data.relational.core.dialect.Escaper; +import org.springframework.data.relational.core.query.ValueFunction; + +import java.util.function.Function; + +/** + * Value function that has an underlying value and a modifier that gets applied after the escaper. + * + * @author Jens Schauder + * @since 3.2 + */ +record ModifyingValueFunction(Object value, Function modifier) implements ValueFunction { + + static ModifyingValueFunction of(Object value, Function modifier) { + return new ModifyingValueFunction(value, modifier); + } + + @Override + public String apply(Escaper escaper) { + return modifier.apply(escaper.escape(value.toString())); + } + + @Override + public ValueFunction transform(Function transformation) { + return new ModifyingValueFunction(transformation.apply(value), modifier); + } +} diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java index 818fa8578f..1261899b49 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java @@ -20,7 +20,6 @@ import java.util.List; import org.springframework.data.relational.core.dialect.Escaper; -import org.springframework.data.relational.core.query.ValueFunction; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; import org.springframework.data.repository.query.parser.Part; @@ -139,12 +138,12 @@ protected Object prepareParameterValue(@Nullable Object value, Class valueTyp switch (partType) { case STARTING_WITH: - return (ValueFunction) escaper -> escaper.escape(value.toString()) + "%"; + return ModifyingValueFunction.of(value, s -> s + "%"); case ENDING_WITH: - return (ValueFunction) escaper -> "%" + escaper.escape(value.toString()); + return ModifyingValueFunction.of(value, s -> "%" + s); case CONTAINING: case NOT_CONTAINING: - return (ValueFunction) escaper -> "%" + escaper.escape(value.toString()) + "%"; + return ModifyingValueFunction.of(value, s -> "%" + s + "%"); default: return value; }