Skip to content

Commit

Permalink
Use F.silu instead of nn.SiLU to avoid creating a module on every call
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Dec 12, 2024
1 parent 064f8d0 commit c1fd416
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
"""

import math
from typing import Optional

from dataclasses import dataclass
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -128,7 +124,7 @@ def forward(
# running on sequences img
img = self.img_in(img)
time_in_0 = self.time_in_0(timestep_embedding(timesteps, 256))
time_in_silu = ops.elementwise(nn.SiLU(), time_in_0)
time_in_silu = ops.elementwise(F.silu, time_in_0)
vec = self.time_in_1(time_in_silu)
if self.guidance:
if guidance is None:
Expand All @@ -137,11 +133,11 @@ def forward(
)
guidance_inp = timestep_embedding(guidance, 256)
guidance0 = self.guidance_in0(guidance_inp)
guidance_silu = ops.elementwise(nn.SiLU(), guidance0)
guidance_silu = ops.elementwise(F.silu, guidance0)
guidance_out = self.guidance_in1(guidance_silu)
vec = vec + self.guidance_in(guidance_out)
vector_in_0 = self.vector_in_0(y)
vector_in_silu = ops.elementwise(nn.SiLU(), vector_in_0)
vector_in_silu = ops.elementwise(F.silu, vector_in_0)
vector_in_1 = self.vector_in_1(vector_in_silu)
vec = vec + vector_in_1

Expand Down Expand Up @@ -247,7 +243,7 @@ def __init__(
self.add_module("ada_linear", LinearLayer(theta("ada_linear")))

def forward(self, x: AnyTensor, vec: AnyTensor) -> AnyTensor:
silu = ops.elementwise(nn.SiLU(), vec)
silu = ops.elementwise(F.silu, vec)
lin = self.ada_linear(silu)
shift, scale = lin.chunk(2, dim=1)
print(x.shape, shift.shape, scale.shape)
Expand Down

0 comments on commit c1fd416

Please sign in to comment.