Skip to content

Commit

Permalink
Add support for arbitrary where clauses in Single Query Loading.
Browse files Browse the repository at this point in the history
Closes #1601
Original pull request: #1617
  • Loading branch information
schauder authored and mp911de committed Sep 26, 2023
1 parent 6fb6110 commit 0fdeaeb
Show file tree
Hide file tree
Showing 21 changed files with 312 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,22 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;

import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.data.relational.core.dialect.Dialect;
import org.springframework.data.relational.core.mapping.AggregatePath;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.query.CriteriaDefinition;
import org.springframework.data.relational.core.query.Query;
import org.springframework.data.relational.core.sql.Condition;
import org.springframework.data.relational.core.sql.Table;
import org.springframework.data.relational.core.sqlgeneration.AliasFactory;
import org.springframework.data.relational.core.sqlgeneration.SingleQuerySqlGenerator;
import org.springframework.data.relational.core.sqlgeneration.SqlGenerator;
import org.springframework.data.relational.domain.RowDocument;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -89,6 +96,35 @@ public Iterable<T> findAllById(Iterable<?> ids) {
return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll);
}

public Iterable<T> findAllBy(Query query) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
BiFunction<Table, RelationalPersistentEntity, Condition> condition = createConditionSource(query, parameterSource);
return jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractAll);
}

public Optional<T> findOneByQuery(Query query) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
BiFunction<Table, RelationalPersistentEntity, Condition> condition = createConditionSource(query, parameterSource);

return Optional.ofNullable(
jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractZeroOrOne));
}

private BiFunction<Table, RelationalPersistentEntity, Condition> createConditionSource(Query query, MapSqlParameterSource parameterSource) {

QueryMapper queryMapper = new QueryMapper(converter);

BiFunction<Table, RelationalPersistentEntity, Condition> condition = (table, aggregate) -> {
Optional<CriteriaDefinition> criteria = query.getCriteria();
return criteria
.map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate))
.orElse(null);
};
return condition;
}

/**
* Extracts a list of aggregates from the given {@link ResultSet} by utilizing the
* {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms
Expand All @@ -115,7 +151,8 @@ private List<T> extractAll(ResultSet rs) throws SQLException {
* to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract.
*
* @param @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}.
* @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is empty, null is returned.
* @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is
* empty, null is returned.
* @throws SQLException
* @throws IncorrectResultSizeDataAccessException when the conversion yields more than one instance.
*/
Expand Down Expand Up @@ -190,9 +227,15 @@ public String findAllById() {
return findAllById;
}

@Override
public String findAllByCondition(BiFunction<Table, RelationalPersistentEntity, Condition> conditionSource) {
return delegate.findAllByCondition(conditionSource);
}

@Override
public AliasFactory getAliasFactory() {
return delegate.getAliasFactory();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private Condition mapCondition(CriteriaDefinition criteria, MapSqlParameterSourc
sqlType = getTypeHint(mappedValue, actualType.getType(), settableValue);
} else if (criteria.getValue() instanceof ValueFunction valueFunction) {

mappedValue = valueFunction;
mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint()));
sqlType = propertyField.getSqlType();

} else if (propertyField instanceof MetadataBackedField metadataBackedField //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {

@Override
public <T> Optional<T> findOne(Query query, Class<T> domainType) {
return Optional.empty();
return getReader(domainType).findOneByQuery(query);
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
throw new UnsupportedOperationException();

return getReader(domainType).findAllBy(query);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package org.springframework.data.jdbc.core.convert;

import java.util.Collections;
import java.util.Optional;

import org.springframework.data.mapping.PersistentPropertyPath;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.query.Query;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -85,13 +87,37 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return super.findAllById(ids, domainType);
}

public <T> Optional<T> findOne(Query query, Class<T> domainType) {

if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) {
return singleSelectDelegate.findOne(query, domainType);
}

return super.findOne(query, domainType);
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {

if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) {
return singleSelectDelegate.findAll(query, domainType);
}

return super.findAll(query, domainType);
}

private static boolean isSingleSelectQuerySupported(Query query) {
return !query.isSorted() && !query.isLimited();
}

private boolean isSingleSelectQuerySupported(Class<?> entityType) {

return sqlGeneratorSource.getDialect().supportsSingleQueryLoading()//
&& entityQualifiesForSingleSelectQuery(entityType);
return converter.getMappingContext().isSingleQueryLoadingEnabled()
&& sqlGeneratorSource.getDialect().supportsSingleQueryLoading()//
&& entityQualifiesForSingleQueryLoading(entityType);
}

private boolean entityQualifiesForSingleSelectQuery(Class<?> entityType) {
private boolean entityQualifiesForSingleQueryLoading(Class<?> entityType) {

boolean referenceFound = false;
for (PersistentPropertyPath<RelationalPersistentProperty> path : converter.getMappingContext()
Expand All @@ -113,9 +139,9 @@ private boolean entityQualifiesForSingleSelectQuery(Class<?> entityType) {
}

// AggregateReferences aren't supported yet
if (property.isAssociation()) {
return false;
}
// if (property.isAssociation()) {
// return false;
// }
}
return true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ public boolean hasValue(String paramName) {
public Object getValue(String paramName) throws IllegalArgumentException {

Object value = parameterSource.getValue(paramName);
if (value instanceof ValueFunction<?>) {
return ((ValueFunction<?>) value).apply(escaper);
if (value instanceof ValueFunction<?> valueFunction) {
return valueFunction.apply(escaper);
}
return value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,8 @@
import static org.springframework.data.jdbc.testing.TestDatabaseFeatures.Feature.*;

import java.time.LocalDateTime;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;

Expand All @@ -42,6 +35,7 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
import org.springframework.dao.OptimisticLockingFailureException;
import org.springframework.data.annotation.Id;
Expand All @@ -64,6 +58,9 @@
import org.springframework.data.relational.core.mapping.MappedCollection;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.Table;
import org.springframework.data.relational.core.query.Criteria;
import org.springframework.data.relational.core.query.CriteriaDefinition;
import org.springframework.data.relational.core.query.Query;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.test.context.ActiveProfiles;

Expand Down Expand Up @@ -223,6 +220,62 @@ void findAllById() {
.containsExactlyInAnyOrder(tuple(entity.id, "entity"), tuple(yetAnother.id, "yetAnother"));
}

@Test // GH-1601
void findAllByQuery() {

template.save(SimpleListParent.of("one", "one_1"));
SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2"));
template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3"));

CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(two.id));
Query query = Query.query(criteria);
Iterable<SimpleListParent> reloadedById = template.findAll(query, SimpleListParent.class);

assertThat(reloadedById).extracting(e -> e.id, e -> e.content.size()).containsExactly(tuple(two.id, 2));
}

@Test // GH-1601
void findOneByQuery() {

template.save(SimpleListParent.of("one", "one_1"));
SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2"));
template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3"));

CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(two.id));
Query query = Query.query(criteria);
Optional<SimpleListParent> reloadedById = template.findOne(query, SimpleListParent.class);

assertThat(reloadedById).get().extracting(e -> e.id, e -> e.content.size()).containsExactly(two.id, 2);
}

@Test // GH-1601
void findOneByQueryNothingFound() {

template.save(SimpleListParent.of("one", "one_1"));
SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2"));
template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3"));

CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(4711));
Query query = Query.query(criteria);
Optional<SimpleListParent> reloadedById = template.findOne(query, SimpleListParent.class);

assertThat(reloadedById).isEmpty();
}

@Test // GH-1601
void findOneByQueryToManyResults() {

template.save(SimpleListParent.of("one", "one_1"));
SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2"));
template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3"));

CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").not(two.id));
Query query = Query.query(criteria);

assertThatExceptionOfType(IncorrectResultSizeDataAccessException.class)
.isThrownBy(() -> template.findOne(query, SimpleListParent.class));
}

@Test // DATAJDBC-112
@EnabledOnFeature(SUPPORTS_QUOTED_IDS)
void saveAndLoadAnEntityWithReferencedEntityById() {
Expand Down Expand Up @@ -1266,6 +1319,29 @@ static class ChildNoId {
private String content;
}

@SuppressWarnings("unused")
static class SimpleListParent {

@Id private Long id;
String name;
List<ElementNoId> content = new ArrayList<>();

static SimpleListParent of(String name, String... contents) {

SimpleListParent parent = new SimpleListParent();
parent.name = name;

for (String content : contents) {

ElementNoId element = new ElementNoId();
element.content = content;
parent.content.add(element);
}

return parent;
}
}

@Table("LIST_PARENT")
@SuppressWarnings("unused")
static class ListParent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ DROP TABLE ONE_TO_ONE_PARENT;

DROP TABLE ELEMENT_NO_ID;
DROP TABLE LIST_PARENT;
DROP TABLE SIMPLE_LIST_PARENT;

DROP TABLE BYTE_ARRAY_OWNER;

Expand Down Expand Up @@ -74,11 +75,18 @@ CREATE TABLE LIST_PARENT
"id4" BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY,
NAME VARCHAR(100)
);
CREATE TABLE SIMPLE_LIST_PARENT
(
ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY,
NAME VARCHAR(100)
);
CREATE TABLE ELEMENT_NO_ID
(
CONTENT VARCHAR(100),
LIST_PARENT_KEY BIGINT,
LIST_PARENT BIGINT
SIMPLE_LIST_PARENT_KEY BIGINT,
LIST_PARENT BIGINT,
SIMPLE_LIST_PARENT BIGINT
);
ALTER TABLE ELEMENT_NO_ID
ADD FOREIGN KEY (LIST_PARENT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ CREATE TABLE LIST_PARENT
NAME VARCHAR(100)
);

CREATE TABLE SIMPLE_LIST_PARENT
(
ID SERIAL PRIMARY KEY,
NAME VARCHAR(100)
);

CREATE TABLE element_no_id
(
content VARCHAR(100),
SIMPLE_LIST_PARENT_key BIGINT,
SIMPLE_LIST_PARENT INTEGER,
LIST_PARENT_key BIGINT,
LIST_PARENT INTEGER
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ CREATE TABLE Child_No_Id
content VARCHAR(30)
);

CREATE TABLE SIMPLE_LIST_PARENT
(
ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY,
NAME VARCHAR(100)
);
CREATE TABLE LIST_PARENT
(
"id4" BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY,
Expand All @@ -34,6 +39,8 @@ CREATE TABLE LIST_PARENT
CREATE TABLE ELEMENT_NO_ID
(
CONTENT VARCHAR(100),
SIMPLE_LIST_PARENT_KEY BIGINT,
SIMPLE_LIST_PARENT BIGINT,
LIST_PARENT_KEY BIGINT,
LIST_PARENT BIGINT
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ CREATE TABLE LIST_PARENT
`id4` BIGINT AUTO_INCREMENT PRIMARY KEY,
NAME VARCHAR(100)
);
CREATE TABLE SIMPLE_LIST_PARENT
(
ID BIGINT AUTO_INCREMENT PRIMARY KEY,
NAME VARCHAR(100)
);
CREATE TABLE element_no_id
(
CONTENT VARCHAR(100),
SIMPLE_LIST_PARENT_key BIGINT,
SIMPLE_LIST_PARENT BIGINT,
LIST_PARENT_key BIGINT,
LIST_PARENT BIGINT
);
Expand Down
Loading

0 comments on commit 0fdeaeb

Please sign in to comment.