-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add subdevice support to multicore untilize #16193
base: main
Are you sure you want to change the base?
Changes from all commits
8179e3c
af86fc8
8b9d860
5cd6fc4
dd43c84
2e5edcf
f83acab
8f549c5
06ab0c9
ef48a3c
23e79d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
#include "ttnn/operation.hpp" | ||
#include "ttnn/operations/core/work_split/work_split_tilize.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/common/work_split.hpp" | ||
#include "tt_metal/detail/util.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
|
||
|
@@ -29,6 +30,190 @@ uint32_t get_largest_divisor(uint32_t dividend, uint32_t starting_divisor, uint3 | |
return 1; | ||
} | ||
|
||
operation::ProgramWithCallbacks untilize_multi_core_parallelize_column_subgrid( | ||
const Tensor& a, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the support here 1:1 with the version without subgrids? Asking because I'm curious if the validate needs updating. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the regular version splits the cores into core_range and core_range_cliff but when I ran it for my shapes, the core_range_cliff was empty so I think it is 1:1. |
||
Tensor& output, | ||
bool use_pack_untilize, | ||
bool fp32_dest_acc_en, | ||
const CoreRangeSet& sub_core_grids) { | ||
tt::tt_metal::Program program{}; | ||
|
||
tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); | ||
uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); | ||
tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); | ||
uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); | ||
|
||
Device* device = a.device(); | ||
|
||
uint32_t ntiles = a.volume() / TILE_HW; | ||
uint32_t ncores = sub_core_grids.num_cores(); | ||
for (uint32_t core_id = ncores; core_id >= 1; core_id--) { | ||
if (ntiles % ncores == 0) { | ||
break; | ||
} else { | ||
ncores--; | ||
} | ||
} | ||
|
||
TT_ASSERT(ntiles % (ncores) == 0); | ||
|
||
uint32_t max_tiles = 1; | ||
uint32_t ntiles_per_block = ntiles / ncores; | ||
uint32_t stick_s = a.get_legacy_shape()[-1]; | ||
uint32_t ntiles_per_row = stick_s / TILE_WIDTH; | ||
uint32_t stick_size = stick_s * output.element_size(); | ||
uint32_t ntiles_per_column = ntiles / ntiles_per_row; | ||
uint32_t starting_tile = ntiles_per_block; | ||
if (ntiles_per_row > max_tiles) { | ||
starting_tile = max_tiles; | ||
} | ||
ntiles_per_block = get_largest_divisor(ntiles_per_row, starting_tile); | ||
TT_ASSERT( | ||
ntiles_per_row % ntiles_per_block == 0 and ntiles_per_block >= 1 and ntiles_per_block <= ntiles_per_row and | ||
ntiles % ntiles_per_block == 0); | ||
|
||
uint32_t nblocks = (ntiles / ntiles_per_block); | ||
uint32_t block_size_nbytes = input_single_tile_size; | ||
|
||
auto cores = corerange_to_cores(sub_core_grids, ncores, true); | ||
auto all_cores = num_cores_to_corerangeset_in_subcoregrids(cores[0], ncores, sub_core_grids, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the only difference here that we're using a different work_split and num_cores_to_corerangeset_in_subcoregrids to get the cores? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If so, could you just have one function that if-elses when it's nullopt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the multicore version has some computations for num_x_cores, num_y_cores and splits the cores into core_range and core_range_cliff: |
||
uint32_t nblocks_per_core = nblocks / ncores; | ||
|
||
bool row_major = true; | ||
bool src_block_sharded = false; | ||
uint32_t num_rows_block = 0, block_row_size = 0, output_row_size = 0, last_block_row_size_unpadded = 0, | ||
num_output_rows_unpadded = 0; | ||
CoreCoord end_core; | ||
std::vector<CoreCoord> cores_with_rtargs; | ||
|
||
uint32_t num_input_tiles = ntiles_per_block * 2; | ||
auto [src0_cb_index, cb_src0] = create_cb( | ||
tt::CBIndex::c_0, program, all_cores, input_single_tile_size, num_input_tiles, input_cb_data_format, nullptr); | ||
|
||
uint32_t num_output_tiles = ntiles_per_block * 2; | ||
auto [output_cb_index, cb_output] = create_cb( | ||
tt::CBIndex::c_16, | ||
program, | ||
all_cores, | ||
output_single_tile_size, | ||
num_output_tiles, | ||
output_cb_data_format, | ||
nullptr); | ||
|
||
Buffer* src0_buffer = a.buffer(); | ||
Buffer* dst_buffer = output.buffer(); | ||
bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; | ||
std::vector<uint32_t> reader_ct_args = {(uint32_t)src0_is_dram}; | ||
|
||
auto reader_kernel_id = CreateKernel( | ||
program, | ||
"ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp", | ||
all_cores, | ||
ReaderDataMovementConfig(reader_ct_args)); | ||
|
||
bool out_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; | ||
bool stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); | ||
uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t)std::log2(stick_size) : 0; | ||
std::vector<uint32_t> writer_ct_args = { | ||
(uint32_t)out_is_dram, | ||
(uint32_t)stick_size_is_power_of_two, | ||
(uint32_t)log2_stick_size, | ||
}; | ||
|
||
auto writer_kernel_id = CreateKernel( | ||
program, | ||
"ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/dataflow/" | ||
"writer_unary_stick_layout_split_rows_interleaved_parallel_columns.cpp", | ||
all_cores, | ||
WriterDataMovementConfig(writer_ct_args)); | ||
|
||
/** compute | ||
*/ | ||
std::vector<uint32_t> compute_args = { | ||
(uint32_t)nblocks_per_core, // per_core_block_cnt | ||
(uint32_t)ntiles_per_block, // per_block_ntiles | ||
}; | ||
|
||
std::string compute_kernel( | ||
"ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/pack_untilize.cpp"); | ||
if (ntiles_per_block > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize) { | ||
log_debug(tt::LogOp, "Using slow untilize."); | ||
compute_kernel = | ||
std::string("ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize.cpp"); | ||
} else { | ||
log_debug(tt::LogOp, "Using fast pack untilize."); | ||
} | ||
|
||
auto untilize_kernel_id = CreateKernel( | ||
program, | ||
compute_kernel, | ||
all_cores, | ||
ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_args}); | ||
|
||
uint32_t tile_start_id = 0; | ||
uint32_t offset_within_stick = 0; | ||
|
||
auto nsticks_per_core = ntiles_per_column * TILE_HEIGHT; | ||
|
||
for (uint32_t i = 0; i < cores.size(); i++) { | ||
CoreCoord core = cores[i]; | ||
|
||
// reader runtime args | ||
auto ntiles_per_core = ntiles_per_block * nblocks_per_core; | ||
const std::array reader_rt_args = { | ||
src0_buffer->address(), // src_addr | ||
ntiles_per_core, // ntiles | ||
tile_start_id // start_id | ||
}; | ||
|
||
const std::array writer_rt_args = { | ||
dst_buffer->address(), // dst_addr | ||
nsticks_per_core, // nsticks | ||
stick_size, // block_size_nbytes | ||
ntiles_per_core, // ntiles_per_core | ||
TILE_WIDTH * output.element_size(), // tile_width_size | ||
std::uint32_t{0}, // start stick id = 0, since parallelizing on height | ||
offset_within_stick}; | ||
|
||
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_rt_args); | ||
tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_rt_args); | ||
cores_with_rtargs.push_back(core); | ||
tile_start_id += ntiles_per_core; | ||
offset_within_stick += ntiles_per_core * TILE_WIDTH * output.element_size(); | ||
} | ||
|
||
auto override_runtime_arguments_callback = [reader_kernel_id = reader_kernel_id, | ||
writer_kernel_id = writer_kernel_id, | ||
cb_src0 = cb_src0, | ||
cb_output = cb_output, | ||
cores_with_rtargs]( | ||
const void* operation, | ||
Program& program, | ||
const std::vector<Tensor>& input_tensors, | ||
const std::vector<std::optional<const Tensor>>&, | ||
const std::vector<Tensor>& output_tensors) { | ||
auto src_buffer = input_tensors.at(0).buffer(); | ||
auto dst_buffer = output_tensors.at(0).buffer(); | ||
{ | ||
auto& runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
for (const CoreCoord& core : cores_with_rtargs) { | ||
auto& runtime_args = runtime_args_by_core[core.x][core.y]; | ||
runtime_args[0] = src_buffer->address(); | ||
} | ||
} | ||
|
||
{ | ||
auto& runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
for (const CoreCoord& core : cores_with_rtargs) { | ||
auto& runtime_args = runtime_args_by_core[core.x][core.y]; | ||
runtime_args[0] = dst_buffer->address(); | ||
} | ||
} | ||
}; | ||
|
||
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; | ||
} | ||
|
||
operation::ProgramWithCallbacks untilize_multi_core_parallelize_column( | ||
const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { | ||
tt::tt_metal::Program program{}; | ||
|
@@ -41,12 +226,13 @@ operation::ProgramWithCallbacks untilize_multi_core_parallelize_column( | |
Device* device = a.device(); | ||
|
||
auto grid_size = device->compute_with_storage_grid_size(); | ||
|
||
uint32_t ntiles = a.volume() / TILE_HW; | ||
uint32_t ncores_x = grid_size.x; | ||
uint32_t ncores_y = grid_size.y; | ||
// uint32_t ncores_x = 2; | ||
|
||
ncores_x = get_largest_divisor(ntiles, ncores_x); | ||
uint32_t ncores_y = grid_size.y; | ||
// uint32_t ncores_y = 1; | ||
ncores_y = get_largest_divisor(ntiles, ncores_y, ncores_x); | ||
|
||
|
@@ -260,7 +446,11 @@ operation::ProgramWithCallbacks untilize_multi_core_parallelize_column( | |
} | ||
|
||
operation::ProgramWithCallbacks untilize_multi_core( | ||
const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { | ||
const Tensor& a, | ||
Tensor& output, | ||
bool use_pack_untilize, | ||
bool fp32_dest_acc_en, | ||
const std::optional<CoreRangeSet>& sub_core_grids) { | ||
tt::tt_metal::Program program{}; | ||
|
||
bool src_sharded = a.memory_config().is_sharded(); | ||
|
@@ -289,7 +479,12 @@ operation::ProgramWithCallbacks untilize_multi_core( | |
if (!src_sharded and !out_sharded) { | ||
uint32_t ntiles_height = ntiles / ntiles_per_block; | ||
if (ntiles_height == 1) { | ||
return untilize_multi_core_parallelize_column(a, output, use_pack_untilize, fp32_dest_acc_en); | ||
if (sub_core_grids.has_value()) { | ||
return untilize_multi_core_parallelize_column_subgrid( | ||
a, output, use_pack_untilize, fp32_dest_acc_en, sub_core_grids.value()); | ||
} else { | ||
return untilize_multi_core_parallelize_column(a, output, use_pack_untilize, fp32_dest_acc_en); | ||
} | ||
} else { | ||
return untilize_single_core(a, output, use_pack_untilize, fp32_dest_acc_en); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No other shapes are needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for our model use case, this is the only one