From 5884e0b90491933bfaa9f746bb239a9e461961f0 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 2 Aug 2024 23:18:44 +0000 Subject: [PATCH 1/5] add bitnetforcausallm support --- aphrodite/modeling/layers/linear.py | 24 +- aphrodite/modeling/models/__init__.py | 1 + aphrodite/modeling/models/bitnet.py | 627 ++++++++++++++++++ aphrodite/quantization/__init__.py | 2 + aphrodite/quantization/bitnet.py | 419 ++++++++++++ aphrodite/transformers_utils/tokenizer.py | 22 +- .../transformers_utils/tokenizers/__init__.py | 3 +- .../transformers_utils/tokenizers/bitnet.py | 463 +++++++++++++ requirements-cuda.txt | 3 +- 9 files changed, 1556 insertions(+), 8 deletions(-) create mode 100644 aphrodite/modeling/models/bitnet.py create mode 100644 aphrodite/quantization/bitnet.py create mode 100644 aphrodite/transformers_utils/tokenizers/bitnet.py diff --git a/aphrodite/modeling/layers/linear.py b/aphrodite/modeling/layers/linear.py index e400779f8..4774dcc2e 100644 --- a/aphrodite/modeling/layers/linear.py +++ b/aphrodite/modeling/layers/linear.py @@ -24,6 +24,16 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def adjust_bitblas_shard(param, shard_size, shard_offset): + bitblas_tile_size = getattr(param, "bitblas_tile_size", None) + weight_propagation = getattr(param, "weight_propagation", None) + if weight_propagation and bitblas_tile_size is not None: + return (shard_size // bitblas_tile_size, + shard_offset // bitblas_tile_size) + + return shard_size, shard_offset + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -262,7 +272,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data, loaded_weight = fp8_scales_shard_indexer(param_data, loaded_weight, shard_id=0) - assert param_data.shape == loaded_weight.shape + assert param_data.dtype == loaded_weight.dtype, ( + f"{param_data.dtype} != {loaded_weight.dtype}") + assert param_data.shape == loaded_weight.shape, ( + f"{param_data.shape} != {loaded_weight.shape}") param_data.copy_(loaded_weight) def forward(self, input_): @@ -382,6 +395,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -404,6 +420,8 @@ def weight_loader(self, # account for the tiling. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) param_data = param_data.narrow(output_dim, shard_offset, shard_size) @@ -565,6 +583,8 @@ def weight_loader(self, # account for the tiling. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) @@ -596,6 +616,8 @@ def weight_loader(self, # account for the tiling. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) param_data = param_data.narrow(output_dim, shard_offset, shard_size) diff --git a/aphrodite/modeling/models/__init__.py b/aphrodite/modeling/models/__init__.py index 709f20520..179149734 100755 --- a/aphrodite/modeling/models/__init__.py +++ b/aphrodite/modeling/models/__init__.py @@ -13,6 +13,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "BitnetForCausalLM": ("bitnet", "BitnetForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"), diff --git a/aphrodite/modeling/models/bitnet.py b/aphrodite/modeling/models/bitnet.py new file mode 100644 index 000000000..b6ca4726f --- /dev/null +++ b/aphrodite/modeling/models/bitnet.py @@ -0,0 +1,627 @@ +# coding=utf-8 +# Adapted from +# https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/modeling_bitnet.py +# Copyright 2023 The PygmalionAI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Bitnet model compatible with HuggingFace weights.""" + +# ruff: noqa: E501 + +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from loguru import logger + +from aphrodite.attention import Attention, AttentionMetadata +from aphrodite.common.config import CacheConfig +from aphrodite.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from aphrodite.modeling.layers.activation import SiluAndMul +from aphrodite.modeling.layers.layernorm import RMSNorm +from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from aphrodite.modeling.layers.logits_processor import LogitsProcessor +from aphrodite.quantization.base_config import ( + QuantizationConfig) +from aphrodite.modeling.layers.rotary_embedding import get_rope +from aphrodite.modeling.layers.sampler import Sampler +from aphrodite.modeling.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from aphrodite.modeling.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader) +from aphrodite.modeling.sampling_metadata import SamplingMetadata +from aphrodite.common.sequence import IntermediateTensors, SamplerOutput +from aphrodite.common.utils import is_hip, print_warning_once + + +class BitnetConfig(PretrainedConfig): + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + weight_bits=1, + input_bits=8, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.weight_bits = weight_bits + self.input_bits = input_bits + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, + dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", "dynamic" + ]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}" + ) + + +class BitnetMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + config: BitnetConfig = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + self.ffn_layernorm = RMSNorm(intermediate_size, + eps=config.rms_norm_eps) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.ffn_layernorm(x) + x, _ = self.down_proj(x) + return x + + +class BitnetAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, float]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + config: BitnetConfig = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = hidden_size // self.total_num_heads + self.padded_head_dim = self.find_flash_attn_supported_head_dims( + self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) + + self.attn = Attention( + self.num_heads, + self.padded_head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.inner_attn_ln = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: + """ + Find the closest head dimension to the given head dimension that is supported by Flash Attention. + """ + from aphrodite.attention.backends.flash_attn import FlashAttentionBackend + + FLASHATTN_SUPPORTED_HEAD_DIMS = ( + FlashAttentionBackend.get_supported_head_sizes()) + for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: + if head_dim <= supported_head_dim: + return supported_head_dim + raise ValueError( + f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are " + f"{FLASHATTN_SUPPORTED_HEAD_DIMS}.") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # QKV projection cannot be grouped as the they + # do not share the same scaling factor + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + # Padding as paged attention doesn't support head_dim == 100 + q = torch.nn.functional.pad( + q.view(-1, self.total_num_heads, self.head_dim), + (0, self.padded_head_dim - self.head_dim), + ).view(-1, self.total_num_heads * self.padded_head_dim) + k = torch.nn.functional.pad( + k.view(-1, self.num_kv_heads, self.head_dim), + (0, self.padded_head_dim - self.head_dim), + ).view(-1, self.total_num_kv_heads * self.padded_head_dim) + v = torch.nn.functional.pad( + v.view(-1, self.num_kv_heads, self.head_dim), + (0, self.padded_head_dim - self.head_dim), + ).view(-1, self.total_num_kv_heads * self.padded_head_dim) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.total_num_heads, + self.padded_head_dim)[..., :self.head_dim].reshape( + -1, self.total_num_heads * self.head_dim) + attn_output = self.inner_attn_ln(attn_output) + output, _ = self.o_proj(attn_output) + return output + + +class BitnetDecoderLayer(nn.Module): + + def __init__( + self, + config: BitnetConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = BitnetAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + config=config, + ) + self.mlp = BitnetMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + config=config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class BitnetModel(nn.Module): + + def __init__( + self, + config: BitnetConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + BitnetDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BitnetForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: BitnetConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + + self.model = BitnetModel(config, cache_config, quant_config) + + self.unpadded_vocab_size = config.vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + # align scaling attr with param + if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.view(param.data.shape) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + f"Found kv scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_kv_scale_name}). kv-scale is " + "not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # align scaling attr with param + if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.view(param.data.shape) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/aphrodite/quantization/__init__.py b/aphrodite/quantization/__init__.py index 82c5890c7..03ffabfd1 100644 --- a/aphrodite/quantization/__init__.py +++ b/aphrodite/quantization/__init__.py @@ -5,6 +5,7 @@ from aphrodite.quantization.aqlm import AQLMConfig from aphrodite.quantization.awq import AWQConfig from aphrodite.quantization.base_config import QuantizationConfig +from aphrodite.quantization.bitnet import BITNETBitBLASConfig from aphrodite.quantization.bitsandbytes import BitsandBytesConfig from aphrodite.quantization.compressed_tensors.compressed_tensors import \ CompressedTensorsConfig @@ -39,6 +40,7 @@ "gguf": GGUFConfig, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) + "bitnet": BITNETBitBLASConfig, "marlin": MarlinConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, diff --git a/aphrodite/quantization/bitnet.py b/aphrodite/quantization/bitnet.py new file mode 100644 index 000000000..04b3d854a --- /dev/null +++ b/aphrodite/quantization/bitnet.py @@ -0,0 +1,419 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch +from loguru import logger +from torch.nn.parameter import Parameter + +from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + set_weight_attrs) +from aphrodite.quantization.base_config import QuantizationConfig + +try: + import bitblas + import bitblas.cache + from bitblas.utils import auto_detect_nvidia_target +except ImportError as e: + bitblas_import_exception = e + error_message = ( + "Trying to use the bitblas backend, but could not import dependencies " + f"with the following error: {bitblas_import_exception}") + raise ValueError(error_message) from bitblas_import_exception + +bitblas.set_log_level("Debug") +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() + +BITNET_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] + + +class BITNETBitBLASConfig(QuantizationConfig): + """Config class for BITNET BitBLAS""" + + TORCH_DTYPE = torch.int8 + BITNET_CKPT_STORAGE_DTYPE = ( + "float16" # BITNET Default Checkpoints use float16 as storage dtype + ) + BITNET_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, BITNET_BITBLAS_STORAGE_DTYPE) + + def __init__(self, weight_bits: int, is_sym: bool) -> None: + self.input_bits = 8 + self.weight_bits = weight_bits + self.is_sym = is_sym + + # Verify + if self.weight_bits not in BITNET_BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {BITNET_BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + self.storage_dtype = self.BITNET_BITBLAS_STORAGE_DTYPE + self.nbits = weight_bits + + def __repr__(self) -> str: + return (f"BITNETBitBLASConfig(weight_bits={self.weight_bits}, " + f"is_sym={self.is_sym})") + + @classmethod + def get_name(cls) -> str: + return "bitnet_bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.int8] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BITNETBitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + is_sym = cls.get_from_keys(config, ["sym"]) + return cls(weight_bits, is_sym) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_bitblas_compatible(hf_quant_cfg) + + is_valid_user_quant = user_quant is None or user_quant == "bitblas" + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "bitnet": + logger.info( + "Detected that the model can run with bitnet_bitblas" + ", however you specified quantization=bitnet explicitly," + " so forcing bitnet. Use quantization=bitnet_bitblas for" + " faster inference") + return None + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["BITNETBitBLASLinearMethod"]: + if isinstance(layer, LinearBase): + return BITNETBitBLASLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_bitblas_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits", None) + sym = quant_config.get("sym", None) + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or sym is None: + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies bitblas constraints. + return num_bits in BITNET_BITBLAS_SUPPORTED_NUM_BITS + + +class BITNETBitBLASState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class BITNETBitBLASLinearMethod(LinearMethodBase): + """Linear method for BITNET BitBLAS. + Args: + quant_config: The BITNET BitBLAS quantization config. + """ + + ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__(self, quant_config: BITNETBitBLASConfig) -> None: + self.quant_config = quant_config + self.Qp = 2**(quant_config.input_bits - 1) - 1 + self.Qn = -2**(quant_config.input_bits - 1) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates quantized weights for use in linear operations. + The function initializes and returns a dictionary containing + quantized weights, scales, and zeros for performing quantized + matrix multiplication operations. + Args: + input_size_per_partition: The size of the input partition. + output_partition_sizes: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size + in `quant_config`. + """ + del output_size # Unused arguments. + if params_dtype != torch.float16: + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + bitblas_dtype = "int8" + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + bitblas_dtype=bitblas_dtype, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + + # Init buffers + # Quantized weights + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float16, + ), + requires_grad=False, + ) + set_weight_attrs( + weight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) + + qweight = Parameter( + torch.empty( + *self.bitblas_matmul.retrieve_weight_shape(), + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) + + layer.register_parameter("weight", weight) + layer.register_parameter("qweight", qweight) + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.input_size = input_size + layer.bitblas_state = BITNETBitBLASState.REPACK + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + bitblas_dtype, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + + W_dtype = f"int{bits}" + + matmul_config = MatmulConfig( + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype="float32", + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.quant_config.BITNET_BITBLAS_STORAGE_DTYPE, + with_scaling=False, + with_zeros=False, + with_bias=bias, + layout=layout, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul + from bitblas.cache import global_operator_cache + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info("BitBLAS Tuning done, appended operator to " + "global_operator_cache.") + else: + _message = ( + f"BitBLAS Operator {config} created without tuning. ") + logger.info(_message) + else: + _message = (f"BitBLAS Operator {config} retrieved from cache.") + logger.info(_message) + return bitblas_matmul + + def weight_quant(self, weight): + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) + return result.type(torch.int8) + + def activation_quant(self, x, num_bits=8): + x = x.float() + Qn = self.Qn + Qp = self.Qp + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) + return result.type(torch.int8) + + def repack_bitblas_from_bitnet(self, + b_q_weight: torch.Tensor, + is_qkv_packed: bool = False, + is_gateup_packed: bool = False): + if is_qkv_packed: + hidden_size = b_q_weight.size(0) + sw_q = 1 / b_q_weight[:hidden_size // + 3].abs().mean().clamp(min=1e-5) + sw_k = 1 / b_q_weight[hidden_size // 3:2 * hidden_size // + 3].abs().mean().clamp(min=1e-5) + sw_v = 1 / b_q_weight[2 * hidden_size // + 3:].abs().mean().clamp(min=1e-5) + self.sw = torch.cat( + (sw_q.repeat(hidden_size // 3), sw_k.repeat( + hidden_size // 3), sw_v.repeat(hidden_size // 3)), + dim=0) + + qweight_q = self.weight_quant(b_q_weight[:hidden_size // + 3]).detach() + qweight_k = self.weight_quant( + b_q_weight[hidden_size // 3:2 * hidden_size // 3]).detach() + qweight_v = self.weight_quant(b_q_weight[2 * hidden_size // + 3:]).detach() + qweight = torch.cat([qweight_q, qweight_k, qweight_v], dim=0) + elif is_gateup_packed: + hidden_size = b_q_weight.size(0) + sw_gate = 1 / b_q_weight[:hidden_size // + 2].abs().mean().clamp(min=1e-5) + sw_up = 1 / b_q_weight[hidden_size // + 2:].abs().mean().clamp(min=1e-5) + self.sw = torch.cat((sw_gate.repeat( + hidden_size // 2), sw_up.repeat(hidden_size // 2)), + dim=0) + qweight_gate = self.weight_quant(b_q_weight[:hidden_size // + 2]).detach() + qweight_up = self.weight_quant(b_q_weight[hidden_size // + 2:]).detach() + qweight = torch.cat([qweight_gate, qweight_up], dim=0) + else: + sw = 1 / b_q_weight.abs().mean().clamp(min=1e-5) + self.sw = sw + qweight = self.weight_quant(b_q_weight).detach() + qweight = self.bitblas_matmul.transform_weight(qweight) + return qweight + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + part_size_n = layer.output_size_per_partition + out_shape = x.shape[:-1] + (part_size_n, ) + quant_input = self.activation_quant( + x, self.quant_config.input_bits).detach() + + if layer.bitblas_state == BITNETBitBLASState.REPACK: + layer.bitblas_state = BITNETBitBLASState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by Aphrodite (and won't be freed) + def free_tensor(name): + # free the original weight tensor + delattr(layer, name) + + def replace_tensor(name, new_t): + # Cannot use copy_() as gptq because the storage + # shape and dtype are different + delattr(layer, name) + setattr(layer, name, new_t) + + # Repack weights + # QKVParallelLinear is a special case where the weight is packed + # For bitnet as different weights matrix shouldn't share the same + # scale, we need to unpack and repack the weight matrix + is_qkv_packed = isinstance(layer, QKVParallelLinear) + is_gateup_packed = isinstance(layer, MergedColumnParallelLinear) + bitblas_qweight = self.repack_bitblas_from_bitnet( + layer.weight, is_qkv_packed, is_gateup_packed) + # free the original weight tensor + free_tensor("weight") + replace_tensor("qweight", bitblas_qweight) + + fp32_out = self.bitblas_matmul(quant_input, layer.qweight) + sw = self.sw + Qp = self.Qp + si = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + # if / (si * sw) it will inf in some cases + output = fp32_out / si + output = output / sw + output = output.half() + output = output.type(x.dtype) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/aphrodite/transformers_utils/tokenizer.py b/aphrodite/transformers_utils/tokenizer.py index 9205f4796..29775146a 100644 --- a/aphrodite/transformers_utils/tokenizer.py +++ b/aphrodite/transformers_utils/tokenizer.py @@ -93,11 +93,23 @@ def get_tokenizer( revision=revision, **kwargs) except ValueError as e: - # If the error pertains to the tokenizer class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - if (not trust_remote_code and - ("does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e))): + if "BitnetTokenizer" in str(e): + # This is for the error "'BitnetTokenizer' object has no + # attribute 'sp_model'". + from aphrodite.transformers_utils.tokenizers.bitnet import ( + BitnetTokenizer) + tokenizer = BitnetTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + elif (not trust_remote_code + and ("does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e))): + # If the error pertains to the tokenizer class not existing + # or not currently being imported, suggest using the + # --trust-remote-code flag. err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " diff --git a/aphrodite/transformers_utils/tokenizers/__init__.py b/aphrodite/transformers_utils/tokenizers/__init__.py index cb4cbfc35..9348450ec 100755 --- a/aphrodite/transformers_utils/tokenizers/__init__.py +++ b/aphrodite/transformers_utils/tokenizers/__init__.py @@ -1,3 +1,4 @@ from aphrodite.transformers_utils.tokenizers.baichuan import BaichuanTokenizer +from aphrodite.transformers_utils.tokenizers.bitnet import BitnetTokenizer -__all__ = ["BaichuanTokenizer"] +__all__ = ["BaichuanTokenizer", "BitnetTokenizer"] diff --git a/aphrodite/transformers_utils/tokenizers/bitnet.py b/aphrodite/transformers_utils/tokenizers/bitnet.py new file mode 100644 index 000000000..828b5fc8b --- /dev/null +++ b/aphrodite/transformers_utils/tokenizers/bitnet.py @@ -0,0 +1,463 @@ +# Adapted from +# https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/tokenization_bitnet.py + +# ruff: noqa: E501 +"""Tokenization classes for Bitnet.""" +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "hf-internal-testing/llama-tokenizer": 2048, +} +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class BitnetTokenizer(PreTrainedTokenizer): + """ + Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Bitnet should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = (AddedToken(bos_token, normalized=False, special=True) + if isinstance(bos_token, str) else bos_token) + eos_token = (AddedToken(eos_token, normalized=False, special=True) + if isinstance(eos_token, str) else eos_token) + unk_token = (AddedToken(unk_token, normalized=False, special=True) + if isinstance(unk_token, str) else unk_token) + pad_token = (AddedToken(pad_token, normalized=False, special=True) + if isinstance(pad_token, str) else pad_token) + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565") + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf( + f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)" + ) + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = { + self.convert_ids_to_tokens(i): i + for i in range(self.vocab_size) + } + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if (len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE + and tokens[1] in self.all_special_tokens): + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return (tokens[self.unk_token_length:] + if len(tokens) >= self.unk_token_length else tokens) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens: List[str] = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary( + self, + save_directory, + filename_prefix: Optional[str] = None) -> Optional[Tuple[str]]: + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + error_message = f"Vocabulary path ({save_directory}) should be a directory" + logger.error(error_message) + return None + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file, ) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id) + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + if token_ids_1 is None, only returns the first portion of the mask (0s). + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + The output should look something like: + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + The reference for this chat template is [this code + snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) + in the original repository. + """ + logger.warning_once( + "\nNo chat template is defined for this tokenizer - using the default template " + f"for the {self.__class__.__name__} class. If the default is not appropriate for " + "your model, please set `tokenizer.chat_template` to an appropriate template. " + "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" + ) + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}") + template = template.replace( + "USE_DEFAULT_PROMPT", + "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace( + "'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 00dda0700..065415871 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -6,4 +6,5 @@ nvidia-ml-py == 12.555.43 torch == 2.3.0 xformers == 0.0.26.post1 # Requires torch 2.3.0 triton >= 2.2.0 -vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 \ No newline at end of file +vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 +bitnet \ No newline at end of file From 5d965b34a7a6db285053f82a311e182109102918 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 2 Aug 2024 23:20:22 +0000 Subject: [PATCH 2/5] bitnet -> bitblas in reqs --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 065415871..3a49c2b9c 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,4 @@ torch == 2.3.0 xformers == 0.0.26.post1 # Requires torch 2.3.0 triton >= 2.2.0 vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 -bitnet \ No newline at end of file +bitblas \ No newline at end of file From 74cb1aad4e85d3ea2fa168fad33dc604cb5ca278 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 2 Aug 2024 23:34:59 +0000 Subject: [PATCH 3/5] wip --- aphrodite/modeling/models/bitnet.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/aphrodite/modeling/models/bitnet.py b/aphrodite/modeling/models/bitnet.py index b6ca4726f..f213e0a50 100644 --- a/aphrodite/modeling/models/bitnet.py +++ b/aphrodite/modeling/models/bitnet.py @@ -34,7 +34,7 @@ from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig -from aphrodite.distributed import (get_pp_group, get_tensor_model_parallel_rank, +from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from aphrodite.modeling.layers.activation import SiluAndMul from aphrodite.modeling.layers.layernorm import RMSNorm @@ -430,16 +430,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] residual = None for i in range(len(self.layers)): layer = self.layers[i] From 60af35bc348bc479348a0354e14723139653dca4 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 2 Aug 2024 23:37:06 +0000 Subject: [PATCH 4/5] wip --- aphrodite/modeling/models/bitnet.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/aphrodite/modeling/models/bitnet.py b/aphrodite/modeling/models/bitnet.py index f213e0a50..23fc934e8 100644 --- a/aphrodite/modeling/models/bitnet.py +++ b/aphrodite/modeling/models/bitnet.py @@ -51,7 +51,7 @@ from aphrodite.modeling.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from aphrodite.modeling.sampling_metadata import SamplingMetadata -from aphrodite.common.sequence import IntermediateTensors, SamplerOutput +from aphrodite.common.sequence import SamplerOutput from aphrodite.common.utils import is_hip, print_warning_once @@ -427,12 +427,8 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -501,10 +497,9 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -521,20 +516,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) From 9bbc75d2e3389428a587af92f60e94482898ae15 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 2 Aug 2024 23:41:26 +0000 Subject: [PATCH 5/5] wip --- aphrodite/modeling/models/bitnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aphrodite/modeling/models/bitnet.py b/aphrodite/modeling/models/bitnet.py index 23fc934e8..be468d5c3 100644 --- a/aphrodite/modeling/models/bitnet.py +++ b/aphrodite/modeling/models/bitnet.py @@ -429,6 +429,10 @@ def forward( attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i]