This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "bring back torch.autograd.Function for float8 matmul"
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068) [ghstack-poisoned]
- Loading branch information