-
Notifications
You must be signed in to change notification settings - Fork 0
/
rmscaleCUTLASS.patch
94 lines (82 loc) · 4.75 KB
/
rmscaleCUTLASS.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
index a4f80dc6..d2844f49 100644
--- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
+++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
@@ -408,7 +408,7 @@ public:
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
- iterator_scale.clear_mask(gemm_k_iterations == 0);
+ // iterator_scale.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
@@ -460,7 +460,7 @@ public:
++this->smem_iterator_B_;
}
- copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);
+ // copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
@@ -545,15 +545,17 @@ public:
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
- warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
+ // warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
+ warp_frag_scales.fill(static_cast<ElementScale>(1));
+ warp_frag_zeros.fill(static_cast<ElementScale>(0));
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
- warp_dequantizer_.add_pointer_offset(Shape::kN);
+ // warp_dequantizer_.add_pointer_offset(Shape::kN);
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
- iterator_scale.clear_mask(gemm_k_iterations == 0);
+ // iterator_scale.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
@@ -619,10 +621,10 @@ public:
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
- if (group_start_iteration_B == 0)
- {
- copy_scales_and_advance(iterator_scale);
- }
+ // if (group_start_iteration_B == 0)
+ // {
+ // copy_scales_and_advance(iterator_scale);
+ // }
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
@@ -654,7 +656,7 @@ public:
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
- this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
+ // this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
}
else
@@ -668,7 +670,7 @@ public:
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
- warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
+ // warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
smem_read_stage_idx = 0;
}
else
@@ -679,14 +681,14 @@ public:
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
- iterator_scale.clear_mask(gemm_k_iterations == 0);
+ // iterator_scale.clear_mask(gemm_k_iterations == 0);
}
}
// Load the scale needed for the next tile iteration.
- warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
+ // warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
// Update internal pointer to set of scales in shared memory.
- warp_dequantizer_.add_pointer_offset(Shape::kN);
+ // warp_dequantizer_.add_pointer_offset(Shape::kN);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)