diff --git a/control_plane/attachment_service/src/persistence.rs b/control_plane/attachment_service/src/persistence.rs index cead54005804..623d6257676f 100644 --- a/control_plane/attachment_service/src/persistence.rs +++ b/control_plane/attachment_service/src/persistence.rs @@ -381,16 +381,22 @@ impl Persistence { self.with_conn(move |conn| -> DatabaseResult<()> { conn.transaction(|conn| -> DatabaseResult<()> { // Mark parent shards as splitting + + let expect_parent_records = std::cmp::max(1, old_shard_count.0); + let updated = diesel::update(tenant_shards) .filter(tenant_id.eq(split_tenant_id.to_string())) .filter(shard_count.eq(old_shard_count.0 as i32)) .set((splitting.eq(1),)) .execute(conn)?; - if ShardCount(updated.try_into().map_err(|_| DatabaseError::Logical(format!("Overflow existing shard count {} while splitting", updated)))?) != old_shard_count { + if u8::try_from(updated) + .map_err(|_| DatabaseError::Logical( + format!("Overflow existing shard count {} while splitting", updated)) + )? != expect_parent_records { // Perhaps a deletion or another split raced with this attempt to split, mutating // the parent shards that we intend to split. In this case the split request should fail. return Err(DatabaseError::Logical( - format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {old_shard_count:?})") + format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {expect_parent_records})") )); } diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 805eaa34b0fc..27d1cf2f3482 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -4,7 +4,7 @@ tenant_get_shards, ) from fixtures.remote_storage import s3_storage -from fixtures.types import TimelineId +from fixtures.types import TenantShardId, TimelineId from fixtures.workload import Workload @@ -84,6 +84,35 @@ def get_sizes(): assert timelines == {env.initial_timeline, timeline_b} +def test_sharding_split_unsharded( + neon_env_builder: NeonEnvBuilder, +): + """ + Test that shard splitting works on a tenant created as unsharded (i.e. with + ShardCount(0)). + """ + env = neon_env_builder.init_start() + tenant_id = env.initial_tenant + timeline_id = env.initial_timeline + + workload = Workload(env, tenant_id, timeline_id, branch_name="main") + workload.init() + workload.write_rows(256) + + # Check that we created with an unsharded TenantShardId: this is the default, + # but check it in case we change the default in future + assert env.attachment_service.inspect(TenantShardId(tenant_id, 0, 0)) is not None + + # Split one shard into two + env.attachment_service.tenant_shard_split(tenant_id, shard_count=2) + + # Check we got the shard IDs we expected + assert env.attachment_service.inspect(TenantShardId(tenant_id, 0, 2)) is not None + assert env.attachment_service.inspect(TenantShardId(tenant_id, 1, 2)) is not None + + workload.validate() + + def test_sharding_split_smoke( neon_env_builder: NeonEnvBuilder, ):