Skip to content

Commit

Permalink
Polishing.
Browse files Browse the repository at this point in the history
Replace code duplications with doWithBatch(…) method. Return most concrete type in DefaultDataAccessStrategy and MyBatisDataAccessStrategy.

See #1623
Original pull request: #1897
  • Loading branch information
mp911de committed Oct 1, 2024
1 parent c4f62e9 commit 7cf81ae
Showing 3 changed files with 40 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -16,13 +16,15 @@
package org.springframework.data.jdbc.core;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@@ -56,6 +58,7 @@
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;

/**
* {@link JdbcAggregateOperations} implementation, storing aggregates in and obtaining them from a JDBC data store.
@@ -173,19 +176,8 @@ public <T> T save(T instance) {

@Override
public <T> List<T> saveAll(Iterable<T> instances) {

Assert.notNull(instances, "Aggregate instances must not be null");

if (!instances.iterator().hasNext()) {
return Collections.emptyList();
}

List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {
verifyIdProperty(instance);
entityAndChangeCreators.add(new EntityAndChangeCreator<>(instance, changeCreatorSelectorForSave(instance)));
}
return performSaveAll(entityAndChangeCreators);
return doWithBatch(instances, entity -> changeCreatorSelectorForSave(entity).apply(entity), this::verifyIdProperty,
this::performSaveAll);
}

/**
@@ -206,21 +198,7 @@ public <T> T insert(T instance) {

@Override
public <T> List<T> insertAll(Iterable<T> instances) {

Assert.notNull(instances, "Aggregate instances must not be null");

if (!instances.iterator().hasNext()) {
return Collections.emptyList();
}

List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {

Function<T, RootAggregateChange<T>> changeCreator = entity -> createInsertChange(prepareVersionForInsert(entity));
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
entityAndChangeCreators.add(entityChange);
}
return performSaveAll(entityAndChangeCreators);
return doWithBatch(instances, entity -> createInsertChange(prepareVersionForInsert(entity)), this::performSaveAll);
}

/**
@@ -241,21 +219,35 @@ public <T> T update(T instance) {

@Override
public <T> List<T> updateAll(Iterable<T> instances) {
return doWithBatch(instances, entity -> createUpdateChange(prepareVersionForUpdate(entity)), this::performSaveAll);
}

private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {
return doWithBatch(iterable, changeCreator, entity -> {}, performFunction);
}

Assert.notNull(instances, "Aggregate instances must not be null");
private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
Consumer<T> beforeEntityChange, Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {

if (!instances.iterator().hasNext()) {
Assert.notNull(iterable, "Aggregate instances must not be null");

if (ObjectUtils.isEmpty(iterable)) {
return Collections.emptyList();
}

List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>(
iterable instanceof Collection<?> c ? c.size() : 16);

for (T instance : iterable) {

beforeEntityChange.accept(instance);

Function<T, RootAggregateChange<T>> changeCreator = entity -> createUpdateChange(prepareVersionForUpdate(entity));
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
entityAndChangeCreators.add(entityChange);
}
return performSaveAll(entityAndChangeCreators);

return performFunction.apply(entityAndChangeCreators);
}

@Override
Original file line number Diff line number Diff line change
@@ -272,12 +272,12 @@ public <T> T findById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType) {
public <T> List<T> findAll(Class<T> domainType) {
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
}

@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {

if (!ids.iterator().hasNext()) {
return Collections.emptyList();
@@ -290,7 +290,7 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {

@Override
@SuppressWarnings("unchecked")
public Iterable<Object> findAllByPath(Identifier identifier,
public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> propertyPath) {

Assert.notNull(identifier, "identifier must not be null");
@@ -338,12 +338,12 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
public <T> List<T> findAll(Class<T> domainType, Sort sort) {
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
}

@@ -361,7 +361,7 @@ public <T> Optional<T> findOne(Query query, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
public <T> List<T> findAll(Query query, Class<T> domainType) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
@@ -370,7 +370,7 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource, pageable);
Original file line number Diff line number Diff line change
@@ -256,21 +256,21 @@ public <T> T findById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType) {
public <T> List<T> findAll(Class<T> domainType) {

String statement = namespace(domainType) + ".findAll";
MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap());
return sqlSession().selectList(statement, parameter);
}

@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return sqlSession().selectList(namespace(domainType) + ".findAllById",
new MyBatisContext(ids, null, domainType, Collections.emptyMap()));
}

@Override
public Iterable<Object> findAllByPath(Identifier identifier,
public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {

String statementName = namespace(getOwnerTyp(path)) + ".findAllByPath-" + path.toDotPath();
@@ -288,7 +288,7 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
public <T> List<T> findAll(Class<T> domainType, Sort sort) {

Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("sort", sort);
@@ -297,7 +297,7 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {

Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("pageable", pageable);
@@ -311,12 +311,12 @@ public <T> Optional<T> findOne(Query query, Class<T> probeType) {
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> probeType) {
public <T> List<T> findAll(Query query, Class<T> probeType) {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
throw new UnsupportedOperationException("Not implemented");
}

0 comments on commit 7cf81ae

Please sign in to comment.