Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
rename all variables to use input/weight/grad_output notation
Browse files Browse the repository at this point in the history
Summary:

In #323 we
changed the user facing variable notation from `x/w/dL_dY` to
`input/weight/grad_output`.

This PR follows up by changing most of the internal variables to also match
the new notation, to reduce confusion.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 25, 2024
1 parent 994057c commit 97887f7
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 94 deletions.
4 changes: 2 additions & 2 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def backward(ctx, gradY):
gradY_scale,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
)
return fp8_tensor, None

Expand All @@ -51,7 +51,7 @@ def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.X,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
Expand Down
76 changes: 39 additions & 37 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def backward(ctx, go):
fp8_scale_grad_output,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand Down Expand Up @@ -273,21 +273,21 @@ def convert_amax_buffer_to_float32(self):
if self._buffers[key] is not None:
self._buffers[key] = self._buffers[key].to(torch.float32)

def cast_x_to_float8(
self, x: torch.Tensor, is_amax_initialized: bool
def cast_input_to_float8(
self, input: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
x = x.to(autocast_dtype)
input = input.to(autocast_dtype)

if self.scaling_type_input is TensorScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
x,
input,
self.fp8_amax_input,
self.fp8_amax_history_input,
self.fp8_scale_input,
Expand All @@ -296,29 +296,29 @@ def cast_x_to_float8(
is_amax_initialized,
reduce_amax=True,
)
x_fp8 = Float8Tensor.to_float8(
x,
input_fp8 = Float8Tensor.to_float8(
input,
self.fp8_scale_input,
e4m3_dtype,
self.fp8_amax_input,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config)
return x_fp8
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
return input_fp8

def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
def cast_weight_to_float8(
self, weight: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
if self.scaling_type_weight is TensorScalingType.DELAYED:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
weight_fp8 = self.weight
else:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
Expand All @@ -328,29 +328,31 @@ def cast_w_to_float8(
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
weight_fp8 = Float8Tensor.to_float8(
weight,
self.fp8_scale_weight,
e4m3_dtype,
self.fp8_amax_weight,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
else:
assert self.scaling_type_weight is TensorScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
weight_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
weight_fp8 = cast_to_float8_e4m3_dynamic(
self.weight,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return w_fp8
return weight_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
output = NoopFwToFloat8E5M2Bw.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
self.fp8_scale_grad_output,
Expand All @@ -360,10 +362,10 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
)
else:
assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config)
return y
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
return output

def float8_pre_forward(self, x):
def float8_pre_forward(self, input):
if not self.enable_pre_and_post_forward:
return
if (
Expand All @@ -374,7 +376,7 @@ def float8_pre_forward(self, x):
raise AssertionError(
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
)
self.last_seen_input_dtype = x.dtype
self.last_seen_input_dtype = input.dtype

def float8_post_forward(self):
if not self.enable_pre_and_post_forward:
Expand All @@ -388,25 +390,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.has_any_delayed_scaling:
self.float8_pre_forward(input)

x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
output = torch.matmul(input_fp8, weight_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y)
# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)

if self.bias is not None:
y = y + self.bias.to(y.dtype)
output = output + self.bias.to(output.dtype)

if self.has_any_delayed_scaling:
self.float8_post_forward()
return y
return output

def scaling_repr(self):
# add scaling settings without using too many characters
# example: "x:del,w:del,dldy:dyn"
return f"x:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},dldy:{self.scaling_type_grad_output.short_str()}"
# example: "i:del,w:del,go:dyn"
return f"i:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},go:{self.scaling_type_grad_output.short_str()}"

def extra_repr(self):
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
Expand Down
55 changes: 28 additions & 27 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@
#
# There are three gemms in a forward + backward of a Linear layer:
#
# 1. x @ w_t = y (forward pass)
# 2. dL_dY @ w = dL_dX (backward pass)
# 3. x_t @ dL_dY = dL_dW (backward pass)
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# In the formulas above, there are:
# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t).
# - Note that dL_dY_t is implied because of memory format requirements
# A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t).
# - Note that grad_output_t is implied because of memory format requirements
# of float8 gemms
# B. three output tensors (y, dL_dX, dL_dW)
# B. three output tensors (output, grad_input, grad_weight)
#
# We want each input tensor, gemm, and output tensor to be configurable.
# The state of this configuration today is:
#
# i. pairs of input tensors (non-t and t variants) have their scaling
# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear
# configurable via the scaling_type_* arguments to Float8Linear
# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing
# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed
# to configure all three gemms, also not user facing
Expand All @@ -60,11 +60,12 @@

# The object below is not user facing and exists for convenience,
# to allow Float8Tensor to use
# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is
# the right config based on which gemm from gemms with outputs
# `output`, `grad_input`, `grad_weight` is
# being called.
LinearMMConfig = namedtuple(
"LinearMMConfig",
["y", "dL_dX", "dL_dW"],
["output", "grad_input", "grad_weight"],
defaults=[
ScaledMMConfig(False, True, False, False),
ScaledMMConfig(False, False, False, False),
Expand All @@ -81,9 +82,9 @@ class GemmInputRole(enum.Enum):
gemm is performed.
"""

X = "x"
W = "w"
DL_DY = "dL_dY"
INPUT = "input"
WEIGHT = "weight"
GRAD_OUTPUT = "grad_output"


# choose which scaled_mm_config to use based on gemm inputs
Expand All @@ -93,21 +94,21 @@ def choose_scaled_mm_config(
b_role: GemmInputRole,
b_linear_mm_config: LinearMMConfig,
):
if a_role is GemmInputRole.X and b_role is GemmInputRole.W:
if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT:
assert (
a_linear_mm_config.y == b_linear_mm_config.y
), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}"
return a_linear_mm_config.y
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W:
a_linear_mm_config.output == b_linear_mm_config.output
), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}"
return a_linear_mm_config.output
elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT:
assert (
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
return a_linear_mm_config.dL_dX
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
a_linear_mm_config.grad_input == b_linear_mm_config.grad_input
), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}"
return a_linear_mm_config.grad_input
elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT:
assert (
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"
return a_linear_mm_config.dL_dW
a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight
), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}"
return a_linear_mm_config.grad_weight
else:
raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")

Expand Down Expand Up @@ -207,7 +208,7 @@ def forward(
float8_dtype=e4m3_dtype,
amax_buffer: Optional[torch.Tensor] = None,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
Args
Expand Down Expand Up @@ -287,7 +288,7 @@ def __new__(
scale: torch.Tensor,
orig_dtype: torch.dtype,
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
assert (
scale.numel() == 1
Expand Down Expand Up @@ -348,7 +349,7 @@ def to_float8(
float8_dtype: torch.dtype,
amax_buffer: Optional[torch.Tensor] = None,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
"""Converts a higher precision tensor to float8 in a differentiable way.
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _prepare_input_fn(
input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
Expand Down Expand Up @@ -101,7 +101,7 @@ def _prepare_input_fn(
input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
Expand Down Expand Up @@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
dt_inp = cast_to_float8_e4m3_dynamic(
dt_inp,
self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
Expand Down
10 changes: 5 additions & 5 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@ def fsdp_pre_all_gather(self, mesh):
self._precomputed_scale,
torch.float8_e4m3fn,
linear_mm_config=self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor,
self._linear_mm_config,
reduce_amax=True,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand All @@ -201,7 +201,7 @@ def fsdp_post_all_gather(
scale,
param_dtype,
self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
), (data,)


Expand Down Expand Up @@ -364,7 +364,7 @@ def fsdp_pre_all_gather(self, mesh):
e4m3_dtype,
self._amax_buffer,
self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand All @@ -387,5 +387,5 @@ def fsdp_post_all_gather(
scale,
param_dtype,
self._linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
), (data,)
4 changes: 2 additions & 2 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
scale,
dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.W,
gemm_input_role=GemmInputRole.WEIGHT,
)
self.weight = nn.Parameter(quantized_weight)
self.weight.requires_grad = False
Expand Down Expand Up @@ -205,7 +205,7 @@ def cast_to_float8_e4m3_inference(
scale,
e4m3_dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=GemmInputRole.X,
gemm_input_role=GemmInputRole.INPUT,
)


Expand Down
Loading

0 comments on commit 97887f7

Please sign in to comment.