diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java index df87d3813d..c53e3770b1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java @@ -49,6 +49,7 @@ * @author Umut Erturk * @author Myeonghyeon Lee * @author Chirag Tailor + * @author Mark Paluch */ @SuppressWarnings("rawtypes") class JdbcAggregateChangeExecutionContext { @@ -268,12 +269,10 @@ List populateIdsIfNecessary() { cascadingValues.stage(insert.getDependingOn(), insert.getPropertyPath(), qualifierValue, newEntity); - } else if (insert.getPropertyPath().getLeafProperty().isCollectionLike()) { cascadingValues.gather(insert.getDependingOn(), insert.getPropertyPath(), qualifierValue, newEntity); - } } } @@ -289,6 +288,7 @@ List populateIdsIfNecessary() { return roots; } + @SuppressWarnings("unchecked") private Object setIdAndCascadingProperties(DbAction.WithEntity action, @Nullable Object generatedId, StagedValues cascadingValues) { @@ -328,6 +328,7 @@ private PersistentPropertyPath getRelativePath(DbAction action, Persistent throw new IllegalArgumentException(String.format("DbAction of type %s is not supported", action.getClass())); } + @SuppressWarnings("unchecked") private RelationalPersistentEntity getRequiredPersistentEntity(Class type) { return (RelationalPersistentEntity) context.getRequiredPersistentEntity(type); } @@ -358,7 +359,7 @@ private void updateWithVersion(DbAction.UpdateRoot update) { */ private static class StagedValues { - static final List aggregators = Arrays.asList(SetAggregator.INSTANCE, MapAggregator.INSTANCE, + static final List> aggregators = Arrays.asList(SetAggregator.INSTANCE, MapAggregator.INSTANCE, ListAggregator.INSTANCE, SingleElementAggregator.INSTANCE); Map> values = new HashMap<>(); @@ -374,13 +375,14 @@ private static class StagedValues { * be {@literal null}. * @param value The value to be set. Must not be {@literal null}. */ - @SuppressWarnings("unchecked") - void stage(DbAction action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) { - gather(action, path, qualifier, value); - values.get(action).get(path).isStaged = true; + void stage(DbAction action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) { + + StagedValue gather = gather(action, path, qualifier, value); + gather.isStaged = true; } - void gather(DbAction action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) { + @SuppressWarnings("unchecked") + StagedValue gather(DbAction action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) { MultiValueAggregator aggregator = getAggregatorFor(path); @@ -391,19 +393,20 @@ void gather(DbAction action, PersistentPropertyPath path, @Nullable Objec persistentPropertyPath -> new StagedValue(aggregator.createEmptyInstance())); T currentValue = (T) stagedValue.value; - Object newValue = aggregator.add(currentValue, qualifier, value); - - stagedValue.value = newValue; + stagedValue.value = aggregator.add(currentValue, qualifier, value); valuesForPath.put(path, stagedValue); + + return stagedValue; } - private MultiValueAggregator getAggregatorFor(PersistentPropertyPath path) { + @SuppressWarnings("unchecked") + private MultiValueAggregator getAggregatorFor(PersistentPropertyPath path) { PersistentProperty property = path.getLeafProperty(); - for (MultiValueAggregator aggregator : aggregators) { + for (MultiValueAggregator aggregator : aggregators) { if (aggregator.handles(property)) { - return aggregator; + return (MultiValueAggregator) aggregator; } } @@ -427,10 +430,10 @@ void forEachPath(DbAction dbAction, BiConsumer extends WithPropertyPath, WithEntity { default Pair, Object> getQualifier() { Map, Object> qualifiers = getQualifiers(); + if (qualifiers.isEmpty()) { return null; } Set, Object>> entries = qualifiers.entrySet(); - Map.Entry, Object> entry = entries.stream().sorted(Comparator.comparing(e -> -e.getKey().getLength())).findFirst().get(); + Optional, Object>> optionalEntry = entries.stream() + .filter(e -> e.getValue() != null).min(Comparator.comparing(e -> -e.getKey().getLength())); + + Map.Entry, Object> entry = optionalEntry.orElse(null); - if (entry.getValue() == null) { + if (entry == null) { return null; } + return Pair.of(entry.getKey(), entry.getValue()); }