diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java index 23647874f9b..29b34b9a6d1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java @@ -298,7 +298,7 @@ private Object setIdAndCascadingProperties(DbAction.WithEntity action, @N PersistentPropertyPathAccessor propertyAccessor = converter.getPropertyAccessor(persistentEntity, originalEntity); - if (IdValueSource.GENERATED.equals(action.getIdValueSource())) { + if (IdValueSource.isGeneratedByDatabased(action.getIdValueSource())) { propertyAccessor.setProperty(persistentEntity.getRequiredIdProperty(), generatedId); } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java index 4d210d516da..04bc8840824 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java @@ -19,9 +19,16 @@ import java.sql.ResultSet; import java.sql.SQLException; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.dao.OptimisticLockingFailureException; @@ -37,6 +44,7 @@ import org.springframework.data.relational.core.query.Query; import org.springframework.data.relational.core.sql.LockMode; import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.util.Pair; import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; @@ -60,6 +68,7 @@ * @author Radim Tlusty * @author Chirag Tailor * @author Diego Krupitza + * @author Mikhail Polivakha * @since 1.1 */ public class DefaultDataAccessStrategy implements DataAccessStrategy { @@ -102,31 +111,35 @@ public DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, Relation @Override public Object insert(T instance, Class domainType, Identifier identifier, IdValueSource idValueSource) { - SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forInsert(instance, domainType, identifier, - idValueSource); + RelationalPersistentEntity persistentEntity = context.getRequiredPersistentEntity(domainType); + + Optional idFromSequence = getIdFromSequenceIfAnyDefined(idValueSource, persistentEntity); + + SqlIdentifierParameterSource parameterSource = idFromSequence + .map(it -> sqlParametersFactory.forInsert(instance, domainType, identifier, it)) + .orElseGet(() -> sqlParametersFactory.forInsert(instance, domainType, identifier, idValueSource)); String insertSql = sql(domainType).getInsert(parameterSource.getIdentifiers()); - return insertStrategyFactory.insertStrategy(idValueSource, getIdColumn(domainType)).execute(insertSql, - parameterSource); - } + Object idAfterExecute = insertStrategyFactory.insertStrategy(idValueSource, getIdColumn(domainType)) + .execute(insertSql, parameterSource); + + return idFromSequence.map(it -> (Object) it).orElse(idAfterExecute); + } @Override public Object[] insert(List> insertSubjects, Class domainType, IdValueSource idValueSource) { Assert.notEmpty(insertSubjects, "Batch insert must contain at least one InsertSubject"); - SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects.stream() - .map(insertSubject -> sqlParametersFactory.forInsert(insertSubject.getInstance(), domainType, - insertSubject.getIdentifier(), idValueSource)) - .toArray(SqlIdentifierParameterSource[]::new); - String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers()); + if (IdValueSource.SEQUENCE.equals(idValueSource)) { + return executeBatchInsertWithSequenceAsIdSource(insertSubjects, domainType, idValueSource); + } else { + return executeBatchInsert(insertSubjects, domainType, idValueSource); + } + } - return insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType)).execute(insertSql, - sqlParameterSources); - } - - @Override + @Override public boolean update(S instance, Class domainType) { SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forUpdate(instance, domainType); @@ -446,4 +459,70 @@ private Class getBaseType(PersistentPropertyPath Object[] executeBatchInsert(List> insertSubjects, Class domainType, IdValueSource idValueSource) { + SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects + .stream() + .map(insertSubject -> sqlParametersFactory.forInsert( + insertSubject.getInstance(), domainType, + insertSubject.getIdentifier(), idValueSource) + ) + .toArray(SqlIdentifierParameterSource[]::new); + + String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers()); + + return insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType)) + .execute(insertSql, sqlParameterSources); + } + + private Object[] executeBatchInsertWithSequenceAsIdSource(List> insertSubjects, Class domainType, IdValueSource idValueSource) { + List> sqlParameterSources = createBatchParameterSourcesWithSequence(insertSubjects, domainType, + context.getPersistentEntity(domainType).getIdTargetSequence() + ); + + String insertSql = sql(domainType).getInsert(sqlParameterSources.get(0).getSecond().getIdentifiers()); + + insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType)) + .execute(insertSql, sqlParameterSources.stream() + .map(Pair::getSecond) + .toArray(SqlIdentifierParameterSource[]::new)); + + return sqlParameterSources.stream().map(Pair::getFirst).toArray(Object[]::new); + } + + private List> createBatchParameterSourcesWithSequence(List> insertSubjects, Class domainType, Optional idTargetSequence) { + List> sqlParameterSources; + int subjectsSize = insertSubjects.size(); + + List generatedIds = getMultipleIdsFromSequence(idTargetSequence.get(), subjectsSize); + + sqlParameterSources = IntStream + .range(0, subjectsSize) + .mapToObj(index -> { + InsertSubject subject = insertSubjects.get(index); + Long generatedId = generatedIds.get(index); + return Pair.of(generatedId, sqlParametersFactory.forInsert( + subject.getInstance(), domainType, + subject.getIdentifier(), generatedId + )); + }) + .collect(Collectors.toList()); + return sqlParameterSources; + } + + private Optional getIdFromSequenceIfAnyDefined(IdValueSource idValueSource, RelationalPersistentEntity persistentEntity) { + if (IdValueSource.SEQUENCE.equals(idValueSource) && persistentEntity.getIdTargetSequence().isPresent()) { + String nextSequenceValueSelect = insertStrategyFactory.getDialect().nextValueFromSequenceSelect(persistentEntity.getIdTargetSequence().get()); + return Optional.of(operations.queryForObject(nextSequenceValueSelect, Map.of(), (rs, rowNum) -> rs.getLong(1))); + } + return Optional.empty(); + } + + private List getMultipleIdsFromSequence(String sequenceName, Integer requiredIds) { + String nextSequenceValueSelect = insertStrategyFactory.getDialect().nextValueFromSequenceSelect(sequenceName); + + return IntStream.range(0, requiredIds) + .mapToObj(operand -> operations.queryForObject(nextSequenceValueSelect, Map.of(), (rs, rowNum) -> rs.getLong(1))) + .collect(Collectors.toList()); + } + } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/InsertStrategyFactory.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/InsertStrategyFactory.java index d7e4593c37f..a1291ec8018 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/InsertStrategyFactory.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/InsertStrategyFactory.java @@ -28,6 +28,7 @@ * * @author Chirag Tailor * @author Jens Schauder + * @author Mikhail Polivakha * @since 2.4 */ public class InsertStrategyFactory { @@ -102,4 +103,7 @@ public Object[] execute(String sql, SqlParameterSource[] sqlParameterSources) { } } + public Dialect getDialect() { + return this.dialect; + } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlIdentifierParameterSource.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlIdentifierParameterSource.java index 2b131ac7a9b..7dd90902702 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlIdentifierParameterSource.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlIdentifierParameterSource.java @@ -22,6 +22,7 @@ import java.util.Set; import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.util.Pair; import org.springframework.jdbc.core.namedparam.AbstractSqlParameterSource; /** @@ -35,9 +36,11 @@ */ class SqlIdentifierParameterSource extends AbstractSqlParameterSource { - private final Set identifiers = new HashSet<>(); + private final Set sqlIdentifiers = new HashSet<>(); private final Map namesToValues = new HashMap<>(); + private Pair idToValue; + @Override public boolean hasValue(String paramName) { return namesToValues.containsKey(paramName); @@ -54,30 +57,34 @@ public String[] getParameterNames() { } Set getIdentifiers() { - return Collections.unmodifiableSet(identifiers); + return Collections.unmodifiableSet(sqlIdentifiers); } void addValue(SqlIdentifier name, Object value) { addValue(name, value, Integer.MIN_VALUE); } - void addValue(SqlIdentifier identifier, Object value, int sqlType) { + void addValue(SqlIdentifier sqlIdentifier, Object value, int sqlType) { - identifiers.add(identifier); - String name = BindParameterNameSanitizer.sanitize(identifier.getReference()); + sqlIdentifiers.add(sqlIdentifier); + String name = prepareSqlIdentifierName(sqlIdentifier); namesToValues.put(name, value); registerSqlType(name, sqlType); } - void addAll(SqlIdentifierParameterSource others) { + void addAll(SqlIdentifierParameterSource others) { for (SqlIdentifier identifier : others.getIdentifiers()) { - String name = BindParameterNameSanitizer.sanitize( identifier.getReference()); + String name = prepareSqlIdentifierName(identifier); addValue(identifier, others.getValue(name), others.getSqlType(name)); } } + private static String prepareSqlIdentifierName(SqlIdentifier sqlIdentifier) { + return BindParameterNameSanitizer.sanitize(sqlIdentifier.getReference()); + } + int size() { return namesToValues.size(); } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java index 94f90de501f..cc7a319ced1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java @@ -15,10 +15,13 @@ */ package org.springframework.data.jdbc.core.convert; +import java.sql.ResultSet; +import java.sql.SQLException; import java.sql.SQLType; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Predicate; import org.springframework.data.jdbc.core.mapping.JdbcValue; @@ -26,10 +29,14 @@ import org.springframework.data.mapping.PersistentProperty; import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.relational.core.conversion.IdValueSource; +import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; +import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.jdbc.support.JdbcUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -46,13 +53,15 @@ public class SqlParametersFactory { private final RelationalMappingContext context; private final JdbcConverter converter; + private final Dialect dialect; - /** - * @since 3.1 - */ - public SqlParametersFactory(RelationalMappingContext context, JdbcConverter converter) { + private final NamedParameterJdbcOperations operations; + + public SqlParametersFactory(RelationalMappingContext context, JdbcConverter converter, Dialect dialect, NamedParameterJdbcOperations operations) { this.context = context; this.converter = converter; + this.dialect = dialect; + this.operations = operations; } /** @@ -70,18 +79,38 @@ public SqlParametersFactory(RelationalMappingContext context, JdbcConverter conv SqlIdentifierParameterSource forInsert(T instance, Class domainType, Identifier identifier, IdValueSource idValueSource) { + RelationalPersistentEntity persistentEntity = getRequiredPersistentEntity(domainType); + + Object idValue = null; + + if (IdValueSource.PROVIDED.equals(idValueSource)) { + idValue = persistentEntity.getIdentifierAccessor(instance).getRequiredIdentifier(); + } + return forInsert(instance, domainType, identifier, idValue); + } + + /** + * Creates the parameters for a SQL insert operation. That method is different from its sibling + * {@link #forInsert(Object, Class, Identifier, IdValueSource) forInsert method} in the sense, that + * this method is invoked when we actually know the id to be added to the {@link SqlParameterSource paarameter source}. + * It might be null, meaning, that we know for sure the id should be coming from the database, or + * it could be not null, meaning, that we've got the id from some source (user provided by himself, + * or we have queried the sequence for instance) + */ + SqlIdentifierParameterSource forInsert(T instance, Class domainType, Identifier identifier, + @Nullable Object id) { + RelationalPersistentEntity persistentEntity = getRequiredPersistentEntity(domainType); SqlIdentifierParameterSource parameterSource = getParameterSource(instance, persistentEntity, "", PersistentProperty::isIdProperty); identifier.forEach((name, value, type) -> addConvertedPropertyValue(parameterSource, name, value, type)); - if (IdValueSource.PROVIDED.equals(idValueSource)) { - - RelationalPersistentProperty idProperty = persistentEntity.getRequiredIdProperty(); - Object idValue = persistentEntity.getIdentifierAccessor(instance).getRequiredIdentifier(); - addConvertedPropertyValue(parameterSource, idProperty, idValue, idProperty.getColumnName()); - } + RelationalPersistentProperty idProperty = persistentEntity.getIdProperty(); + Optional + .ofNullable(id) + .filter(it -> idProperty != null) + .ifPresent(it -> addConvertedPropertyValue(parameterSource, idProperty, it, idProperty.getColumnName())); return parameterSource; } @@ -178,6 +207,13 @@ private void addConvertedPropertyValue(SqlIdentifierParameterSource parameterSou converter.getTargetSqlType(property)); } + private void addConvertedIdPropertyValue(SqlIdentifierParameterSource parameterSource, + RelationalPersistentProperty property, @Nullable Object value, SqlIdentifier name) { + + addConvertedValue(parameterSource, value, name, converter.getColumnType(property), + converter.getTargetSqlType(property)); + } + private void addConvertedPropertyValue(SqlIdentifierParameterSource parameterSource, SqlIdentifier name, Object value, Class javaType) { diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java index 3b8b8efd349..a4811b681a1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java @@ -86,7 +86,7 @@ public static DataAccessStrategy createCombinedAccessStrategy(RelationalMappingC NamespaceStrategy namespaceStrategy, Dialect dialect) { SqlGeneratorSource sqlGeneratorSource = new SqlGeneratorSource(context, converter, dialect); - SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter); + SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter, dialect, operations); InsertStrategyFactory insertStrategyFactory = new InsertStrategyFactory(operations, dialect); DataAccessStrategy defaultDataAccessStrategy = new DataAccessStrategyFactory( // diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/config/AbstractJdbcConfiguration.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/config/AbstractJdbcConfiguration.java index bd725b98d33..ae64df363b6 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/config/AbstractJdbcConfiguration.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/config/AbstractJdbcConfiguration.java @@ -207,7 +207,7 @@ public DataAccessStrategy dataAccessStrategyBean(NamedParameterJdbcOperations op SqlGeneratorSource sqlGeneratorSource = new SqlGeneratorSource(context, jdbcConverter, dialect); DataAccessStrategyFactory factory = new DataAccessStrategyFactory(sqlGeneratorSource, jdbcConverter, operations, - new SqlParametersFactory(context, jdbcConverter), + new SqlParametersFactory(context, jdbcConverter, dialect, operations), new InsertStrategyFactory(operations, dialect)); return factory.create(); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcRepositoryFactoryBean.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcRepositoryFactoryBean.java index a76db20a139..d51fff941a6 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcRepositoryFactoryBean.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/JdbcRepositoryFactoryBean.java @@ -177,7 +177,7 @@ public void afterPropertiesSet() { SqlGeneratorSource sqlGeneratorSource = new SqlGeneratorSource(this.mappingContext, this.converter, this.dialect); - SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(this.mappingContext, this.converter); + SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(this.mappingContext, this.converter, this.dialect, this.operations); InsertStrategyFactory insertStrategyFactory = new InsertStrategyFactory(this.operations, this.dialect); DataAccessStrategyFactory factory = new DataAccessStrategyFactory(sqlGeneratorSource, this.converter, diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlParametersFactoryTest.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlParametersFactoryTest.java index 7fd0f6e9a8b..117b48bde29 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlParametersFactoryTest.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlParametersFactoryTest.java @@ -27,6 +27,7 @@ import java.util.Objects; import org.junit.jupiter.api.Test; +import org.mockito.Mock; import org.springframework.core.convert.converter.Converter; import org.springframework.data.annotation.Id; import org.springframework.data.convert.ReadingConverter; @@ -38,11 +39,14 @@ import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; /** * Unit tests for {@link SqlParametersFactory}. * * @author Chirag Tailor + * @author Mikhail Polivakha */ class SqlParametersFactoryTest { @@ -50,7 +54,9 @@ class SqlParametersFactoryTest { RelationResolver relationResolver = mock(RelationResolver.class); MappingJdbcConverter converter = new MappingJdbcConverter(context, relationResolver); AnsiDialect dialect = AnsiDialect.INSTANCE; - SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter); + + NamedParameterJdbcOperations operations = mock(NamedParameterJdbcOperations.class); + SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter, dialect, operations); @Test // DATAJDBC-412 public void considersConfiguredWriteConverterForIdValueObjects_onRead() { @@ -301,6 +307,6 @@ private SqlParametersFactory createSqlParametersFactoryWithConverters(List co MappingJdbcConverter converter = new MappingJdbcConverter(context, relationResolver, new JdbcCustomConversions(converters), new DefaultJdbcTypeFactory(mock(JdbcOperations.class))); - return new SqlParametersFactory(context, converter); + return new SqlParametersFactory(context, converter, dialect, operations); } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java index 0704a16ca0f..2ea102a300a 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java @@ -68,6 +68,8 @@ import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory; import org.springframework.data.jdbc.testing.ConditionalOnDatabase; import org.springframework.data.jdbc.testing.DatabaseType; +import org.springframework.data.jdbc.testing.DisabledOnDatabase; +import org.springframework.data.jdbc.testing.EnabledOnDatabase; import org.springframework.data.jdbc.testing.EnabledOnFeature; import org.springframework.data.jdbc.testing.IntegrationTest; import org.springframework.data.jdbc.testing.TestConfiguration; @@ -75,6 +77,7 @@ import org.springframework.data.relational.core.mapping.Column; import org.springframework.data.relational.core.mapping.MappedCollection; import org.springframework.data.relational.core.mapping.Table; +import org.springframework.data.relational.core.mapping.TargetSequence; import org.springframework.data.relational.core.mapping.event.AbstractRelationalEvent; import org.springframework.data.relational.core.mapping.event.AfterConvertEvent; import org.springframework.data.relational.core.sql.LockMode; @@ -115,8 +118,8 @@ public class JdbcRepositoryIntegrationTests { @Autowired DummyEntityRepository repository; @Autowired MyEventListener eventListener; @Autowired RootRepository rootRepository; - @Autowired WithDelimitedColumnRepository withDelimitedColumnRepository; + @Autowired EntityWithSequenceRepository entityWithSequenceRepository; @BeforeEach public void before() { @@ -135,6 +138,29 @@ public void savesAnEntity() { "id_Prop = " + entity.getIdProp())).isEqualTo(1); } + @Test + @DisabledOnDatabase(database = DatabaseType.MYSQL) + public void saveEntityWithTargetSequenceSpecified() { + EntityWithSequence first = entityWithSequenceRepository.save(new EntityWithSequence("first")); + EntityWithSequence second = entityWithSequenceRepository.save(new EntityWithSequence("second")); + + assertThat(first.getId()).isNotNull(); + assertThat(second.getId()).isNotNull(); + assertThat(first.getId()).isLessThan(second.getId()); + assertThat(first.getName()).isEqualTo("first"); + assertThat(second.getName()).isEqualTo("second"); + } + + @Test + @DisabledOnDatabase(database = DatabaseType.MYSQL) + public void batchInsertEntityWithTargetSequenceSpecified() { + Iterable results = entityWithSequenceRepository.saveAll( + List.of(new EntityWithSequence("first"), new EntityWithSequence("second")) + ); + + assertThat(results).hasSize(2).extracting(EntityWithSequence::getId).containsExactly(1L, 2L); + } + @Test // DATAJDBC-95 public void saveAndLoadAnEntity() { @@ -1515,6 +1541,8 @@ interface RootRepository extends ListCrudRepository { interface WithDelimitedColumnRepository extends CrudRepository {} + interface EntityWithSequenceRepository extends CrudRepository {} + @Configuration @Import(TestConfiguration.class) static class Config { @@ -1536,6 +1564,11 @@ WithDelimitedColumnRepository withDelimitedColumnRepository() { return factory.getRepository(WithDelimitedColumnRepository.class); } + @Bean + EntityWithSequenceRepository entityWithSequenceRepository() { + return factory.getRepository(EntityWithSequenceRepository.class); + } + @Bean NamedQueries namedQueries() throws IOException { @@ -1839,6 +1872,32 @@ private static DummyEntity createEntity(String entityName, Consumer return entity; } + static class EntityWithSequence { + + @Id + @TargetSequence(sequence = "entity_sequence") + private Long id; + + private String name; + + public EntityWithSequence(Long id, String name) { + this.id = id; + this.name = name; + } + + public EntityWithSequence(String name) { + this.name = name; + } + + public Long getId() { + return id; + } + + public String getName() { + return name; + } + } + static class DummyEntity { String name; diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/SimpleJdbcRepositoryEventsUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/SimpleJdbcRepositoryEventsUnitTests.java index e50a67bb99f..c0603e6bf10 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/SimpleJdbcRepositoryEventsUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/SimpleJdbcRepositoryEventsUnitTests.java @@ -91,7 +91,7 @@ void before() { JdbcConverter converter = new MappingJdbcConverter(context, delegatingDataAccessStrategy, new JdbcCustomConversions(), new DefaultJdbcTypeFactory(operations.getJdbcOperations())); SqlGeneratorSource generatorSource = new SqlGeneratorSource(context, converter, dialect); - SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter); + SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter, dialect, operations); InsertStrategyFactory insertStrategyFactory = new InsertStrategyFactory(operations, dialect); this.dataAccessStrategy = spy(new DefaultDataAccessStrategy(generatorSource, context, converter, operations, diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/config/EnableJdbcRepositoriesIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/config/EnableJdbcRepositoriesIntegrationTests.java index 1e3b30f4cac..a27756bcef4 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/config/EnableJdbcRepositoriesIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/config/EnableJdbcRepositoriesIntegrationTests.java @@ -167,7 +167,7 @@ DataAccessStrategy defaultDataAccessStrategy( @Qualifier("namedParameterJdbcTemplate") NamedParameterJdbcOperations template, RelationalMappingContext context, JdbcConverter converter, Dialect dialect) { return new DataAccessStrategyFactory(new SqlGeneratorSource(context, converter, dialect), converter, template, - new SqlParametersFactory(context, converter), + new SqlParametersFactory(context, converter, dialect, template), new InsertStrategyFactory(template, dialect)).create(); } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/DisabledOnDatabase.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/DisabledOnDatabase.java new file mode 100644 index 00000000000..c83ec900f68 --- /dev/null +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/DisabledOnDatabase.java @@ -0,0 +1,27 @@ +package org.springframework.data.jdbc.testing; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.test.context.junit.jupiter.EnabledIf; + +/** + * Annotation that allows to disable a particular test to be executed on a particular database + * + * @author Mikhail Polivakha + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@ExtendWith(DisabledOnDatabaseExecutionCondition.class) +public @interface DisabledOnDatabase { + + /** + * The database on which the test is not supposed to run on + */ + DatabaseType database(); +} diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/DisabledOnDatabaseExecutionCondition.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/DisabledOnDatabaseExecutionCondition.java new file mode 100644 index 00000000000..17f9bfdf206 --- /dev/null +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/DisabledOnDatabaseExecutionCondition.java @@ -0,0 +1,36 @@ +package org.springframework.data.jdbc.testing; + +import org.apache.commons.lang3.ArrayUtils; +import org.junit.jupiter.api.extension.ConditionEvaluationResult; +import org.junit.jupiter.api.extension.ExecutionCondition; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.springframework.context.ApplicationContext; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +/** + * {@link ExecutionCondition} for the {@link DisabledOnDatabase} annotation + * + * @author Mikhail Polivakha + */ +public class DisabledOnDatabaseExecutionCondition implements ExecutionCondition { + + @Override + public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext context) { + ApplicationContext applicationContext = SpringExtension.getApplicationContext(context); + + MergedAnnotation disabledOnDatabaseMergedAnnotation = MergedAnnotations + .from(context.getRequiredTestMethod(), MergedAnnotations.SearchStrategy.DIRECT) + .get(DisabledOnDatabase.class); + + DatabaseType database = disabledOnDatabaseMergedAnnotation.getEnum("database", DatabaseType.class); + + if (ArrayUtils.contains(applicationContext.getEnvironment().getActiveProfiles(), database.getProfile())) { + return ConditionEvaluationResult.disabled( + "The test method '%s' is disabled for '%s' because of the @DisabledOnDatabase annotation".formatted(context.getRequiredTestMethod().getName(), database) + ); + } + return ConditionEvaluationResult.enabled("The test method '%s' is enabled".formatted(context.getRequiredTestMethod())); + } +} diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java index b84d93fe6b0..6175f2a055d 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/testing/TestConfiguration.java @@ -109,7 +109,7 @@ DataAccessStrategy defaultDataAccessStrategy( JdbcConverter converter, Dialect dialect) { return new DataAccessStrategyFactory(new SqlGeneratorSource(context, converter, dialect), converter, - template, new SqlParametersFactory(context, converter), + template, new SqlParametersFactory(context, converter, dialect, template), new InsertStrategyFactory(template, dialect)).create(); } diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-db2.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-db2.sql index 2c66f226e1a..1c00e779a6e 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-db2.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-db2.sql @@ -3,6 +3,8 @@ DROP TABLE ROOT; DROP TABLE INTERMEDIATE; DROP TABLE LEAF; DROP TABLE WITH_DELIMITED_COLUMN; +DROP TABLE ENTITY_WITH_SEQUENCE; +DROP SEQUENCE ENTITY_SEQUENCE; CREATE TABLE dummy_entity ( @@ -45,4 +47,12 @@ CREATE TABLE WITH_DELIMITED_COLUMN ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, "ORG.XTUNIT.IDENTIFIER" VARCHAR(100), STYPE VARCHAR(100) -); \ No newline at end of file +); + +CREATE TABLE ENTITY_WITH_SEQUENCE +( + ID BIGINT, + NAME VARCHAR(100) +); + +CREATE SEQUENCE ENTITY_SEQUENCE START WITH 1 INCREMENT BY 1 NO MAXVALUE; \ No newline at end of file diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-h2.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-h2.sql index b72f6645357..6f9087b69df 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-h2.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-h2.sql @@ -39,4 +39,12 @@ CREATE TABLE WITH_DELIMITED_COLUMN ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, "ORG.XTUNIT.IDENTIFIER" VARCHAR(100), STYPE VARCHAR(100) -); \ No newline at end of file +); + +CREATE TABLE ENTITY_WITH_SEQUENCE +( + ID BIGINT, + NAME VARCHAR(100) +); + +CREATE SEQUENCE ENTITY_SEQUENCE START WITH 1 INCREMENT BY 1 NO MAXVALUE; \ No newline at end of file diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-hsql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-hsql.sql index b72f6645357..6f9087b69df 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-hsql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-hsql.sql @@ -39,4 +39,12 @@ CREATE TABLE WITH_DELIMITED_COLUMN ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, "ORG.XTUNIT.IDENTIFIER" VARCHAR(100), STYPE VARCHAR(100) -); \ No newline at end of file +); + +CREATE TABLE ENTITY_WITH_SEQUENCE +( + ID BIGINT, + NAME VARCHAR(100) +); + +CREATE SEQUENCE ENTITY_SEQUENCE START WITH 1 INCREMENT BY 1 NO MAXVALUE; \ No newline at end of file diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mariadb.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mariadb.sql index 75b46639892..19ebad8bc38 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mariadb.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mariadb.sql @@ -39,4 +39,12 @@ CREATE TABLE WITH_DELIMITED_COLUMN ID BIGINT AUTO_INCREMENT PRIMARY KEY, `ORG.XTUNIT.IDENTIFIER` VARCHAR(100), STYPE VARCHAR(100) -); \ No newline at end of file +); + +CREATE TABLE ENTITY_WITH_SEQUENCE +( + ID BIGINT, + NAME VARCHAR(100) +); + +CREATE SEQUENCE ENTITY_SEQUENCE START WITH 1 INCREMENT BY 1 NO MAXVALUE; \ No newline at end of file diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mssql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mssql.sql index 9959dea4a81..69f191f65df 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mssql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mssql.sql @@ -3,6 +3,8 @@ DROP TABLE IF EXISTS ROOT; DROP TABLE IF EXISTS INTERMEDIATE; DROP TABLE IF EXISTS LEAF; DROP TABLE IF EXISTS WITH_DELIMITED_COLUMN; +DROP TABLE IF EXISTS ENTITY_WITH_SEQUENCE; +DROP SEQUENCE IF EXISTS ENTITY_SEQUENCE; CREATE TABLE dummy_entity ( @@ -45,4 +47,12 @@ CREATE TABLE WITH_DELIMITED_COLUMN ID BIGINT IDENTITY PRIMARY KEY, "ORG.XTUNIT.IDENTIFIER" VARCHAR(100), STYPE VARCHAR(100) -); \ No newline at end of file +); + +CREATE TABLE ENTITY_WITH_SEQUENCE +( + ID BIGINT, + NAME VARCHAR(100) +); + +CREATE SEQUENCE ENTITY_SEQUENCE START WITH 1 INCREMENT BY 1 NO MAXVALUE; \ No newline at end of file diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mysql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mysql.sql index 0d3e16587ff..43c33bc4404 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mysql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-mysql.sql @@ -42,4 +42,4 @@ CREATE TABLE WITH_DELIMITED_COLUMN ID BIGINT AUTO_INCREMENT PRIMARY KEY, `ORG.XTUNIT.IDENTIFIER` VARCHAR(100), STYPE VARCHAR(100) -); \ No newline at end of file +); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-oracle.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-oracle.sql index 0a08dfbf9ed..179ac5abb99 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-oracle.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-oracle.sql @@ -3,6 +3,8 @@ DROP TABLE ROOT CASCADE CONSTRAINTS PURGE; DROP TABLE INTERMEDIATE CASCADE CONSTRAINTS PURGE; DROP TABLE LEAF CASCADE CONSTRAINTS PURGE; DROP TABLE WITH_DELIMITED_COLUMN CASCADE CONSTRAINTS PURGE; +DROP TABLE ENTITY_WITH_SEQUENCE CASCADE CONSTRAINTS PURGE; +DROP SEQUENCE ENTITY_SEQUENCE; CREATE TABLE DUMMY_ENTITY ( @@ -46,3 +48,11 @@ CREATE TABLE WITH_DELIMITED_COLUMN "ORG.XTUNIT.IDENTIFIER" VARCHAR(100), STYPE VARCHAR(100) ) + +CREATE TABLE ENTITY_WITH_SEQUENCE +( + ID BIGINT, + NAME VARCHAR(100) +); + +CREATE SEQUENCE ENTITY_SEQUENCE START WITH 1 INCREMENT BY 1 NO MAXVALUE; \ No newline at end of file diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-postgres.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-postgres.sql index 37ad6914dee..14dff05925f 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-postgres.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.repository/JdbcRepositoryIntegrationTests-postgres.sql @@ -3,6 +3,8 @@ DROP TABLE ROOT; DROP TABLE INTERMEDIATE; DROP TABLE LEAF; DROP TABLE WITH_DELIMITED_COLUMN; +DROP TABLE ENTITY_WITH_SEQUENCE; +DROP SEQUENCE ENTITY_SEQUENCE; CREATE TABLE dummy_entity ( @@ -45,4 +47,12 @@ CREATE TABLE "WITH_DELIMITED_COLUMN" ID SERIAL PRIMARY KEY, "ORG.XTUNIT.IDENTIFIER" VARCHAR(100), "STYPE" VARCHAR(100) -); \ No newline at end of file +); + +CREATE TABLE ENTITY_WITH_SEQUENCE +( + ID BIGINT, + NAME VARCHAR(100) +); + +CREATE SEQUENCE ENTITY_SEQUENCE START WITH 1 INCREMENT BY 1 NO MAXVALUE; \ No newline at end of file diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/IdValueSource.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/IdValueSource.java index 0c7961ae619..0a2cc6696c6 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/IdValueSource.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/IdValueSource.java @@ -15,6 +15,10 @@ */ package org.springframework.data.relational.core.conversion; +import java.util.Optional; +import java.util.Set; + +import org.springframework.data.annotation.Id; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; @@ -22,6 +26,7 @@ * Enumeration describing the source of a value for an id property. * * @author Chirag Tailor + * @author Mikhail Polivakha * @since 2.4 */ public enum IdValueSource { @@ -39,7 +44,12 @@ public enum IdValueSource { /** * There is no id property, and therefore no id value source. */ - NONE; + NONE, + + /** + * The id should be dervied from the database sequence + */ + SEQUENCE; /** * Returns the appropriate {@link IdValueSource} for the instance: {@link IdValueSource#NONE} when the entity has no @@ -48,6 +58,11 @@ public enum IdValueSource { */ public static IdValueSource forInstance(Object instance, RelationalPersistentEntity persistentEntity) { + Optional idTargetSequence = persistentEntity.getIdTargetSequence(); + if (idTargetSequence.isPresent()) { + return IdValueSource.SEQUENCE; + } + Object idValue = persistentEntity.getIdentifierAccessor(instance).getIdentifier(); RelationalPersistentProperty idProperty = persistentEntity.getIdProperty(); if (idProperty == null) { @@ -62,4 +77,8 @@ public static IdValueSource forInstance(Object instance, RelationalPersisten return IdValueSource.GENERATED; } } + + public static boolean isGeneratedByDatabased(IdValueSource idValueSource) { + return Set.of(GENERATED, SEQUENCE).contains(idValueSource); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Db2Dialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Db2Dialect.java index d2f2fa3e7c6..4782eda6b9e 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Db2Dialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Db2Dialect.java @@ -25,6 +25,7 @@ * An SQL dialect for DB2. * * @author Jens Schauder + * @author Mikhail Polivakha * @since 2.0 */ public class Db2Dialect extends AbstractDialect { @@ -102,4 +103,14 @@ public IdentifierProcessing getIdentifierProcessing() { public Collection getConverters() { return Collections.singletonList(TimestampAtUtcToOffsetDateTimeConverter.INSTANCE); } + + /** + * This workaround (non-ANSI SQL way of querying sequence) exists for the same reasons it exists for {@link HsqlDbDialect} + * + * @see HsqlDbDialect#nextValueFromSequenceSelect(String) + */ + @Override + public String nextValueFromSequenceSelect(String sequenceName) { + return "SELECT NEXT VALUE FOR %s FROM SYSCAT.SEQUENCES LIMIT 1".formatted(sequenceName); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Dialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Dialect.java index 492b84f11fe..6ab754ef762 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Dialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/Dialect.java @@ -146,5 +146,11 @@ default SimpleFunction getExistsFunction() { default boolean supportsSingleQueryLoading() { return true; - }; + } + + default String nextValueFromSequenceSelect(String sequenceName) { + throw new UnsupportedOperationException( + "Currently, there is no support for sequence generation for %s dialect. If you need it, please, submit a ticket".formatted(this.getClass().getSimpleName()) + ); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/H2Dialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/H2Dialect.java index a13212971a2..41b51e131e3 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/H2Dialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/H2Dialect.java @@ -31,6 +31,7 @@ * @author Myeonghyeon Lee * @author Christph Strobl * @author Jens Schauder + * @author Mikhail Polivakha * @since 2.0 */ public class H2Dialect extends AbstractDialect { @@ -113,4 +114,9 @@ public Set> simpleTypes() { public boolean supportsSingleQueryLoading() { return false; } + + @Override + public String nextValueFromSequenceSelect(String sequenceName) { + return "SELECT NEXT VALUE FOR %s".formatted(sequenceName); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/HsqlDbDialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/HsqlDbDialect.java index 268f59cc528..b17bc2d22dc 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/HsqlDbDialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/HsqlDbDialect.java @@ -20,6 +20,7 @@ * * @author Jens Schauder * @author Myeonghyeon Lee + * @author Mikhail Polivakha */ public class HsqlDbDialect extends AbstractDialect { @@ -64,4 +65,16 @@ public Position getClausePosition() { return Position.AFTER_ORDER_BY; } }; + + /** + * One may think that this is an over-complication, but it is actually not. + * There is no a direct way to query the next value for the sequence, only to use it as an expression + * inside other queries (SELECT/INSERT). Therefore, such a workaround is required + * + * @see The way JOOQ solves this problem + */ + @Override + public String nextValueFromSequenceSelect(String sequenceName) { + return "SELECT NEXT VALUE FOR %s AS msq FROM INFORMATION_SCHEMA.SEQUENCES LIMIT 1".formatted(sequenceName); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/MariaDbDialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/MariaDbDialect.java index 4387724134c..ff95f8ecc5d 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/MariaDbDialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/MariaDbDialect.java @@ -24,6 +24,7 @@ * A SQL dialect for MariaDb. * * @author Jens Schauder + * @author Mikhail Polivakha * @since 2.3 */ public class MariaDbDialect extends MySqlDialect { @@ -38,4 +39,9 @@ public Collection getConverters() { TimestampAtUtcToOffsetDateTimeConverter.INSTANCE, NumberToBooleanConverter.INSTANCE); } + + @Override + public String nextValueFromSequenceSelect(String sequenceName) { + return "SELECT NEXTVAL(%s)".formatted(sequenceName); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/OracleDialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/OracleDialect.java index 4970d507591..8ce7e74488e 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/OracleDialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/OracleDialect.java @@ -27,6 +27,7 @@ * An SQL dialect for Oracle. * * @author Jens Schauder + * @author Mikahil Polivakha * @since 2.1 */ public class OracleDialect extends AnsiDialect { @@ -69,4 +70,9 @@ public Integer convert(Boolean bool) { return bool ? 1 : 0; } } + + @Override + public String nextValueFromSequenceSelect(String sequenceName) { + return "SELECT %s.nextval FROM DUAL".formatted(sequenceName); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/PostgresDialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/PostgresDialect.java index ca0d52c2eab..063b38a9596 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/PostgresDialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/PostgresDialect.java @@ -42,6 +42,7 @@ * @author Myeonghyeon Lee * @author Jens Schauder * @author Nikita Konev + * @author Mikhail Polivakha * @since 1.1 */ public class PostgresDialect extends AbstractDialect { @@ -163,4 +164,9 @@ public Set> simpleTypes() { public SimpleFunction getExistsFunction() { return Functions.least(Functions.count(SQL.literalOf(1)), SQL.literalOf(1)); } + + @Override + public String nextValueFromSequenceSelect(String sequenceName) { + return "SELECT nextval('%s')".formatted(sequenceName); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/SqlServerDialect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/SqlServerDialect.java index 2eb2a1ee9a3..4471ad6bbe3 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/SqlServerDialect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/dialect/SqlServerDialect.java @@ -131,4 +131,9 @@ public InsertRenderContext getInsertRenderContext() { public OrderByNullPrecedence orderByNullHandling() { return OrderByNullPrecedence.NONE; } + + @Override + public String nextValueFromSequenceSelect(String sequenceName) { + return "SELECT NEXT VALUE FOR %s".formatted(sequenceName); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntity.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntity.java index c8d67cb1b25..8bd6b5a2cd2 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntity.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntity.java @@ -17,6 +17,7 @@ import java.util.Optional; +import org.jetbrains.annotations.NotNull; import org.springframework.data.mapping.model.BasicPersistentEntity; import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.util.Lazy; @@ -46,6 +47,8 @@ class BasicRelationalPersistentEntity extends BasicPersistentEntity tableName; private final @Nullable Expression tableNameExpression; + private final Lazy idTargetSequenceName; + private final Lazy> schemaName; private final @Nullable Expression schemaNameExpression; private final ExpressionEvaluator expressionEvaluator; @@ -87,6 +90,8 @@ class BasicRelationalPersistentEntity extends BasicPersistentEntity getIdTargetSequence() { + return idTargetSequenceName.getOptional(); + } + @Override public String toString() { return String.format("BasicRelationalPersistentEntity<%s>", getType()); } + + private @Nullable String determineTargetSequenceName() { + RelationalPersistentProperty idProperty = getIdProperty(); + + if (idProperty != null && idProperty.isAnnotationPresent(TargetSequence.class)) { + TargetSequence requiredAnnotation = idProperty.getRequiredAnnotation(TargetSequence.class); + if (!StringUtils.hasText(requiredAnnotation.sequence()) && !StringUtils.hasText(requiredAnnotation.value())) { + throw new IllegalStateException(""" + For the persistent entity '%s' the @TargetSequence annotation was specified for the @Id, however, neither + the value() nor the sequence() attributes are specified + """ + ); + } else { + String sequenceFullyQualifiedName = getSequenceName(requiredAnnotation); + if (StringUtils.hasText(requiredAnnotation.schema())) { + return String.join(".", requiredAnnotation.schema(), sequenceFullyQualifiedName); + } + return sequenceFullyQualifiedName; + } + } else { + return null; + } + } + + @NotNull + private static String getSequenceName(TargetSequence requiredAnnotation) { + return Optional.of(requiredAnnotation.sequence()) + .filter(s -> !s.isBlank()) + .orElse(requiredAnnotation.value()); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java index e5432499a79..2c915dd21c0 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java @@ -17,6 +17,7 @@ import java.lang.annotation.Annotation; import java.util.Iterator; +import java.util.Optional; import org.springframework.core.env.Environment; import org.springframework.data.mapping.*; @@ -31,6 +32,7 @@ * Embedded entity extension for a {@link Embedded entity}. * * @author Mark Paluch + * @author Mikhail Polivakha * @since 3.2 */ class EmbeddedRelationalPersistentEntity implements RelationalPersistentEntity { @@ -54,6 +56,11 @@ public SqlIdentifier getIdColumn() { throw new MappingException("Embedded entity does not have an id column"); } + @Override + public Optional getIdTargetSequence() { + return Optional.empty(); + } + @Override public void addPersistentProperty(RelationalPersistentProperty property) { throw new UnsupportedOperationException(); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/RelationalPersistentEntity.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/RelationalPersistentEntity.java index f54587a19d5..fea5f9c86c0 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/RelationalPersistentEntity.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/RelationalPersistentEntity.java @@ -15,6 +15,8 @@ */ package org.springframework.data.relational.core.mapping; +import java.util.Optional; + import org.springframework.data.mapping.model.MutablePersistentEntity; import org.springframework.data.relational.core.sql.SqlIdentifier; @@ -25,6 +27,7 @@ * @author Jens Schauder * @author Oliver Gierke * @author Mark Paluch + * @author Mikhail Polivakha */ public interface RelationalPersistentEntity extends MutablePersistentEntity { @@ -52,4 +55,8 @@ default SqlIdentifier getQualifiedTableName() { */ SqlIdentifier getIdColumn(); + /** + * @return the target sequence that should be used for id generation + */ + Optional getIdTargetSequence(); } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/TargetSequence.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/TargetSequence.java new file mode 100644 index 00000000000..be16bcfc7fe --- /dev/null +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/TargetSequence.java @@ -0,0 +1,43 @@ +package org.springframework.data.relational.core.mapping; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; + +/** + * Specify the sequence from which the value for the {@link org.springframework.data.annotation.Id} + * should be fetched + * + * @author Mikhail Polivakha + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.FIELD) +@Documented +public @interface TargetSequence { + + /** + * The name of the sequence from which the id should be fetched + */ + String value() default ""; + + /** + * Alias for {@link #value()} + */ + @AliasFor("value") + String sequence() default ""; + + /** + * Schema where the sequence reside. + * Technically, this attribute is not necessarily the schema. It just represents the location/namespace, + * where the sequence resides. For instance, in Oracle databases the schema and user are often used + * interchangeably, so {@link #schema() schema} attribute may represent an Oracle user as well. + *

+ * The final name of the sequence to be queried for the next value will be constructed by the concatenation + * of schema and sequence :

schema().sequence()
+ */ + String schema() default ""; +} diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntityUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntityUnitTests.java index 83f56e80121..8fc0b98033b 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntityUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/BasicRelationalPersistentEntityUnitTests.java @@ -58,6 +58,34 @@ void discoversAnnotatedTableName() { assertThat(entity.getTableName()).isEqualTo(quoted("dummy_sub_entity")); } + @Test + void entityWithNotargetSequence() { + RelationalPersistentEntity entity = mappingContext.getRequiredPersistentEntity(DummySubEntity.class); + + assertThat(entity.getIdTargetSequence()).isEmpty(); + } + + @Test + void determineSequenceName() { + RelationalPersistentEntity persistentEntity = mappingContext.getPersistentEntity(EntityWithSequence.class); + + assertThat(persistentEntity.getIdTargetSequence()).isPresent().hasValue("my_seq"); + } + + @Test + void determineSequenceNameFromValue() { + RelationalPersistentEntity persistentEntity = mappingContext.getPersistentEntity(EntityWithSequenceValueAlias.class); + + assertThat(persistentEntity.getIdTargetSequence()).isPresent().hasValue("my_seq"); + } + + @Test + void determineSequenceNameWithSchemaSpecified() { + RelationalPersistentEntity persistentEntity = mappingContext.getPersistentEntity(EntityWithSequenceAndSchema.class); + + assertThat(persistentEntity.getIdTargetSequence()).isPresent().hasValue("public.my_seq"); + } + @Test // DATAJDBC-294 void considerIdColumnName() { @@ -201,6 +229,24 @@ static class DummySubEntity { @Column("renamedId") Long id; } + @Table("entity_with_sequence") + static class EntityWithSequence { + @Id + @TargetSequence(sequence = "my_seq") Long id; + } + + @Table("entity_with_sequence_value_alias") + static class EntityWithSequenceValueAlias { + @Id + @Column("myId") @TargetSequence(value = "my_seq") Long id; + } + + @Table("entity_with_sequence_and_schema") + static class EntityWithSequenceAndSchema { + @Id + @Column("myId") @TargetSequence(sequence = "my_seq", schema = "public") Long id; + } + @Table() static class DummyEntityWithEmptyAnnotation { @Id