Skip to content

Commit

Permalink
fix _round in onnx_ops to look more like new Tensor.round (tinygrad#3239
Browse files Browse the repository at this point in the history
)

* fix: _round in onnxops

* fix: minor things

* fix: no more n

* fix: smol

* fix: smoller
  • Loading branch information
geohotstan authored Jan 25, 2024
1 parent aa0d1b6 commit b0b5eba
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions extra/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,30 +394,28 @@ def GatherElements(x: Tensor, indices: Tensor, axis):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(indices, axis)

def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
assert n <= 1, f"n:{n} shouldn't be larger than 1"
b = x.trunc()
b = (b >= 0).where(b+n, b-n)
if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n)
if equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n)
def _round(x:Tensor, equidistant_case = "round_down") -> Tensor:
if equidistant_case == "round_down": return (x - 0.5).ceil()
if equidistant_case == "round_up": return x.round()
if equidistant_case == "round_to_even":
b = x.trunc()
b = (b >= 0).where(b + 0.5, b - 0.5)
x_ceil_fraction = x.ceil()/2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
x = (And(x == b, cond_ceil_even)).where(x+1-n, x)
x = (x > b).where(b+1-n, b-n)
x = (And(x == b, cond_ceil_even)).where(x + 0.5, x)
x = (x - 0.5).ceil()
return x

# TODO: this is different from Tensor.round?
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
def Round(X:Tensor): return _round(X, "round_to_even")

# TODO clean this up, it's taking the longest in CI
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel',
cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch',
mode='nearest', nearest_mode='round_prefer_floor'):
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down")
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up")
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, "round_down")
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, "round_up")
elif nearest_mode == "floor": ret = x_resized.floor()
elif nearest_mode == "ceil": ret = x_resized.ceil()
return ret.cast(dtypes.int32).clip(0, x_len-1)
Expand Down Expand Up @@ -468,11 +466,11 @@ def _coordinate_transformation(x_out, y_out, output_shape, scales_, roi=None):
else: scales = [si/xs for xs, si in zip(X.shape, sizes)]
if keep_aspect_ratio_policy == "not_larger":
scale = min(scales)
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
sizes = _round(Tensor(list(X.shape[-2:]))*scale, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
elif keep_aspect_ratio_policy == "not_smaller":
scale = max(scales)
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
sizes = _round(Tensor(list(X.shape[-2:]))*scale, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
Expand Down

0 comments on commit b0b5eba

Please sign in to comment.