diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/AbstractJdbcQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/AbstractJdbcQuery.java index 2d7df924ed..98fcac2861 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/AbstractJdbcQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/AbstractJdbcQuery.java @@ -101,7 +101,7 @@ protected JdbcQueryExecution getQueryExecution(JdbcQueryMethod queryMethod, return extractor != null ? getQueryExecution(extractor) : singleObjectQuery(rowMapper); } - private JdbcQueryExecution createModifyingQueryExecutor() { + protected JdbcQueryExecution createModifyingQueryExecutor() { return (query, parameters) -> { diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcDeleteQueryCreator.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcDeleteQueryCreator.java new file mode 100644 index 0000000000..2037fc5929 --- /dev/null +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcDeleteQueryCreator.java @@ -0,0 +1,148 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jdbc.repository.query; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import org.springframework.data.domain.Sort; +import org.springframework.data.jdbc.core.convert.JdbcConverter; +import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.relational.core.dialect.Dialect; +import org.springframework.data.relational.core.dialect.RenderContextFactory; +import org.springframework.data.relational.core.mapping.PersistentPropertyPathExtension; +import org.springframework.data.relational.core.mapping.RelationalMappingContext; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.query.Criteria; +import org.springframework.data.relational.core.sql.Condition; +import org.springframework.data.relational.core.sql.Conditions; +import org.springframework.data.relational.core.sql.Delete; +import org.springframework.data.relational.core.sql.DeleteBuilder.DeleteWhere; +import org.springframework.data.relational.core.sql.Select; +import org.springframework.data.relational.core.sql.SelectBuilder.SelectWhere; +import org.springframework.data.relational.core.sql.StatementBuilder; +import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.relational.core.sql.render.SqlRenderer; +import org.springframework.data.relational.repository.query.RelationalEntityMetadata; +import org.springframework.data.relational.repository.query.RelationalParameterAccessor; +import org.springframework.data.relational.repository.query.RelationalQueryCreator; +import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link RelationalQueryCreator} that creates {@link Stream} of deletion {@link ParametrizedQuery} + * from a {@link PartTree}. + * + * @author Yunyoung LEE + * @since 2.3 + */ +class JdbcDeleteQueryCreator extends RelationalQueryCreator> { + + private final RelationalMappingContext context; + private final QueryMapper queryMapper; + private final RelationalEntityMetadata entityMetadata; + private final RenderContextFactory renderContextFactory; + + /** + * Creates new instance of this class with the given {@link PartTree}, {@link JdbcConverter}, {@link Dialect}, + * {@link RelationalEntityMetadata} and {@link RelationalParameterAccessor}. + * + * @param context + * @param tree part tree, must not be {@literal null}. + * @param converter must not be {@literal null}. + * @param dialect must not be {@literal null}. + * @param entityMetadata relational entity metadata, must not be {@literal null}. + * @param accessor parameter metadata provider, must not be {@literal null}. + */ + JdbcDeleteQueryCreator(RelationalMappingContext context, PartTree tree, JdbcConverter converter, Dialect dialect, + RelationalEntityMetadata entityMetadata, RelationalParameterAccessor accessor) { + super(tree, accessor); + + Assert.notNull(converter, "JdbcConverter must not be null"); + Assert.notNull(dialect, "Dialect must not be null"); + Assert.notNull(entityMetadata, "Relational entity metadata must not be null"); + + this.context = context; + + this.entityMetadata = entityMetadata; + this.queryMapper = new QueryMapper(dialect, converter); + this.renderContextFactory = new RenderContextFactory(dialect); + } + + @Override + protected Stream complete(@Nullable Criteria criteria, Sort sort) { + + RelationalPersistentEntity entity = entityMetadata.getTableEntity(); + Table table = Table.create(entityMetadata.getTableName()); + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + + SqlContext sqlContext = new SqlContext(entity); + + Condition condition = criteria == null ? null + : queryMapper.getMappedObject(parameterSource, criteria, table, entity); + + // create select criteria query for subselect + SelectWhere selectBuilder = StatementBuilder.select(sqlContext.getIdColumn()).from(table); + Select select = condition == null ? selectBuilder.build() : selectBuilder.where(condition).build(); + + // create delete relation queries + List deleteChain = new ArrayList<>(); + deleteRelations(deleteChain, entity, select); + + // crate delete query + DeleteWhere deleteBuilder = StatementBuilder.delete(table); + Delete delete = condition == null ? deleteBuilder.build() : deleteBuilder.where(condition).build(); + + deleteChain.add(delete); + + SqlRenderer renderer = SqlRenderer.create(renderContextFactory.createRenderContext()); + return deleteChain.stream().map(d -> new ParametrizedQuery(renderer.render(d), parameterSource)); + } + + private void deleteRelations(List deleteChain, RelationalPersistentEntity entity, Select parentSelect) { + + for (PersistentPropertyPath path : context + .findPersistentPropertyPaths(entity.getType(), p -> true)) { + + PersistentPropertyPathExtension extPath = new PersistentPropertyPathExtension(context, path); + + // prevent duplication on recursive call + if (path.getLength() > 1 && !extPath.getParentPath().isEmbedded()) { + continue; + } + + if (extPath.isEntity() && !extPath.isEmbedded()) { + + SqlContext sqlContext = new SqlContext(extPath.getLeafEntity()); + + Condition inCondition = Conditions.in(sqlContext.getTable().column(extPath.getReverseColumnName()), + parentSelect); + + Select select = StatementBuilder + .select(sqlContext.getTable().column(extPath.getIdDefiningParentPath().getIdColumnName()) + // sqlContext.getIdColumn() + ).from(sqlContext.getTable()).where(inCondition).build(); + deleteRelations(deleteChain, extPath.getLeafEntity(), select); + + deleteChain.add(StatementBuilder.delete(sqlContext.getTable()).where(inCondition).build()); + } + } + } +} diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java index fccbe0a00c..d217bd79c9 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.List; import java.util.function.LongSupplier; +import java.util.stream.Stream; import org.springframework.core.convert.converter.Converter; import org.springframework.data.domain.Pageable; @@ -123,6 +124,12 @@ public Object execute(Object[] values) { RelationalParametersParameterAccessor accessor = new RelationalParametersParameterAccessor(getQueryMethod(), values); + if (tree.isDelete()) { + JdbcQueryExecution execution = createModifyingQueryExecutor(); + return createDeleteQueries(accessor).map(query -> execution.execute(query.getQuery(), query.getParameterSource())) + .reduce((a, b) -> b); + } + ResultProcessor processor = getQueryMethod().getResultProcessor().withDynamicProjection(accessor); ParametrizedQuery query = createQuery(accessor, processor.getReturnedType()); JdbcQueryExecution execution = getQueryExecution(processor, accessor); @@ -185,6 +192,15 @@ protected ParametrizedQuery createQuery(RelationalParametersParameterAccessor ac return queryCreator.createQuery(getDynamicSort(accessor)); } + private Stream createDeleteQueries(RelationalParametersParameterAccessor accessor) { + + RelationalEntityMetadata entityMetadata = getQueryMethod().getEntityInformation(); + + JdbcDeleteQueryCreator queryCreator = new JdbcDeleteQueryCreator(context, tree, converter, dialect, entityMetadata, + accessor); + return queryCreator.createQuery(); + } + /** * {@link JdbcQueryExecution} returning a {@link org.springframework.data.domain.Slice}. * diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryEmbeddedWithCollectionIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryEmbeddedWithCollectionIntegrationTests.java index 9fd5523283..4fbce0a4ca 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryEmbeddedWithCollectionIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryEmbeddedWithCollectionIntegrationTests.java @@ -22,6 +22,7 @@ import java.sql.SQLException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import org.junit.jupiter.api.Test; @@ -231,9 +232,31 @@ public void deleteAll() { assertThat(repository.findAll()).isEmpty(); } + @Test // DATAJDBC-551 + public void deleteByTest() { + + DummyEntity one = repository.save(createDummyEntity("root1")); + DummyEntity two = repository.save(createDummyEntity("root2")); + DummyEntity three = repository.save(createDummyEntity("root3")); + + assertThat(repository.deleteByTest(two.getTest())).isEqualTo(1); + + assertThat(repository.findAll()) // + .extracting(DummyEntity::getId) // + .containsExactlyInAnyOrder(one.getId(), three.getId()); + + Long count = template.queryForObject("select count(1) from dummy_entity2", Collections.emptyMap(), Long.class); + assertThat(count).isEqualTo(4); + + } + private static DummyEntity createDummyEntity() { + return createDummyEntity("root"); + } + + private static DummyEntity createDummyEntity(String test) { DummyEntity entity = new DummyEntity(); - entity.setTest("root"); + entity.setTest(test); final Embeddable embeddable = new Embeddable(); embeddable.setTest("embedded"); @@ -252,7 +275,9 @@ private static DummyEntity createDummyEntity() { return entity; } - interface DummyEntityRepository extends CrudRepository {} + interface DummyEntityRepository extends CrudRepository { + int deleteByTest(String test); + } @Data private static class DummyEntity { diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithCollectionsChainIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithCollectionsChainIntegrationTests.java new file mode 100644 index 0000000000..0260f73c02 --- /dev/null +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithCollectionsChainIntegrationTests.java @@ -0,0 +1,132 @@ +package org.springframework.data.jdbc.repository; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.context.TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS; + +import lombok.Data; +import lombok.RequiredArgsConstructor; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.data.annotation.Id; +import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory; +import org.springframework.data.jdbc.testing.AssumeFeatureTestExecutionListener; +import org.springframework.data.jdbc.testing.TestConfiguration; +import org.springframework.data.repository.CrudRepository; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.junit.jupiter.SpringExtension; +import org.springframework.transaction.annotation.Transactional; + +/** + * Integration tests with collections chain. + * + * @author Yunyoung LEE + */ +@ContextConfiguration +@Transactional +@TestExecutionListeners(value = AssumeFeatureTestExecutionListener.class, mergeMode = MERGE_WITH_DEFAULTS) +@ExtendWith(SpringExtension.class) +public class JdbcRepositoryWithCollectionsChainIntegrationTests { + + @Autowired NamedParameterJdbcTemplate template; + @Autowired DummyEntityRepository repository; + + private static DummyEntity createDummyEntity() { + + DummyEntity entity = new DummyEntity(); + entity.setName("Entity Name"); + return entity; + } + + @Test // DATAJDBC-551 + public void deleteByName() { + + ChildElement element1 = createChildElement("one"); + ChildElement element2 = createChildElement("two"); + + DummyEntity entity = createDummyEntity(); + entity.content.add(element1); + entity.content.add(element2); + + entity = repository.save(entity); + + assertThat(repository.deleteByName("Entity Name")).isEqualTo(1); + + assertThat(repository.findById(entity.id)).isEmpty(); + + Long count = template.queryForObject("select count(1) from grand_child_element", new HashMap<>(), Long.class); + assertThat(count).isEqualTo(0); + } + + private ChildElement createChildElement(String name) { + + ChildElement element = new ChildElement(); + element.name = name; + element.content.add(createGrandChildElement(name + "1")); + element.content.add(createGrandChildElement(name + "2")); + return element; + } + + private GrandChildElement createGrandChildElement(String content) { + + GrandChildElement element = new GrandChildElement(); + element.content = content; + return element; + } + + interface DummyEntityRepository extends CrudRepository { + long deleteByName(String name); + } + + @Configuration + @Import(TestConfiguration.class) + static class Config { + + @Autowired JdbcRepositoryFactory factory; + + @Bean + Class testClass() { + return JdbcRepositoryWithCollectionsChainIntegrationTests.class; + } + + @Bean + DummyEntityRepository dummyEntityRepository() { + return factory.getRepository(DummyEntityRepository.class); + } + } + + @Data + static class DummyEntity { + + String name; + Set content = new HashSet<>(); + @Id private Long id; + + } + + @RequiredArgsConstructor + static class ChildElement { + + String name; + Set content = new HashSet<>(); + @Id private Long id; + } + + @RequiredArgsConstructor + static class GrandChildElement { + + String content; + @Id private Long id; + } + +} diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithCollectionsIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithCollectionsIntegrationTests.java index 5f62da52dc..3224789a4a 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithCollectionsIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithCollectionsIntegrationTests.java @@ -184,6 +184,26 @@ public void deletingWithSet() { assertThat(count).isEqualTo(0); } + @Test // DATAJDBC-551 + public void deleteByName() { + + Element element1 = createElement("one"); + Element element2 = createElement("two"); + + DummyEntity entity = createDummyEntity(); + entity.content.add(element1); + entity.content.add(element2); + + entity = repository.save(entity); + + assertThat(repository.deleteByName("Entity Name")).isEqualTo(1); + + assertThat(repository.findById(entity.id)).isEmpty(); + + Long count = template.queryForObject("select count(1) from Element", new HashMap<>(), Long.class); + assertThat(count).isEqualTo(0); + } + private Element createElement(String content) { Element element = new Element(); @@ -191,7 +211,9 @@ private Element createElement(String content) { return element; } - interface DummyEntityRepository extends CrudRepository {} + interface DummyEntityRepository extends CrudRepository { + long deleteByName(String name); + } @Configuration @Import(TestConfiguration.class) diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryWithCollectionsChainIntegrationTests-hsql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryWithCollectionsChainIntegrationTests-hsql.sql new file mode 100644 index 0000000000..3c26f132dc --- /dev/null +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryWithCollectionsChainIntegrationTests-hsql.sql @@ -0,0 +1,3 @@ +CREATE TABLE dummy_entity ( id BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, NAME VARCHAR(100)); +CREATE TABLE child_element (id BIGINT GENERATED BY DEFAULT AS IDENTITY (START WITH 1) PRIMARY KEY, NAME VARCHAR(100), dummy_entity BIGINT); +CREATE TABLE grand_child_element (id BIGINT GENERATED BY DEFAULT AS IDENTITY (START WITH 1) PRIMARY KEY, CONTENT VARCHAR(100), child_element BIGINT);