Skip to content

Commit

Permalink
New q8p kernel requires alignment to 256 * 4.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Mar 15, 2024
1 parent b51f161 commit 277a544
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,21 @@ kernel void depalettize(
device uchar *source [[buffer(0)]],
device real4 *destination [[buffer(1)]],
uint3 tgid [[threadgroup_position_in_grid]],
uint tgid [[threadgroup_position_in_grid]],
ushort lid [[thread_index_in_threadgroup]]
) {
device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 4) * tgid.y;
const uint block_idx = tgid / number_of_elements_per_segment;
device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 4) * block_idx;
threadgroup real palette[palette_size];
if (lid < palette_size) {
palette[lid] = ((device real*)ui0)[lid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint x = tgid.x * threadgroup_size + lid;
if (x + number_in_blocks * tgid.y >= number_of_elements) {
return;
}
const uint x = (tgid % number_of_elements_per_segment) * threadgroup_size + lid;
device const uchar4 *ui1 = (device const uchar4*)(ui0 + sizeof(real) * palette_size);
const uchar4 u = ui1[x];
const real4 d = real4(palette[u.x], palette[u.y], palette[u.z], palette[u.w]);
destination[number_in_blocks * tgid.y + x] = d;
destination[number_in_blocks * block_idx + x] = d;
}
)";
}
Expand Down Expand Up @@ -201,16 +199,18 @@ kernel void depalettize(
defines += "constant uint number_in_blocks = ";
defines += std::to_string(hash.number_in_blocks / 4) + ";";
defines += "\n";
if (hash.length % hash.number_in_blocks != 0) {
defines += "constant uint number_of_elements = ";
defines += std::to_string(hash.length / 4) + ";";
defines += "\n";
}
const int num_blocks = (hash.length + hash.number_in_blocks - 1) / hash.number_in_blocks;
CCV_NNC_MFA_PRECONDITION((hash.number_in_blocks % (256 * 4)) == 0);
const int repeat_4 = hash.number_in_blocks / (256 * 4);
this->grid_size = MTL::Size(repeat_4, num_blocks, 1);
CCV_NNC_MFA_PRECONDITION((hash.length % 4) == 0);
if (hash.length % hash.number_in_blocks != 0) {
defines += "constant uint number_of_elements_per_segment = ";
defines += std::to_string(repeat_4) + ";";
defines += "\n";
this->grid_size = MTL::Size(repeat_4 * num_blocks, 1, 1);
} else {
this->grid_size = MTL::Size(repeat_4, num_blocks, 1);
}
CCV_NNC_MFA_PRECONDITION((hash.length % (256 * 4)) == 0);
}

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
Expand Down

0 comments on commit 277a544

Please sign in to comment.