-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support for cumsum and transpose_dst apis
* cumsum implementation * Add transpose dest, needed for rowwise cumsum
- Loading branch information
1 parent
b3d5095
commit d7a12ee
Showing
5 changed files
with
290 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ckernel_defs.h" | ||
#include "ckernel.h" | ||
#include "noc_nonblocking_api.h" | ||
|
||
#include "sfpi.h" | ||
|
||
using namespace sfpi; | ||
|
||
namespace ckernel | ||
{ | ||
namespace sfpu | ||
{ | ||
|
||
template <bool APPROXIMATION_MODE /*unused*/, int ITERATIONS /*unused*/> | ||
inline void _calculate_cumsum_(const bool first) | ||
{ | ||
if(first) | ||
{ | ||
// Clear context for F0 | ||
TTI_SFPMOV(0, 9, 4, 0); | ||
TTI_SFPMOV(0, 9, 5, 0); | ||
TTI_SFPMOV(0, 9, 6, 0); | ||
TTI_SFPMOV(0, 9, 7, 0); | ||
} | ||
|
||
// F0,1 R0 | ||
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 0); | ||
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 2); | ||
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 0 + 16); | ||
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 2 + 16); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(0, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(0, 0, ADDR_MOD_7, 0); | ||
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 2); | ||
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 0 + 16); | ||
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 2 + 16); | ||
|
||
// F0,1 R4 | ||
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 4); | ||
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 6); | ||
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 4 + 16); | ||
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 6 + 16); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(8, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(4, 0, ADDR_MOD_7, 4); | ||
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 6); | ||
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 4 + 16); | ||
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 6 + 16); | ||
|
||
// F0,1 R8 | ||
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 8); | ||
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 10); | ||
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 8 + 16); | ||
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 10 + 16); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(0, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(0, 0, ADDR_MOD_7, 8); | ||
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 10); | ||
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 8 + 16); | ||
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 10 + 16); | ||
|
||
// F0,1 R12 | ||
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 12); | ||
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 14); | ||
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 12 + 16); | ||
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 14 + 16); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(8, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(4, 0, ADDR_MOD_7, 12); | ||
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 14); | ||
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 12 + 16); | ||
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 14 + 16); | ||
|
||
// F2,3 R0 | ||
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 0 + 32); | ||
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 2 + 32); | ||
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 0 + 16 + 32); | ||
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 2 + 16 + 32); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(0, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(0, 0, ADDR_MOD_7, 0 + 32); | ||
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 2 + 32); | ||
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 0 + 16 + 32); | ||
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 2 + 16 + 32); | ||
|
||
// F2,3 R4 | ||
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 4 + 32); | ||
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 6 + 32); | ||
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 4 + 16 + 32); | ||
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 6 + 16 + 32); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(8, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(4, 0, ADDR_MOD_7, 4 + 32); | ||
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 6 + 32); | ||
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 4 + 16 + 32); | ||
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 6 + 16 + 32); | ||
|
||
// F2,3 R8 | ||
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 8 + 32); | ||
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 10 + 32); | ||
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 8 + 16 + 32); | ||
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 10 + 16 + 32); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(0, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(0, 0, ADDR_MOD_7, 8 + 32); | ||
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 10 + 32); | ||
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 8 + 16 + 32); | ||
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 10 + 16 + 32); | ||
|
||
// F2,3 R12 | ||
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 12 + 32); | ||
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 14 + 32); | ||
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 12 + 16 + 32); | ||
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 14 + 16 + 32); | ||
|
||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
TTI_REPLAY(8, 8, 0, 0); | ||
TTI_SFPTRANSP(0, 0, 0, 0); | ||
|
||
TTI_SFPSTORE(4, 0, ADDR_MOD_7, 12 + 32); | ||
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 14 + 32); | ||
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 12 + 16 + 32); | ||
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 14 + 16 + 32); | ||
} | ||
|
||
template <bool APPROXIMATION_MODE /*unused*/> | ||
inline void _cumsum_init_() | ||
{ | ||
load_replay_buf<0, 16, 0>( | ||
[] { | ||
TTI_SFPADD(10, 7, 0, 0, 0); | ||
TTI_SFPNOP; | ||
TTI_SFPADD(10, 0, 1, 1, 0); | ||
TTI_SFPNOP; | ||
TTI_SFPADD(10, 1, 2, 2, 0); | ||
TTI_SFPNOP; | ||
TTI_SFPADD(10, 2, 3, 3, 0); | ||
TTI_SFPNOP; | ||
TTI_SFPADD(10, 3, 4, 4, 0); | ||
TTI_SFPNOP; | ||
TTI_SFPADD(10, 4, 5, 5, 0); | ||
TTI_SFPNOP; | ||
TTI_SFPADD(10, 5, 6, 6, 0); | ||
TTI_SFPNOP; | ||
TTI_SFPADD(10, 6, 7, 7, 0); | ||
TTI_SFPNOP; | ||
}); | ||
} | ||
|
||
} // namespace sfpu | ||
} // namespace ckernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
#pragma once | ||
|
||
#include "ckernel_include.h" | ||
#include "ckernel_template.h" | ||
|
||
#include "cmath_common.h" | ||
#include "llk_math_common.h" | ||
#include "ckernel_globals.h" | ||
|
||
using namespace ckernel; | ||
|
||
// local function declarations | ||
inline void transpose_dest_configure_addrmod(); | ||
|
||
inline void _llk_math_transpose_dest_(const std::uint32_t dst_index) { | ||
math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(dst_index); | ||
|
||
TTI_STALLWAIT(p_stall::STALL_MATH, p_stall::WAIT_SFPU); | ||
|
||
ckernel_template::run(instrn_buffer); | ||
|
||
TTI_REPLAY(20, 5, 0, 0); | ||
TTI_REPLAY(26, 4, 0, 0); | ||
|
||
TTI_SETRWC(p_setrwc::CLR_AB, 0, 0, 0, 0, p_setrwc::SET_AB); | ||
|
||
math::clear_dst_reg_addr(); | ||
} | ||
|
||
inline void transpose_dest_configure_addrmod() { | ||
addr_mod_t{ | ||
.srca = {.incr = 0}, | ||
.srcb = {.incr = 0}, | ||
.dest = {.incr = 16}, | ||
}.set(ADDR_MOD_0); | ||
|
||
addr_mod_t{ | ||
.srca = {.incr = 0}, | ||
.srcb = {.incr = 0}, | ||
.dest = {.incr = 0}, | ||
}.set(ADDR_MOD_1); | ||
|
||
addr_mod_t{ | ||
.srca = {.incr = 0}, | ||
.srcb = {.incr = 0}, | ||
.dest = {.incr = -16}, | ||
}.set(ADDR_MOD_2); | ||
} | ||
|
||
inline void transpose_dest_configure_mop() { | ||
load_replay_buf<16, 16, 0>([] { | ||
// A | ||
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 0, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 0 - 16); | ||
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 4, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 4 - 16); | ||
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 8, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 8 - 16); | ||
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 12, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 12 - 16); | ||
|
||
// B | ||
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 0, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 0); | ||
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 4, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 4); | ||
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 8, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 8); | ||
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 12, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 12); | ||
|
||
// C | ||
TTI_TRNSPSRCB; | ||
|
||
// D | ||
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 32, ADDR_MOD_2, p_movd2b::MOV_1_ROW, 0); // throwaway to decrement dst | ||
|
||
// E | ||
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 0, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 0); | ||
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 4, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 4); | ||
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 8, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 8); | ||
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 12, ADDR_MOD_0, p_movb2d::MOV_4_ROWS, 12); | ||
|
||
//F | ||
TTI_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 0, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 0); | ||
TTI_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 4, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 4); | ||
}); | ||
|
||
uint AF = TT_OP_REPLAY(16, 16, 0, 0); | ||
uint BC = TT_OP_REPLAY(20, 5, 0, 0); | ||
uint E = TT_OP_REPLAY(26, 4, 0, 0); | ||
uint X = TT_OP_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 8, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 8); | ||
uint Y = TT_OP_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 12, ADDR_MOD_0, p_movb2d::MOV_4_ROWS, 12); | ||
|
||
ckernel_template tmp(1, 2, E, BC); | ||
tmp.set_start_op(BC); | ||
tmp.set_last_outer_loop_instr(AF); | ||
tmp.set_end_ops(X, Y); | ||
tmp.program(instrn_buffer); | ||
} | ||
|
||
inline void _llk_math_transpose_dest_init_() { | ||
|
||
transpose_dest_configure_addrmod(); | ||
|
||
transpose_dest_configure_mop(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters