diff --git a/elasticai/creator/base_modules/autograd_functions/round_to_block_biased_float.py b/elasticai/creator/base_modules/autograd_functions/round_to_block_biased_float.py new file mode 100644 index 00000000..ffb51921 --- /dev/null +++ b/elasticai/creator/base_modules/autograd_functions/round_to_block_biased_float.py @@ -0,0 +1,36 @@ +from typing import Any + +import torch + + +class RoundToBlockMinifloat(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + if len(args) != 3: + raise TypeError( + "apply() takes exactly three arguments (" + "x: torch.Tensor, mantissa_bits: int, exponent_bits: int" + ")" + ) + x: torch.Tensor = args[0] + mantissa_bits: int = args[1] + exponent_bits: int = args[2] + + shared_bias = x.abs().log2().floor().max() - (2**exponent_bits - 1) + + largest_value = (2 - 1 / 2**mantissa_bits) * 2 ** ( + (2**exponent_bits - 1) - shared_bias + ) + + out_of_bounds = (x < -largest_value) | (x > largest_value) + if torch.any(out_of_bounds): + raise ValueError("Cannot quantize tensor. Values out of bounds.") + + # TODO: Here we are... + + scale = 2 ** (x.abs().log2().floor() - mantissa_bits) + return scale * torch.round(x / scale) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + return *grad_outputs, None, None