Skip to content

Commit

Permalink
Merge branch '20-fp64_int8_14' into 'main'
Browse files Browse the repository at this point in the history
Add fp64_int8_14~17

See merge request mutsuki/ozimma!21
  • Loading branch information
enp1s0 committed Sep 29, 2023
2 parents 76be147 + 312a9a6 commit 2c74a2a
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 13 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The supported compute modes are [here](#supported-compute-mode).
| Mode | Tensor Core type | Num splits | |
|:--------------|:-----------------|:-----------|:------------------------|
|dgemm | -- | -- | Disable hijacking |
|fp64_int8_3 | Int8 TC | 3 | |
|fp64_int8_4 | Int8 TC | 4 | |
|fp64_int8_5 | Int8 TC | 5 | |
|fp64_int8_6 | Int8 TC | 6 | |
|fp64_int8_7 | Int8 TC | 7 | |
|fp64_int8_8 | Int8 TC | 8 | |
Expand All @@ -39,7 +42,12 @@ The supported compute modes are [here](#supported-compute-mode).
|fp64_int8_11 | Int8 TC | 11 | |
|fp64_int8_12 | Int8 TC | 12 | |
|fp64_int8_13 | Int8 TC | 13 | |
|fp64_int8_auto | Int8 TC | AUTO | fp64_int8_6..13 / dgemm |
|fp64_int8_14 | Int8 TC | 14 | |
|fp64_int8_15 | Int8 TC | 15 | |
|fp64_int8_16 | Int8 TC | 16 | |
|fp64_int8_17 | Int8 TC | 17 | |
|fp64_int8_18 | Int8 TC | 18 | |
|fp64_int8_auto | Int8 TC | AUTO | fp64_int8_3..18 / dgemm |


### Optional environmental variables
Expand Down
8 changes: 8 additions & 0 deletions include/ozimmu/ozimmu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ enum compute_mode_t {
sgemm,
dgemm,

fp64_int8_3,
fp64_int8_4,
fp64_int8_5,
fp64_int8_6,
fp64_int8_7,
fp64_int8_8,
Expand All @@ -25,6 +28,11 @@ enum compute_mode_t {
fp64_int8_11,
fp64_int8_12,
fp64_int8_13,
fp64_int8_14,
fp64_int8_15,
fp64_int8_16,
fp64_int8_17,
fp64_int8_18,

fp64_int8_auto,
};
Expand Down
16 changes: 16 additions & 0 deletions src/config.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ mtk::ozimmu::detail::split_config_t mtk::ozimmu::detail::get_split_config(
{original},
{{0, 0, detail::cublas_dgemm}}
};
case mtk::ozimmu::fp64_int8_3:
case mtk::ozimmu::fp64_int8_4:
case mtk::ozimmu::fp64_int8_5:
case mtk::ozimmu::fp64_int8_6:
case mtk::ozimmu::fp64_int8_7:
case mtk::ozimmu::fp64_int8_8:
Expand All @@ -25,8 +28,16 @@ mtk::ozimmu::detail::split_config_t mtk::ozimmu::detail::get_split_config(
case mtk::ozimmu::fp64_int8_11:
case mtk::ozimmu::fp64_int8_12:
case mtk::ozimmu::fp64_int8_13:
case mtk::ozimmu::fp64_int8_14:
case mtk::ozimmu::fp64_int8_15:
case mtk::ozimmu::fp64_int8_16:
case mtk::ozimmu::fp64_int8_17:
case mtk::ozimmu::fp64_int8_18:
{
unsigned num_split = 0;
if (compute_mode == mtk::ozimmu::fp64_int8_3 ) {num_split = 3;}
if (compute_mode == mtk::ozimmu::fp64_int8_4 ) {num_split = 4;}
if (compute_mode == mtk::ozimmu::fp64_int8_5 ) {num_split = 5;}
if (compute_mode == mtk::ozimmu::fp64_int8_6 ) {num_split = 6;}
if (compute_mode == mtk::ozimmu::fp64_int8_7 ) {num_split = 7;}
if (compute_mode == mtk::ozimmu::fp64_int8_8 ) {num_split = 8;}
Expand All @@ -35,6 +46,11 @@ mtk::ozimmu::detail::split_config_t mtk::ozimmu::detail::get_split_config(
if (compute_mode == mtk::ozimmu::fp64_int8_11) {num_split = 11;}
if (compute_mode == mtk::ozimmu::fp64_int8_12) {num_split = 12;}
if (compute_mode == mtk::ozimmu::fp64_int8_13) {num_split = 13;}
if (compute_mode == mtk::ozimmu::fp64_int8_13) {num_split = 14;}
if (compute_mode == mtk::ozimmu::fp64_int8_13) {num_split = 15;}
if (compute_mode == mtk::ozimmu::fp64_int8_13) {num_split = 16;}
if (compute_mode == mtk::ozimmu::fp64_int8_13) {num_split = 17;}
if (compute_mode == mtk::ozimmu::fp64_int8_13) {num_split = 18;}

// Data
std::vector<mtk::ozimmu::data_t> split_types(num_split + 1);
Expand Down
8 changes: 8 additions & 0 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ mtk::ozimmu::compute_mode_t get_compute_mode(
std::vector<mtk::ozimmu::compute_mode_t> supported_gemm_mode = {
mtk::ozimmu::sgemm,
mtk::ozimmu::dgemm,
mtk::ozimmu::fp64_int8_3,
mtk::ozimmu::fp64_int8_4,
mtk::ozimmu::fp64_int8_5,
mtk::ozimmu::fp64_int8_6,
mtk::ozimmu::fp64_int8_7,
mtk::ozimmu::fp64_int8_8,
Expand All @@ -33,6 +36,11 @@ mtk::ozimmu::compute_mode_t get_compute_mode(
mtk::ozimmu::fp64_int8_11,
mtk::ozimmu::fp64_int8_12,
mtk::ozimmu::fp64_int8_13,
mtk::ozimmu::fp64_int8_14,
mtk::ozimmu::fp64_int8_15,
mtk::ozimmu::fp64_int8_16,
mtk::ozimmu::fp64_int8_17,
mtk::ozimmu::fp64_int8_18,
mtk::ozimmu::fp64_int8_auto,
};

Expand Down
18 changes: 17 additions & 1 deletion src/gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,9 @@ int mtk::ozimmu::gemm(
input_type = mtk::ozimmu::fp32;
break;
case mtk::ozimmu::dgemm:
case mtk::ozimmu::fp64_int8_3:
case mtk::ozimmu::fp64_int8_4:
case mtk::ozimmu::fp64_int8_5:
case mtk::ozimmu::fp64_int8_6:
case mtk::ozimmu::fp64_int8_7:
case mtk::ozimmu::fp64_int8_8:
Expand All @@ -739,6 +742,11 @@ int mtk::ozimmu::gemm(
case mtk::ozimmu::fp64_int8_11:
case mtk::ozimmu::fp64_int8_12:
case mtk::ozimmu::fp64_int8_13:
case mtk::ozimmu::fp64_int8_14:
case mtk::ozimmu::fp64_int8_15:
case mtk::ozimmu::fp64_int8_16:
case mtk::ozimmu::fp64_int8_17:
case mtk::ozimmu::fp64_int8_18:
case mtk::ozimmu::fp64_int8_auto:
input_type = mtk::ozimmu::fp64;
break;
Expand All @@ -753,14 +761,22 @@ int mtk::ozimmu::gemm(

if (input_type == mtk::ozimmu::fp64) {
if (
compute_mode == mtk::ozimmu::fp64_int8_3 ||
compute_mode == mtk::ozimmu::fp64_int8_4 ||
compute_mode == mtk::ozimmu::fp64_int8_5 ||
compute_mode == mtk::ozimmu::fp64_int8_6 ||
compute_mode == mtk::ozimmu::fp64_int8_7 ||
compute_mode == mtk::ozimmu::fp64_int8_8 ||
compute_mode == mtk::ozimmu::fp64_int8_9 ||
compute_mode == mtk::ozimmu::fp64_int8_10 ||
compute_mode == mtk::ozimmu::fp64_int8_11 ||
compute_mode == mtk::ozimmu::fp64_int8_12 ||
compute_mode == mtk::ozimmu::fp64_int8_13
compute_mode == mtk::ozimmu::fp64_int8_13 ||
compute_mode == mtk::ozimmu::fp64_int8_14 ||
compute_mode == mtk::ozimmu::fp64_int8_15 ||
compute_mode == mtk::ozimmu::fp64_int8_16 ||
compute_mode == mtk::ozimmu::fp64_int8_17 ||
compute_mode == mtk::ozimmu::fp64_int8_18
) {
if (element_kind == mtk::ozimmu::real) {
using T = double;
Expand Down
34 changes: 33 additions & 1 deletion src/handle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,22 @@ std::size_t mtk::ozimmu::reallocate_working_memory(
const auto working_memory_C_fp64 = m * n * mtk::ozimmu::get_data_size_in_byte(fp64) * (element_kind == mtk::ozimmu::real ? 1 : 2);
std::size_t etc = 0;
if (
mode == mtk::ozimmu::fp64_int8_3 ||
mode == mtk::ozimmu::fp64_int8_4 ||
mode == mtk::ozimmu::fp64_int8_5 ||
mode == mtk::ozimmu::fp64_int8_6 ||
mode == mtk::ozimmu::fp64_int8_7 ||
mode == mtk::ozimmu::fp64_int8_8 ||
mode == mtk::ozimmu::fp64_int8_9 ||
mode == mtk::ozimmu::fp64_int8_10 ||
mode == mtk::ozimmu::fp64_int8_11 ||
mode == mtk::ozimmu::fp64_int8_12 ||
mode == mtk::ozimmu::fp64_int8_13
mode == mtk::ozimmu::fp64_int8_13 ||
mode == mtk::ozimmu::fp64_int8_14 ||
mode == mtk::ozimmu::fp64_int8_15 ||
mode == mtk::ozimmu::fp64_int8_16 ||
mode == mtk::ozimmu::fp64_int8_17 ||
mode == mtk::ozimmu::fp64_int8_18
) {
etc = (m + n) * mtk::ozimmu::get_data_size_in_byte(fp64) * (element_kind == mtk::ozimmu::real ? 1 : 2);
}
Expand Down Expand Up @@ -127,6 +135,12 @@ std::string mtk::ozimmu::get_compute_mode_name_str(
return "sgemm";
case mtk::ozimmu::dgemm:
return "dgemm";
case mtk::ozimmu::fp64_int8_3:
return "fp64_int8_3";
case mtk::ozimmu::fp64_int8_4:
return "fp64_int8_4";
case mtk::ozimmu::fp64_int8_5:
return "fp64_int8_5";
case mtk::ozimmu::fp64_int8_6:
return "fp64_int8_6";
case mtk::ozimmu::fp64_int8_7:
Expand All @@ -143,6 +157,16 @@ std::string mtk::ozimmu::get_compute_mode_name_str(
return "fp64_int8_12";
case mtk::ozimmu::fp64_int8_13:
return "fp64_int8_13";
case mtk::ozimmu::fp64_int8_14:
return "fp64_int8_14";
case mtk::ozimmu::fp64_int8_15:
return "fp64_int8_15";
case mtk::ozimmu::fp64_int8_16:
return "fp64_int8_16";
case mtk::ozimmu::fp64_int8_17:
return "fp64_int8_17";
case mtk::ozimmu::fp64_int8_18:
return "fp64_int8_18";
case mtk::ozimmu::fp64_int8_auto:
return "fp64_int8_auto";
default:
Expand All @@ -159,6 +183,9 @@ mtk::ozimmu::data_t mtk::ozimmu::get_output_type(
case mtk::ozimmu::sgemm:
return mtk::ozimmu::fp32;

case mtk::ozimmu::fp64_int8_4:
case mtk::ozimmu::fp64_int8_3:
case mtk::ozimmu::fp64_int8_5:
case mtk::ozimmu::fp64_int8_6:
case mtk::ozimmu::fp64_int8_7:
case mtk::ozimmu::fp64_int8_8:
Expand All @@ -167,6 +194,11 @@ mtk::ozimmu::data_t mtk::ozimmu::get_output_type(
case mtk::ozimmu::fp64_int8_11:
case mtk::ozimmu::fp64_int8_12:
case mtk::ozimmu::fp64_int8_13:
case mtk::ozimmu::fp64_int8_14:
case mtk::ozimmu::fp64_int8_15:
case mtk::ozimmu::fp64_int8_16:
case mtk::ozimmu::fp64_int8_17:
case mtk::ozimmu::fp64_int8_18:
case mtk::ozimmu::fp64_int8_auto:
case mtk::ozimmu::dgemm:
return mtk::ozimmu::fp64;
Expand Down
36 changes: 26 additions & 10 deletions src/split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ std::unordered_map<mtk::ozimmu::compute_mode_t, std::uint64_t> mtk::ozimmu::get_
m, n,
in_ptr,
ld,
6,
13,
3,
18,
bits_per_int8,
is_col_major
);
Expand All @@ -449,14 +449,22 @@ std::unordered_map<mtk::ozimmu::compute_mode_t, std::uint64_t> mtk::ozimmu::get_
unsigned long long int host_buffer[mtk::ozimmu::handle::mantissa_loss_counter_length];
CUTF_CHECK_ERROR(cudaMemcpy(host_buffer, handle.d_mantissa_loss_counter_ptr, sizeof(unsigned long long int) * mtk::ozimmu::handle::mantissa_loss_counter_length, cudaMemcpyDefault));

result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_6 , host_buffer[0]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_7 , host_buffer[1]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_8 , host_buffer[2]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_9 , host_buffer[3]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_10, host_buffer[4]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_11, host_buffer[5]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_12, host_buffer[6]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_13, host_buffer[7]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_3 , host_buffer[ 0]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_4 , host_buffer[ 1]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_5 , host_buffer[ 2]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_6 , host_buffer[ 3]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_7 , host_buffer[ 4]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_8 , host_buffer[ 5]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_9 , host_buffer[ 6]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_10, host_buffer[ 7]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_11, host_buffer[ 8]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_12, host_buffer[ 9]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_13, host_buffer[10]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_14, host_buffer[11]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_15, host_buffer[12]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_16, host_buffer[13]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_17, host_buffer[14]));
result.insert(std::make_pair<mtk::ozimmu::compute_mode_t, std::uint64_t>(mtk::ozimmu::fp64_int8_18, host_buffer[15]));
}

return result;
Expand Down Expand Up @@ -504,6 +512,9 @@ mtk::ozimmu::compute_mode_t auto_mode_select_core(
);

const std::vector<mtk::ozimmu::compute_mode_t> mode_candidate_order = {
mtk::ozimmu::fp64_int8_3,
mtk::ozimmu::fp64_int8_4,
mtk::ozimmu::fp64_int8_5,
mtk::ozimmu::fp64_int8_6,
mtk::ozimmu::fp64_int8_7,
mtk::ozimmu::fp64_int8_8,
Expand All @@ -512,6 +523,11 @@ mtk::ozimmu::compute_mode_t auto_mode_select_core(
mtk::ozimmu::fp64_int8_11,
mtk::ozimmu::fp64_int8_12,
mtk::ozimmu::fp64_int8_13,
mtk::ozimmu::fp64_int8_14,
mtk::ozimmu::fp64_int8_15,
mtk::ozimmu::fp64_int8_16,
mtk::ozimmu::fp64_int8_17,
mtk::ozimmu::fp64_int8_18,
};

for (const auto mode : mode_candidate_order) {
Expand Down
8 changes: 8 additions & 0 deletions test/main_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,9 @@ std::vector<mtk::ozimmu::compute_mode_t> get_supported_compute_mode() {
return std::vector<mtk::ozimmu::compute_mode_t>{
mtk::ozimmu::sgemm,
mtk::ozimmu::dgemm,
mtk::ozimmu::fp64_int8_3,
mtk::ozimmu::fp64_int8_4,
mtk::ozimmu::fp64_int8_5,
mtk::ozimmu::fp64_int8_6,
mtk::ozimmu::fp64_int8_7,
mtk::ozimmu::fp64_int8_8,
Expand All @@ -632,6 +635,11 @@ std::vector<mtk::ozimmu::compute_mode_t> get_supported_compute_mode() {
mtk::ozimmu::fp64_int8_11,
mtk::ozimmu::fp64_int8_12,
mtk::ozimmu::fp64_int8_13,
mtk::ozimmu::fp64_int8_14,
mtk::ozimmu::fp64_int8_15,
mtk::ozimmu::fp64_int8_16,
mtk::ozimmu::fp64_int8_17,
mtk::ozimmu::fp64_int8_18,
mtk::ozimmu::fp64_int8_auto,
};
}
Expand Down

0 comments on commit 2c74a2a

Please sign in to comment.