-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Release/2.4] Updating unit test case based on removing amax from _scaled_mm and removing amax constraint to find more solutions. #1742
base: release/2.4
Are you sure you want to change the base?
Conversation
amax was removed from _scaled_mm by pytorch#128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well. This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result. Pull Request resolved: pytorch#135421 Approved by: https://github.com/drisspg, https://github.com/eqy
…comparison in the unit test.
Jenkins build for 1ac570c3595893c53009ec2601a7860db8384245 commit finished as FAILURE Detected error during Pytorch building:
|
@@ -367,11 +367,8 @@ def _test_tautological_mm(self, device: str = "cuda", | |||
(out_fp8, amax_fp8) = torch._scaled_mm(x_fp8, y_fp8, out_dtype=out_dtype) | |||
if out_dtype is not None: | |||
self.assertEqual(out_dtype, out_fp8.dtype) | |||
if out_dtype not in [torch.float16, torch.bfloat16, torch.float]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You made a lot of ROCm specific changes but this change applies to CUDA as well. Are we sure this doesn't break something on their end?
|
||
// Validates the scale tensors to scaled_mm | ||
// And returns the type of scaling/which kernel to use | ||
ScalingType get_scaling_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be out of the loop here, but where exactly is this called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not called in this class. When I do cherry-pick, this part of the code is added. There are file differences: pytorch@2df620d#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3ab. I need to merge it properly.
* - If scale_a.numel() == 1 && scale_b.numel() == 1: | ||
* - Returns TensorWise. | ||
* | ||
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps also add a comment for the restriction of Tensor's being required to be 2D for the RowWise case
Just some minor questions, but mostly lgtm |
Jenkins build for 1ac570c3595893c53009ec2601a7860db8384245 commit finished as FAILURE Detected error during Pytorch building:
|
Jenkins build for 1ac570c3595893c53009ec2601a7860db8384245 commit finished as FAILURE Detected error during Pytorch building:
|
namespace{ | ||
|
||
enum class ScalingType { | ||
TensorWise, | ||
RowWise, | ||
Error | ||
}; | ||
/* | ||
* Scaling Type Determination: | ||
* --------------------------- | ||
* Conditions and corresponding Scaling Types: | ||
* | ||
* - If scale_a.numel() == 1 && scale_b.numel() == 1: | ||
* - Returns TensorWise. | ||
* | ||
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: | ||
* - Returns RowWise. | ||
* | ||
* - Otherwise: | ||
* - Returns Error. | ||
*/ | ||
|
||
// Validates the scale tensors to scaled_mm | ||
// And returns the type of scaling/which kernel to use | ||
ScalingType get_scaling_type( | ||
const at::Tensor& scale_a, | ||
const at::Tensor& scale_b, | ||
int64_t dim_m, | ||
int64_t dim_n) { | ||
// Both Per-Tensor and Row-wise scaling expect fp32 tensors | ||
TORCH_CHECK( | ||
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, | ||
"Both scale_a and scale_b must be float (fp32) tensors."); | ||
|
||
// Check the singluar scale case for per-tensor scaling | ||
if (scale_a.numel() == 1 && scale_b.numel() == 1) { | ||
return ScalingType::TensorWise; | ||
} | ||
|
||
// For non-TensorWise scaling, enforce 2D input tensors | ||
TORCH_CHECK( | ||
scale_a.dim() == 2 && scale_b.dim() == 2, | ||
"For non-TensorWise scaling, scale tensors must be 2-dimensional, " | ||
"but got scale_a.dim()=", | ||
scale_a.dim(), | ||
" and scale_b.dim()=", | ||
scale_b.dim()); | ||
|
||
// Check for RowWise scaling | ||
if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && | ||
scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { | ||
#if !defined(USE_ROCM) && !defined(_MSC_VER) || \ | ||
(defined(USE_ROCM) && ROCM_VERSION >= 60000) | ||
TORCH_CHECK( | ||
scale_a.is_contiguous() && scale_b.is_contiguous(), | ||
"Both scale_a and scale_b must be contiguous for RowWise scaling."); | ||
return ScalingType::RowWise; | ||
#else | ||
TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); | ||
return ScalingType::Error; | ||
#endif | ||
} | ||
|
||
// If we reach here, the input doesn't match any valid scaling type | ||
TORCH_CHECK( | ||
false, | ||
"Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " | ||
"For RowWise scaling, scale_a should be (", | ||
dim_m, | ||
", 1) and scale_b should be (1, ", | ||
dim_n, | ||
"). " | ||
"Got scale_a.size()=(", | ||
scale_a.size(0), | ||
", ", | ||
scale_a.size(1), | ||
") and ", | ||
"scale_b.size()=(", | ||
scale_b.size(0), | ||
", ", | ||
scale_b.size(1), | ||
")"); | ||
|
||
return ScalingType::Error; | ||
} | ||
|
||
} // namespace | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code is not required.
Seems a git cherry-pick merge conflict mistake.
Sorry for the delayed response. Yes, you are right. That piece of code shouldn't be included. I will do a proper merge. These two PRs - #1735 are related. This PR 1742 solves both issues. So, I will reject 1735. Thanks for noting it. |
Jenkins build for 1ac570c3595893c53009ec2601a7860db8384245 commit finished as FAILURE Detected error during Pytorch building:
|
Jenkins build for 1ac570c3595893c53009ec2601a7860db8384245 commit finished as FAILURE Detected error during Pytorch building:
|
Jenkins build for 1ac570c3595893c53009ec2601a7860db8384245 commit finished as FAILURE Detected error during Pytorch building:
|
cherry pick commit - 39a6179
amax was removed from _scaled_mm by pytorch#128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well. This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Also removing amax comparison in the unit test.