Skip to content

Commit

Permalink
Rework RotaryEmbedding for dynamic computation (#255)
Browse files Browse the repository at this point in the history
Some minor changes to the rotary embedding can better support fusion and
avoid using a lookup table. Depending on backend one version may provide
better overall performance.
  • Loading branch information
rsuderman authored Oct 14, 2024
1 parent acd77e3 commit 854bea3
Showing 1 changed file with 45 additions and 23 deletions.
68 changes: 45 additions & 23 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from collections import namedtuple
from typing import Optional, Union

import torch
Expand All @@ -24,14 +25,17 @@ def __init__(
rope_freq_base: Optional[float],
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = True,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
):
super().__init__()
self.device = device
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
Expand All @@ -44,10 +48,16 @@ def __init__(

@property
def rotary_embed_table(self):
if self.static_rotary_embed_table is None:
if self.use_table:
if self.static_tables:
return self.static_rotary_embed_table
return self._create_rotary_embed_table()
else:
return self.static_rotary_embed_table

if self.tensor_parallelism_size == 1:
return None

nt = namedtuple("replicated_tensor", ["shards"])
return nt([None] * self.tensor_parallelism_size)

def forward(
self,
Expand Down Expand Up @@ -96,7 +106,7 @@ def forward_unsharded(
xq: torch.Tensor,
xk: torch.Tensor,
start_index: int,
rotary_embed_table: torch.Tensor,
rotary_embed_table: Optional[torch.Tensor],
):
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
Expand Down Expand Up @@ -142,12 +152,18 @@ def create_ordering_tensor(dim):
xq = xq[..., create_interleaved_tensor(xq.shape[-1])]
xk = xk[..., create_interleaved_tensor(xq.shape[-1])]

xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
_, sl, _, dim = xq_.shape

# Offset the table based on starting position.
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
if self.use_table:
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
else:
freqs_cis = torch.arange(start_index, start_index + sl, device=xq.device)
freqs_cis = self._compute_rotary_embed_table(freqs_cis)
freqs_cis = self._replicate(freqs_cis)

assert freqs_cis.shape[-1] == dim
assert (
freqs_cis.shape[0] >= sl
Expand Down Expand Up @@ -206,7 +222,13 @@ def compute_batch_mask(
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
self.trace_tensor("rope.positions_seq", positions_seq)
freqs_cis = self.rotary_embed_table[positions_seq]

if self.use_table:
freqs_cis = self.rotary_embed_table[positions_seq]
else:
shape = positions_seq.shape
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)

# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
Expand All @@ -225,10 +247,6 @@ def apply_batched_mask(
and xq.shard_count == xk.shard_count
and xk.shard_dim == xq.shard_dim
)
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xq.shard_count == self.rotary_embed_table.shard_count
)
assert (
isinstance(mask, ReplicatedTensor)
and mask.shard_count == xq.shard_count
Expand Down Expand Up @@ -263,24 +281,20 @@ def apply_batched_mask_unsharded(
"""
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
_, sl, _, dim = xq_.shape

xq_out = torch.view_as_real(xq_ * mask).flatten(3)
xk_out = torch.view_as_real(xk_ * mask).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

def _create_rotary_embed_table(
self,
):
def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
max_seqlen = self.max_seqlen
freqs = 1.0 / (
self.rope_freq_base
** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim)
** (torch.arange(0, dim, 2, device=t.device)[: (dim // 2)].float() / dim)
)
t = torch.arange(max_seqlen, device=freqs.device)
freqs = torch.outer(t, freqs).float()

freqs_cis = (
Expand All @@ -289,8 +303,16 @@ def _create_rotary_embed_table(
else torch.polar(torch.ones_like(freqs), freqs)
)

return freqs_cis

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
freqs_cis = self._compute_rotary_embed_table(t)
return self._replicate(freqs_cis)

def _replicate(self, t):
if self.tensor_parallelism_size > 1:
# Replicate across all devices, the data is not a lot and the computation is cheap.
freqs_cis = ops.replicate(freqs_cis, self.tensor_parallelism_size)
t = ops.replicate(t, self.tensor_parallelism_size)

return freqs_cis
return t

0 comments on commit 854bea3

Please sign in to comment.