Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement linear RoPE scaling #1442

Merged
merged 2 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,17 @@ namespace ctranslate2 {
const dim_t _cache_time_dim;
};

enum class RotaryScalingType {
None = -1,
Linear,
};

class RotaryEmbeddings {
public:
RotaryEmbeddings(const dim_t dim = 0,
const bool interleave = true,
const RotaryScalingType scaling_type = RotaryScalingType::None,
const float scaling_factor = 1,
const dim_t num_initial_positions = 2048,
const float base = 10000);

Expand All @@ -88,6 +95,8 @@ namespace ctranslate2 {

const dim_t _dim;
const bool _interleave;
const RotaryScalingType _scaling_type;
const float _scaling_factor;
const dim_t _num_initial_positions;
const float _base;
const ops::Rotary _rotary_op;
Expand Down
4 changes: 2 additions & 2 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ namespace ctranslate2 {
bool get_flag_with_default(const std::string& name, bool default_value) const;

template <typename Enum>
Enum get_enum_value(const std::string& name) const {
return static_cast<Enum>(get_attribute_with_default<int32_t>(name, 0));
Enum get_enum_value(const std::string& name, int32_t default_index = 0) const {
return static_cast<Enum>(get_attribute_with_default<int32_t>(name, default_index));
}

protected:
Expand Down
29 changes: 28 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import common_spec, model_spec, transformer_spec, whisper_spec
from ctranslate2.specs import (
attention_spec,
common_spec,
model_spec,
transformer_spec,
whisper_spec,
)

_SUPPORTED_ACTIVATIONS = {
"gelu": common_spec.Activation.GELU,
Expand All @@ -31,6 +37,10 @@
"swish": common_spec.Activation.SWISH,
}

_SUPPORTED_ROPE_SCALING = {
"linear": attention_spec.RotaryScalingType.Linear,
}

_MODEL_LOADERS = {}


Expand Down Expand Up @@ -1198,6 +1208,21 @@ def get_model_spec(self, model):
if num_heads_kv == num_heads:
num_heads_kv = None

rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
rotary_scaling_factor = rope_scaling["factor"]

if rotary_scaling_type is None:
raise NotImplementedError(
"RoPE scaling type '%s' is not yet implemented. "
"The following RoPE scaling types are currently supported: %s"
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
)
else:
rotary_scaling_type = None
rotary_scaling_factor = 1

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
Expand All @@ -1207,6 +1232,8 @@ def get_model_spec(self, model):
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
num_heads_kv=num_heads_kv,
)

Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ctranslate2.specs.attention_spec import RotaryScalingType
from ctranslate2.specs.common_spec import Activation, EmbeddingsMerge
from ctranslate2.specs.model_spec import (
LanguageModelSpec,
Expand Down
17 changes: 17 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import enum

import numpy as np

from ctranslate2.specs import common_spec, model_spec


# This enum should match the C++ equivalent in include/ctranslate2/layers/attention.h.
class RotaryScalingType(enum.IntEnum):
"""RoPE scaling type."""

Linear = 0


class MultiHeadAttentionSpec(model_spec.LayerSpec):
def __init__(
self,
Expand All @@ -12,6 +21,8 @@ def __init__(
rms_norm=False,
rotary_dim=None,
rotary_interleave=True,
rotary_scaling_type=None,
rotary_scaling_factor=1,
num_heads_kv=None,
):
self.queries_scale = model_spec.OPTIONAL
Expand All @@ -33,5 +44,11 @@ def __init__(
self.rotary_dim = np.dtype("int32").type(rotary_dim)
self.rotary_interleave = rotary_interleave

if rotary_scaling_type is not None:
self.rotary_scaling_type = np.dtype("int8").type(rotary_scaling_type)
self.rotary_scaling_factor = np.dtype("float32").type(
rotary_scaling_factor
)

if num_heads_kv is not None:
self.num_heads_kv = np.dtype("int32").type(num_heads_kv)
16 changes: 16 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(
alibi_use_positive_positions: bool = False,
rotary_dim: Optional[int] = None,
rotary_interleave: bool = True,
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
rotary_scaling_factor: float = 1,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
Expand Down Expand Up @@ -124,6 +126,8 @@ def __init__(
embeddings are applied to all dimensions.
rotary_interleave: Interleave the head dimensions when rotary embeddings are applied.
Otherwise the head dimensions are sliced in half.
rotary_scaling_type: Type of RoPE scaling.
rotary_scaling_factor: Factor used in the RoPE scaling.
parallel_residual: Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
Expand Down Expand Up @@ -181,6 +185,8 @@ def __init__(
rms_norm=rms_norm,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
num_heads_kv=num_heads_kv,
Expand Down Expand Up @@ -223,6 +229,8 @@ def __init__(
rms_norm=False,
rotary_dim=None,
rotary_interleave=True,
rotary_scaling_type=None,
rotary_scaling_factor=1,
parallel_residual=False,
shared_layer_norm=False,
num_heads_kv=None,
Expand All @@ -234,6 +242,8 @@ def __init__(
rms_norm=rms_norm,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
num_heads_kv=num_heads_kv,
)

Expand Down Expand Up @@ -453,6 +463,8 @@ def from_config(
alibi_use_positive_positions: bool = False,
rotary_dim: Optional[int] = None,
rotary_interleave: bool = True,
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
rotary_scaling_factor: float = 1,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
Expand All @@ -479,6 +491,8 @@ def from_config(
embeddings are applied to all dimensions.
rotary_interleave: Interleave the head dimensions when rotary embeddings are applied.
Otherwise the head dimensions are sliced in half.
rotary_scaling_type: Type of RoPE scaling.
rotary_scaling_factor: Factor used in the RoPE scaling.
parallel_residual: Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
Expand All @@ -502,6 +516,8 @@ def from_config(
alibi_use_positive_positions=alibi_use_positive_positions,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
multi_query_attention=multi_query_attention,
Expand Down
17 changes: 15 additions & 2 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,16 @@ namespace ctranslate2 {
return nullptr;

const bool interleave = model.get_flag_with_default(scope + "/rotary_interleave", true);
return std::make_unique<RotaryEmbeddings>(rotary_dim, interleave);

const auto scaling_type = model.get_enum_value<RotaryScalingType>(
scope + "/rotary_scaling_type", -1);
const auto scaling_factor = model.get_attribute_with_default<float>(
scope + "/rotary_scaling_factor", 1.f);

return std::make_unique<RotaryEmbeddings>(rotary_dim,
interleave,
scaling_type,
scaling_factor);
}


Expand Down Expand Up @@ -590,10 +599,14 @@ namespace ctranslate2 {

RotaryEmbeddings::RotaryEmbeddings(const dim_t dim,
const bool interleave,
const RotaryScalingType scaling_type,
const float scaling_factor,
const dim_t num_initial_positions,
const float base)
: _dim(dim)
, _interleave(interleave)
, _scaling_type(scaling_type)
, _scaling_factor(scaling_factor)
, _num_initial_positions(num_initial_positions)
, _base(base)
, _rotary_op(dim, interleave)
Expand Down Expand Up @@ -637,7 +650,7 @@ namespace ctranslate2 {

StorageView t({num_positions, 1});
for (dim_t i = 0; i < t.size(); ++i)
t.at<float>(i) = i;
t.at<float>(i) = _scaling_type == RotaryScalingType::None ? i : float(i) / _scaling_factor;
if (t.device() != device)
t = t.to(device);

Expand Down