Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify small-batched weight only quantization #2213

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
638b533
Change config files to suit in the A100 server
dasistwo Apr 11, 2024
3d643cc
Update submodules
dasistwo Apr 11, 2024
bac459d
Merge branch 'NVIDIA:main' into mlp
dasistwo Apr 11, 2024
c03b407
Change summarization task default setting
dasistwo Apr 18, 2024
298fcc1
Change CMakeLists for a debugging purpose
dasistwo Apr 23, 2024
dfc29ad
Apply shared mem to scale factor of quantization.
dasistwo Apr 23, 2024
40b1cfb
Remove redundancy of loading scale factors
dasistwo Apr 25, 2024
cc1d2c1
Apply asyncs to scale factors
dasistwo Apr 25, 2024
3fa1699
Apply shared mem asyncs to zeropoints
dasistwo Apr 26, 2024
cde2e2e
[feat]: Support weight only gemm with 2bit
gavinchen430 Apr 30, 2024
fcc7144
refactoring offset
dasistwo May 13, 2024
0875817
Merge branch 'mlp'
dasistwo May 13, 2024
b82286f
Merge pull request #1 from dasistwo/main
dasistwo May 13, 2024
249d93d
Update TensorRT-LLM (#506)
dasistwo Jul 29, 2024
a17b14f
Merge branch 'NVIDIA-main' into mlp
dasistwo Jul 29, 2024
cb76c98
Fix GCC 13 compile error
dasistwo Jul 31, 2024
c8c6432
Fix TensorRT layermap error
dasistwo Aug 1, 2024
43bf1a6
Merge branch 'NVIDIA/main'
dasistwo Aug 5, 2024
5a70210
Merge branch 'NVIDIA-main' into mlp
dasistwo Aug 5, 2024
a0f8499
Fix bug: loading ModelSpec in test
dasistwo Aug 6, 2024
a6fe44d
Fix L1 shared bank conflict
dasistwo Aug 6, 2024
bfa1b74
Revoke private changes
dasistwo Aug 16, 2024
7ff5302
Merge branch 'main' into mlp
dasistwo Aug 16, 2024
5971baf
Refactor & Revoke commit 'Fix L1 shared bank conflict'
dasistwo Aug 27, 2024
d5ecf92
Copy to shared memory within K iteration
dasistwo Aug 28, 2024
bde7127
Merge branch 'NVIDIA:main' into mlp
dasistwo Aug 28, 2024
c2ccb90
Refactoring & Apply double buffering for weight
dasistwo Sep 4, 2024
4c224fd
Merge branch 'mlp' of github.com:dasistwo/TensorRT-LLM into mlp
dasistwo Sep 4, 2024
1488f3f
Debug ColumnMajor Test Case
dasistwo Sep 5, 2024
e3e6d93
Apply double buffering for Act
dasistwo Sep 5, 2024
fb8ab20
Reduce shared memory size & increase grid size
dasistwo Sep 5, 2024
599f5e8
Revert "Increase grid size" & reduce shared memory buffer
dasistwo Sep 6, 2024
51b0df6
Compute memory address at compile time
dasistwo Sep 6, 2024
44c6699
Apply compile-time calculation for less instruction
dasistwo Sep 9, 2024
90c798c
Debug for ColumnMajor Case
dasistwo Sep 9, 2024
0396172
Merge branch 'NVIDIA:main' into mlp
dasistwo Sep 10, 2024
a50ccee
Revoke irrelevant commits
dasistwo Sep 10, 2024
f344143
Merge branch 'NVIDIA:main' into mlp
dasistwo Sep 27, 2024
3bbed0e
Merge with gavinchen430/gemm_w2a16
dasistwo Dec 6, 2024
b19f748
Update submodule cutlass
dasistwo Dec 6, 2024
2a8779d
Debug errors with updated cutlass
dasistwo Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 421 files
1 change: 1 addition & 0 deletions cpp/include/tensorrt_llm/common/stringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <memory> // std::make_unique
#include <sstream> // std::stringstream
#include <cstdint>
#include <string>
#include <unordered_set>
#include <vector>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,6 @@ namespace epilogue
namespace threadblock
{

////////////////////////////////////////////////////////////////////////////////

namespace detail
{

/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
ThreadMap>
{
using WarpTileIterator
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;

using SharedLoadIterator
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;

static int const kFragmentsPerIteration = 2;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace detail

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Tile iterator used to load output tile from shared memory in epilogue.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ template <typename TypeA, typename Arch>
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};

template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint2b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;

private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint2b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;

public:
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint2b_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};


template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
Expand All @@ -148,6 +166,15 @@ struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::
using Operator = cutlass::arch::OpMultiplyAdd;
};

template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};

} // namespace kernel
} // namespace gemm
} // namespace cutlass
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");

static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value
|| platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint8, uint4 or uint2");

static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");

static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value ||
platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint8, uint4 or uint2");

using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
Expand Down Expand Up @@ -213,8 +215,10 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");

static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value ||
platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint8, uint4 or uint2");

using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,54 @@ struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAli
using ThreadblockMma = typename Mma::ThreadblockMma;
};

////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int2 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{

private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;

using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;

public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;

// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;

// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;

// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};

////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
Expand Down Expand Up @@ -232,6 +280,59 @@ struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAli
using ThreadblockMma = typename Mma::ThreadblockMma;
};

////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int2 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{

private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;

using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;

public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;

// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;

// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;

// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
Expand Down Expand Up @@ -287,6 +388,59 @@ struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB
using ThreadblockMma = typename Mma::ThreadblockMma;
};

////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int2 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{

private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;

using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, half_t,
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;

public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;

// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;

// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;

// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#endif

// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
Expand Down
Loading