Skip to content

Commit

Permalink
Polishing.
Browse files Browse the repository at this point in the history
Simplify ValueFunction mapping. Remove invariants of findBy SQL generation in favor of the Condition-based variant. Reduce visibility. Change return value of AggregateReader to List

See #1601
Original pull request: #1617
  • Loading branch information
schauder authored and mp911de committed Sep 26, 2023
1 parent 0fdeaeb commit f3bc0af
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 216 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;

import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.data.relational.core.dialect.Dialect;
import org.springframework.data.relational.core.mapping.AggregatePath;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.query.Criteria;
import org.springframework.data.relational.core.query.CriteriaDefinition;
import org.springframework.data.relational.core.query.Query;
import org.springframework.data.relational.core.sql.Condition;
Expand All @@ -36,14 +36,15 @@
import org.springframework.data.relational.core.sqlgeneration.SingleQuerySqlGenerator;
import org.springframework.data.relational.core.sqlgeneration.SqlGenerator;
import org.springframework.data.relational.domain.RowDocument;
import org.springframework.data.util.Streamable;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Reads complete Aggregates from the database, by generating appropriate SQL using a {@link SingleQuerySqlGenerator}
* through {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate}. Results are converterd into an
* through {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate}. Results are converted into an
* intermediate {@link RowDocumentResultSetExtractor RowDocument} and mapped via
* {@link org.springframework.data.relational.core.conversion.RelationalConverter#read(Class, RowDocument)}.
*
Expand All @@ -55,7 +56,8 @@
class AggregateReader<T> {

private final RelationalPersistentEntity<T> aggregate;
private final org.springframework.data.relational.core.sqlgeneration.SqlGenerator sqlGenerator;
private final Table table;
private final SqlGenerator sqlGenerator;
private final JdbcConverter converter;
private final NamedParameterJdbcOperations jdbcTemplate;
private final RowDocumentResultSetExtractor extractor;
Expand All @@ -66,6 +68,7 @@ class AggregateReader<T> {
this.converter = converter;
this.aggregate = aggregate;
this.jdbcTemplate = jdbcTemplate;
this.table = Table.create(aggregate.getQualifiedTableName());

this.sqlGenerator = new CachingSqlGenerator(
new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate));
Expand All @@ -74,62 +77,58 @@ class AggregateReader<T> {
createPathToColumnMapping(aliasFactory));
}

public List<T> findAll() {
return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll);
}

@Nullable
public T findById(Object id) {

id = converter.writeValue(id, aggregate.getRequiredIdProperty().getTypeInformation());
Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).is(id)).limit(1);

return jdbcTemplate.query(sqlGenerator.findById(), Map.of("id", id), this::extractZeroOrOne);
return findOne(query);
}

public Iterable<T> findAllById(Iterable<?> ids) {
@Nullable
public T findOne(Query query) {

List<Object> convertedIds = new ArrayList<>();
for (Object id : ids) {
convertedIds.add(converter.writeValue(id, aggregate.getRequiredIdProperty().getTypeInformation()));
}
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
Condition condition = createCondition(query, parameterSource);

return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll);
return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractZeroOrOne);
}

public Iterable<T> findAllBy(Query query) {
public List<T> findAll() {
return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll);
}

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
BiFunction<Table, RelationalPersistentEntity, Condition> condition = createConditionSource(query, parameterSource);
return jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractAll);
public List<T> findAllById(Iterable<?> ids) {

Collection<?> identifiers = ids instanceof Collection<?> idl ? idl : Streamable.of(ids).toList();
Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers)).limit(1);

return findAll(query);
}

public Optional<T> findOneByQuery(Query query) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
BiFunction<Table, RelationalPersistentEntity, Condition> condition = createConditionSource(query, parameterSource);
public List<T> findAll(Query query) {

return Optional.ofNullable(
jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractZeroOrOne));
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
Condition condition = createCondition(query, parameterSource);
return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractAll);
}

private BiFunction<Table, RelationalPersistentEntity, Condition> createConditionSource(Query query, MapSqlParameterSource parameterSource) {
@Nullable
private Condition createCondition(Query query, MapSqlParameterSource parameterSource) {

QueryMapper queryMapper = new QueryMapper(converter);

BiFunction<Table, RelationalPersistentEntity, Condition> condition = (table, aggregate) -> {
Optional<CriteriaDefinition> criteria = query.getCriteria();
return criteria
.map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate))
.orElse(null);
};
return condition;
Optional<CriteriaDefinition> criteria = query.getCriteria();
return criteria
.map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate))
.orElse(null);
}

/**
* Extracts a list of aggregates from the given {@link ResultSet} by utilizing the
* {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms
* to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract.
*
*
* @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}.
* @return a {@code List} of aggregates, fully converted.
* @throws SQLException
Expand Down Expand Up @@ -195,21 +194,15 @@ public String keyColumn(AggregatePath path) {
* @author Jens Schauder
* @since 3.2
*/
static class CachingSqlGenerator implements org.springframework.data.relational.core.sqlgeneration.SqlGenerator {

private final org.springframework.data.relational.core.sqlgeneration.SqlGenerator delegate;
static class CachingSqlGenerator implements SqlGenerator {

private final SqlGenerator delegate;
private final String findAll;
private final String findById;
private final String findAllById;

public CachingSqlGenerator(SqlGenerator delegate) {

this.delegate = delegate;

findAll = delegate.findAll();
findById = delegate.findById();
findAllById = delegate.findAllById();
this.findAll = delegate.findAll();
}

@Override
Expand All @@ -218,18 +211,8 @@ public String findAll() {
}

@Override
public String findById() {
return findById;
}

@Override
public String findAllById() {
return findAllById;
}

@Override
public String findAllByCondition(BiFunction<Table, RelationalPersistentEntity, Condition> conditionSource) {
return delegate.findAllByCondition(conditionSource);
public String findAll(@Nullable Condition condition) {
return delegate.findAll(condition);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

import org.springframework.data.domain.Sort;
import org.springframework.data.jdbc.core.mapping.JdbcValue;
Expand All @@ -35,7 +33,6 @@
import org.springframework.data.mapping.context.InvalidPersistentPropertyPath;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.relational.core.dialect.Dialect;
import org.springframework.data.relational.core.dialect.Escaper;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.query.CriteriaDefinition;
Expand Down Expand Up @@ -77,7 +74,7 @@ public QueryMapper(Dialect dialect, JdbcConverter converter) {
Assert.notNull(converter, "JdbcConverter must not be null");

this.converter = converter;
this.mappingContext = (MappingContext) converter.getMappingContext();
this.mappingContext = converter.getMappingContext();
}

/**
Expand Down Expand Up @@ -310,7 +307,7 @@ private Condition mapCondition(CriteriaDefinition criteria, MapSqlParameterSourc
sqlType = getTypeHint(mappedValue, actualType.getType(), settableValue);
} else if (criteria.getValue() instanceof ValueFunction valueFunction) {

mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint()));
mappedValue = valueFunction.map(v -> convertValue(comparator, v, propertyField.getTypeHint()));
sqlType = propertyField.getSqlType();

} else if (propertyField instanceof MetadataBackedField metadataBackedField //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,12 @@ public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {

@Override
public <T> Optional<T> findOne(Query query, Class<T> domainType) {
return getReader(domainType).findOneByQuery(query);
return Optional.ofNullable(getReader(domainType).findOne(query));
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {

return getReader(domainType).findAllBy(query);
return getReader(domainType).findAll(query);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return super.findAllById(ids, domainType);
}

@Override
public <T> Optional<T> findOne(Query query, Class<T> domainType) {

if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) {
Expand Down Expand Up @@ -137,11 +138,6 @@ private boolean entityQualifiesForSingleQueryLoading(Class<?> entityType) {

referenceFound = true;
}

// AggregateReferences aren't supported yet
// if (property.isAssociation()) {
// return false;
// }
}
return true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import org.springframework.jdbc.core.namedparam.SqlParameterSource;

/**
* This {@link SqlParameterSource} will apply escaping to it's values.
*
* This {@link SqlParameterSource} will apply escaping to its values.
*
* @author Jens Schauder
* @since 3.2
*/
public class EscapingParameterSource implements SqlParameterSource {
class EscapingParameterSource implements SqlParameterSource {

private final SqlParameterSource parameterSource;
private final Escaper escaper;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
*/
package org.springframework.data.jdbc.repository.query;

import org.springframework.data.relational.core.dialect.Dialect;
import org.springframework.data.relational.core.dialect.Escaper;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;

/**
* Value object encapsulating a query containing named parameters and a{@link SqlParameterSource} to bind the parameters.
* Value object encapsulating a query containing named parameters and a{@link SqlParameterSource} to bind the
* parameters.
*
* @author Mark Paluch
* @author Jens Schauder
Expand All @@ -41,13 +41,12 @@ String getQuery() {
return query;
}

SqlParameterSource getParameterSource(Escaper escaper) {
return new EscapingParameterSource(parameterSource, escaper);
}

@Override
public String toString() {
return this.query;
}

public SqlParameterSource getParameterSource(Escaper escaper) {

return new EscapingParameterSource(parameterSource, escaper);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,16 @@ public Expression getMappedObject(Expression expression, @Nullable RelationalPer
return expression;
}

if (expression instanceof Column) {
if (expression instanceof Column column) {

Column column = (Column) expression;
Field field = createPropertyField(entity, column.getName());
TableLike table = column.getTable();

Column columnFromTable = table.column(field.getMappedColumnName());
return column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable;
}

if (expression instanceof SimpleFunction) {

SimpleFunction function = (SimpleFunction) expression;
if (expression instanceof SimpleFunction function) {

List<Expression> arguments = function.getExpressions();
List<Expression> mappedArguments = new ArrayList<>(arguments.size());
Expand Down Expand Up @@ -367,15 +364,14 @@ private Condition mapCondition(CriteriaDefinition criteria, MutableBindings bind
Class<?> typeHint;

Comparator comparator = criteria.getComparator();
if (criteria.getValue() instanceof Parameter) {

Parameter parameter = (Parameter) criteria.getValue();
if (criteria.getValue()instanceof Parameter parameter) {

mappedValue = convertValue(comparator, parameter.getValue(), propertyField.getTypeHint());
typeHint = getTypeHint(mappedValue, actualType.getType(), parameter);
} else if (criteria.getValue() instanceof ValueFunction<?> valueFunction) {

mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())).apply(getEscaper(comparator));
mappedValue = valueFunction.map(v -> convertValue(comparator, v, propertyField.getTypeHint()))
.apply(getEscaper(comparator));

typeHint = actualType.getType();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ private Assignment getAssignment(SqlIdentifier columnName, Object value, Mutable

} else if (value instanceof ValueFunction<?> valueFunction) {

mappedValue = valueFunction.transform(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT);
mappedValue = valueFunction.map(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT);

if (mappedValue == null) {
return Assignments.value(column, SQL.nullLiteral());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,18 @@ default Supplier<T> toSupplier(Escaper escaper) {
}

/**
* Transforms the inner value of the ValueFunction using the profided transformation.
* Return a new ValueFunction applying the given mapping {@link Function}. The mapping function is applied after
* applying {@link Escaper}.
*
* The default implementation just return the current {@literal ValueFunction}.
* This is not a valid implementation and serves just to maintain backward compatibility.
*
* @param transformation to be applied to the underlying value.
* @param mapper the mapping function to apply to the value.
* @param <R> the type of the value returned from the mapping function.
* @return a new {@literal ValueFunction}.
* @since 3.2
*/
default ValueFunction<T> transform(Function<Object, Object> transformation) {return this;};
default <R> ValueFunction<R> map(Function<T, R> mapper) {

Assert.notNull(mapper, "Mapping function must not be null");

return escaper -> mapper.apply(this.apply(escaper));
}
}
Loading

0 comments on commit f3bc0af

Please sign in to comment.