diff --git a/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp index 28f59f265..0a29276db 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp @@ -144,23 +144,21 @@ kernel void depalettize( device uchar *source [[buffer(0)]], device real4 *destination [[buffer(1)]], - uint3 tgid [[threadgroup_position_in_grid]], + uint tgid [[threadgroup_position_in_grid]], ushort lid [[thread_index_in_threadgroup]] ) { - device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 4) * tgid.y; + const uint block_idx = tgid / number_of_elements_per_segment; + device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 4) * block_idx; threadgroup real palette[palette_size]; if (lid < palette_size) { palette[lid] = ((device real*)ui0)[lid]; } threadgroup_barrier(mem_flags::mem_threadgroup); - const uint x = tgid.x * threadgroup_size + lid; - if (x + number_in_blocks * tgid.y >= number_of_elements) { - return; - } + const uint x = (tgid % number_of_elements_per_segment) * threadgroup_size + lid; device const uchar4 *ui1 = (device const uchar4*)(ui0 + sizeof(real) * palette_size); const uchar4 u = ui1[x]; const real4 d = real4(palette[u.x], palette[u.y], palette[u.z], palette[u.w]); - destination[number_in_blocks * tgid.y + x] = d; + destination[number_in_blocks * block_idx + x] = d; } )"; } @@ -201,16 +199,18 @@ kernel void depalettize( defines += "constant uint number_in_blocks = "; defines += std::to_string(hash.number_in_blocks / 4) + ";"; defines += "\n"; - if (hash.length % hash.number_in_blocks != 0) { - defines += "constant uint number_of_elements = "; - defines += std::to_string(hash.length / 4) + ";"; - defines += "\n"; - } const int num_blocks = (hash.length + hash.number_in_blocks - 1) / hash.number_in_blocks; CCV_NNC_MFA_PRECONDITION((hash.number_in_blocks % (256 * 4)) == 0); const int repeat_4 = hash.number_in_blocks / (256 * 4); - this->grid_size = MTL::Size(repeat_4, num_blocks, 1); - CCV_NNC_MFA_PRECONDITION((hash.length % 4) == 0); + if (hash.length % hash.number_in_blocks != 0) { + defines += "constant uint number_of_elements_per_segment = "; + defines += std::to_string(repeat_4) + ";"; + defines += "\n"; + this->grid_size = MTL::Size(repeat_4 * num_blocks, 1, 1); + } else { + this->grid_size = MTL::Size(repeat_4, num_blocks, 1); + } + CCV_NNC_MFA_PRECONDITION((hash.length % (256 * 4)) == 0); } auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());