Skip to content

Commit

Permalink
Reuse make_rand_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Dec 12, 2024
1 parent 418fa13 commit 064f8d0
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions sharktank/tests/models/flux/flux_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@
from sharktank.layers.testing import (
make_mmdit_double_block_random_theta,
make_mmdit_single_block_random_theta,
make_rand_torch,
)
from sharktank.types.tensors import DefaultPrimitiveTensor
from sharktank.types.theta import Theta


def make_rand_torch(shape: list[int], dtype: torch.dtype | None = torch.float32):
return torch.rand(shape, dtype=dtype) * 2 - 1


# TODO: Refactor this to a function that generates random toy weights, possibly
# to another file
dtype = torch.float32
Expand Down

0 comments on commit 064f8d0

Please sign in to comment.