Skip to content

Commit

Permalink
Reduce calculation prod enhancement in case of inner dim (#2970)
Browse files Browse the repository at this point in the history
  • Loading branch information
seungmanhan authored Aug 1, 2024
1 parent 10beace commit 25169e2
Show file tree
Hide file tree
Showing 31 changed files with 1,606 additions and 914 deletions.
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ The MIOpen API library is structured as follows:
* :doc:`SGD <../doxygen/html/group___s_g_d>` (experimental)
* :doc:`ReduceExtreme <../doxygen/html/group__ReduceExtreme>` (experimental)
* :doc:`Getitem <../doxygen/html/group__getitem>` (experimental)
* :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental)
2 changes: 1 addition & 1 deletion driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ add_executable(MIOpenDriver
dm_pool.cpp
dm_reduce.cpp
dm_reduceextreme.cpp
dm_reducecalculation.cpp
dm_rnn.cpp
dm_softmax.cpp
dm_sum.cpp
dm_t5layernorm.cpp
dm_tensorop.cpp
dm_transformers_adam_w.cpp
Expand Down
14 changes: 7 additions & 7 deletions driver/dm_sum.cpp → driver/dm_reducecalculation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@
* SOFTWARE.
*
*******************************************************************************/
#include "reducecalculation_driver.hpp"
#include "registry_driver_maker.hpp"
#include "sum_driver.hpp"

static Driver* makeDriver(const std::string& base_arg)
{
if(base_arg == "sum")
return new SumDriver<float, float>();
if(base_arg == "sumfp16")
return new SumDriver<float16, float>();
if(base_arg == "sumbfp16")
return new SumDriver<bfloat16, float>();
if(base_arg == "reducecalculation")
return new ReduceCalculationDriver<float, float>();
if(base_arg == "reducecalculationfp16")
return new ReduceCalculationDriver<float16, float>();
if(base_arg == "reducecalculationbfp16")
return new ReduceCalculationDriver<bfloat16, float>();
return nullptr;
}

Expand Down
5 changes: 3 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
"groupnorm[bfp16|fp16], cat[bfp16|fp16], addlayernorm[bfp16|fp16], "
"t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], "
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, "
"getitem[bfp16|fp16]\n");
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand Down Expand Up @@ -205,7 +205,8 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "reduceextremefp16" && arg != "reduceextremebfp16" && arg != "adamw" &&
arg != "adamwfp16" && arg != "ampadamw" && arg != "transformersadamw" &&
arg != "transformersadamwfp16" && arg != "transformersampadamw" && arg != "getitem" &&
arg != "getitemfp16" && arg != "getitembfp16" && arg != "--version")
arg != "getitemfp16" && arg != "getitembfp16" && arg != "reducecalculation" &&
arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Usage();
Expand Down
Loading

0 comments on commit 25169e2

Please sign in to comment.