Skip to content

Commit

Permalink
support for cumsum and transpose_dst apis
Browse files Browse the repository at this point in the history
* cumsum implementation

* Add transpose dest, needed for rowwise cumsum
  • Loading branch information
lpremovicTT authored Sep 19, 2024
1 parent b3d5095 commit d7a12ee
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 1 deletion.
2 changes: 1 addition & 1 deletion common/inc/ckernel_addrmod.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct addr_mod_t
// CLR, CR, INCR(8 bits)
struct addr_mod_dest_t
{
uint8_t incr = 0;
int16_t incr = 0;
uint8_t clr = 0;
uint8_t cr = 0;
uint8_t c_to_cr = 0;
Expand Down
1 change: 1 addition & 0 deletions common/inc/ckernel_sfpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "sfpu/ckernel_sfpu_cast_fp32_to_fp16a.h"
#include "sfpu/ckernel_sfpu_clamp.h"
#include "sfpu/ckernel_sfpu_comp.h"
#include "sfpu/ckernel_sfpu_cumsum.h"
#include "sfpu/ckernel_sfpu_dropout.h"
#include "sfpu/ckernel_sfpu_exp.h"
#include "sfpu/ckernel_sfpu_gelu.h"
Expand Down
178 changes: 178 additions & 0 deletions common/inc/sfpu/ckernel_sfpu_cumsum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_defs.h"
#include "ckernel.h"
#include "noc_nonblocking_api.h"

#include "sfpi.h"

using namespace sfpi;

namespace ckernel
{
namespace sfpu
{

template <bool APPROXIMATION_MODE /*unused*/, int ITERATIONS /*unused*/>
inline void _calculate_cumsum_(const bool first)
{
if(first)
{
// Clear context for F0
TTI_SFPMOV(0, 9, 4, 0);
TTI_SFPMOV(0, 9, 5, 0);
TTI_SFPMOV(0, 9, 6, 0);
TTI_SFPMOV(0, 9, 7, 0);
}

// F0,1 R0
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 0);
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 2);
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 0 + 16);
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 2 + 16);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(0, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(0, 0, ADDR_MOD_7, 0);
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 2);
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 0 + 16);
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 2 + 16);

// F0,1 R4
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 4);
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 6);
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 4 + 16);
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 6 + 16);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(8, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(4, 0, ADDR_MOD_7, 4);
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 6);
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 4 + 16);
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 6 + 16);

// F0,1 R8
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 8);
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 10);
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 8 + 16);
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 10 + 16);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(0, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(0, 0, ADDR_MOD_7, 8);
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 10);
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 8 + 16);
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 10 + 16);

// F0,1 R12
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 12);
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 14);
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 12 + 16);
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 14 + 16);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(8, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(4, 0, ADDR_MOD_7, 12);
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 14);
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 12 + 16);
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 14 + 16);

// F2,3 R0
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 0 + 32);
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 2 + 32);
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 0 + 16 + 32);
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 2 + 16 + 32);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(0, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(0, 0, ADDR_MOD_7, 0 + 32);
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 2 + 32);
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 0 + 16 + 32);
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 2 + 16 + 32);

// F2,3 R4
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 4 + 32);
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 6 + 32);
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 4 + 16 + 32);
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 6 + 16 + 32);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(8, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(4, 0, ADDR_MOD_7, 4 + 32);
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 6 + 32);
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 4 + 16 + 32);
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 6 + 16 + 32);

// F2,3 R8
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 8 + 32);
TTI_SFPLOAD(1, 0, ADDR_MOD_7, 10 + 32);
TTI_SFPLOAD(2, 0, ADDR_MOD_7, 8 + 16 + 32);
TTI_SFPLOAD(3, 0, ADDR_MOD_7, 10 + 16 + 32);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(0, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(0, 0, ADDR_MOD_7, 8 + 32);
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 10 + 32);
TTI_SFPSTORE(2, 0, ADDR_MOD_7, 8 + 16 + 32);
TTI_SFPSTORE(3, 0, ADDR_MOD_7, 10 + 16 + 32);

// F2,3 R12
TTI_SFPLOAD(4, 0, ADDR_MOD_7, 12 + 32);
TTI_SFPLOAD(5, 0, ADDR_MOD_7, 14 + 32);
TTI_SFPLOAD(6, 0, ADDR_MOD_7, 12 + 16 + 32);
TTI_SFPLOAD(7, 0, ADDR_MOD_7, 14 + 16 + 32);

TTI_SFPTRANSP(0, 0, 0, 0);
TTI_REPLAY(8, 8, 0, 0);
TTI_SFPTRANSP(0, 0, 0, 0);

TTI_SFPSTORE(4, 0, ADDR_MOD_7, 12 + 32);
TTI_SFPSTORE(5, 0, ADDR_MOD_7, 14 + 32);
TTI_SFPSTORE(6, 0, ADDR_MOD_7, 12 + 16 + 32);
TTI_SFPSTORE(7, 0, ADDR_MOD_7, 14 + 16 + 32);
}

template <bool APPROXIMATION_MODE /*unused*/>
inline void _cumsum_init_()
{
load_replay_buf<0, 16, 0>(
[] {
TTI_SFPADD(10, 7, 0, 0, 0);
TTI_SFPNOP;
TTI_SFPADD(10, 0, 1, 1, 0);
TTI_SFPNOP;
TTI_SFPADD(10, 1, 2, 2, 0);
TTI_SFPNOP;
TTI_SFPADD(10, 2, 3, 3, 0);
TTI_SFPNOP;
TTI_SFPADD(10, 3, 4, 4, 0);
TTI_SFPNOP;
TTI_SFPADD(10, 4, 5, 5, 0);
TTI_SFPNOP;
TTI_SFPADD(10, 5, 6, 6, 0);
TTI_SFPNOP;
TTI_SFPADD(10, 6, 7, 7, 0);
TTI_SFPNOP;
});
}

} // namespace sfpu
} // namespace ckernel
104 changes: 104 additions & 0 deletions llk_lib/llk_math_transpose_dest.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0


#pragma once

#include "ckernel_include.h"
#include "ckernel_template.h"

#include "cmath_common.h"
#include "llk_math_common.h"
#include "ckernel_globals.h"

using namespace ckernel;

// local function declarations
inline void transpose_dest_configure_addrmod();

inline void _llk_math_transpose_dest_(const std::uint32_t dst_index) {
math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(dst_index);

TTI_STALLWAIT(p_stall::STALL_MATH, p_stall::WAIT_SFPU);

ckernel_template::run(instrn_buffer);

TTI_REPLAY(20, 5, 0, 0);
TTI_REPLAY(26, 4, 0, 0);

TTI_SETRWC(p_setrwc::CLR_AB, 0, 0, 0, 0, p_setrwc::SET_AB);

math::clear_dst_reg_addr();
}

inline void transpose_dest_configure_addrmod() {
addr_mod_t{
.srca = {.incr = 0},
.srcb = {.incr = 0},
.dest = {.incr = 16},
}.set(ADDR_MOD_0);

addr_mod_t{
.srca = {.incr = 0},
.srcb = {.incr = 0},
.dest = {.incr = 0},
}.set(ADDR_MOD_1);

addr_mod_t{
.srca = {.incr = 0},
.srcb = {.incr = 0},
.dest = {.incr = -16},
}.set(ADDR_MOD_2);
}

inline void transpose_dest_configure_mop() {
load_replay_buf<16, 16, 0>([] {
// A
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 0, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 0 - 16);
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 4, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 4 - 16);
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 8, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 8 - 16);
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 12, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 12 - 16);

// B
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 0, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 0);
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 4, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 4);
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 8, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 8);
TTI_MOVD2B(0, p_movd2b::SRC_ROW16_OFFSET + 12, ADDR_MOD_1, p_movd2b::MOV_4_ROWS, 12);

// C
TTI_TRNSPSRCB;

// D
TTI_MOVD2B(0, p_movd2b::SRC_ZERO_OFFSET + 32, ADDR_MOD_2, p_movd2b::MOV_1_ROW, 0); // throwaway to decrement dst

// E
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 0, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 0);
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 4, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 4);
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 8, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 8);
TTI_MOVB2D(0, p_movd2b::SRC_ROW16_OFFSET + 12, ADDR_MOD_0, p_movb2d::MOV_4_ROWS, 12);

//F
TTI_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 0, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 0);
TTI_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 4, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 4);
});

uint AF = TT_OP_REPLAY(16, 16, 0, 0);
uint BC = TT_OP_REPLAY(20, 5, 0, 0);
uint E = TT_OP_REPLAY(26, 4, 0, 0);
uint X = TT_OP_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 8, ADDR_MOD_1, p_movb2d::MOV_4_ROWS, 8);
uint Y = TT_OP_MOVB2D(0, p_movd2b::SRC_ZERO_OFFSET + 12, ADDR_MOD_0, p_movb2d::MOV_4_ROWS, 12);

ckernel_template tmp(1, 2, E, BC);
tmp.set_start_op(BC);
tmp.set_last_outer_loop_instr(AF);
tmp.set_end_ops(X, Y);
tmp.program(instrn_buffer);
}

inline void _llk_math_transpose_dest_init_() {

transpose_dest_configure_addrmod();

transpose_dest_configure_mop();
}
6 changes: 6 additions & 0 deletions llk_lib/llk_unpack_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,9 @@ inline void _llk_unpack_clear_dbg_feature_disable_(){
inline void _llk_enable_int8_fpu_math_() {
enable_int8_fpu_math();
}

inline void _llk_unpack_set_srcb_dummy_valid_() {
TTI_STALLWAIT(p_stall::STALL_UNPACK, p_stall::UNPACK);
TTI_UNPACR_NOP(SrcB, 0, 0, p_unpacr_nop::SET_DVALID, 0, 0, 0, 0, p_unpacr_nop::UNP_ZEROSRC);
TTI_UNPACR_NOP(SrcA, 0, 0, p_unpacr_nop::SET_DVALID, 0, 0, 0, 0, p_unpacr_nop::UNP_ZEROSRC);
}

0 comments on commit d7a12ee

Please sign in to comment.