From cbe1d2e0e67c6aa5191a09288f5faf0e4a6c1573 Mon Sep 17 00:00:00 2001 From: BoringDoggie <57415741+ericzhou571@users.noreply.github.com> Date: Wed, 19 Jul 2023 13:54:51 +0800 Subject: [PATCH 1/5] add support for baichuan_13b --- vllm/model_executor/models/baichuan_13b.py | 326 +++++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 vllm/model_executor/models/baichuan_13b.py diff --git a/vllm/model_executor/models/baichuan_13b.py b/vllm/model_executor/models/baichuan_13b.py new file mode 100644 index 0000000000000..92c8f48c801fb --- /dev/null +++ b/vllm/model_executor/models/baichuan_13b.py @@ -0,0 +1,326 @@ +# coding=utf-8 +# Adapted from https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/modeling_baichuan.py +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Baichuan-13b model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +import math +from torch import nn + +from vllm.sequence import SequenceOutputs +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_tensor_parallel_weights, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.sequence import SequenceOutputs + +from vllm.transformers_utils.configs import BaichuanConfig +from vllm.model_executor.layers.attention import PagedAttentionWithALiBi + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BaichuanMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BaichuanAttention(nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + self.W_pack = ColumnParallelLinear( + self.hidden_size, + 3 * self.total_num_heads * self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + ) + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = PagedAttentionWithALiBi( + self.num_heads, self.head_dim, scaling, alibi_slopes + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.W_pack(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, cache_event) + output, _ = self.o_proj(attn_output) + return output + + +class BaichuanDecoderLayer(nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = BaichuanAttention(config=config) + self.mlp = BaichuanMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class BaichuanModel(nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, perform_initialization=False + ) + self.layers = nn.ModuleList( + [BaichuanDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + del positions # unused + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class BaichuanForCausalLM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.model = BaichuanModel(config) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.model( + input_ids, positions, kv_caches, input_metadata, cache_events + ) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, input_metadata) + return next_tokens + + _column_parallel_weights = [ + "embed_tokens.weight", + "lm_head.weight", + "W_pack.weight", + "gate_proj.weight", + "up_proj.weight", + ] + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False, + ): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache + ): + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank, + ) From aedf32df642fa79b12d946fa8fca37a2461a2ada Mon Sep 17 00:00:00 2001 From: BoringDoggie <57415741+ericzhou571@users.noreply.github.com> Date: Wed, 19 Jul 2023 13:55:44 +0800 Subject: [PATCH 2/5] add support for baichuan_13b --- vllm/model_executor/models/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c3e3e5723e533..563889348ec0b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -7,9 +7,11 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.models.baichuan_13b import BaichuanForCausalLM __all__ = [ - "BaiChuanForCausalLM", + "BaiChuanForCausalLM", # 7b + "BaichuanForCausalLM", # 13b "BloomForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", From 7c1282d3179a4cc2423defc07cf4f07c780d8d92 Mon Sep 17 00:00:00 2001 From: BoringDoggie <57415741+ericzhou571@users.noreply.github.com> Date: Wed, 19 Jul 2023 13:56:22 +0800 Subject: [PATCH 3/5] Update model_loader.py --- vllm/model_executor/model_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index a3fd24c911b14..09880a74456bf 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -11,7 +11,8 @@ # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { - "BaiChuanForCausalLM": BaiChuanForCausalLM, + "BaiChuanForCausalLM": BaiChuanForCausalLM, # 7b + "BaichuanForCausalLM": BaichuanForCausalLM, # 13b "BloomForCausalLM": BloomForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, From eb7728096fb72d1ff0791bba7b21bf17d0d03a64 Mon Sep 17 00:00:00 2001 From: BoringDoggie <57415741+ericzhou571@users.noreply.github.com> Date: Wed, 19 Jul 2023 13:57:46 +0800 Subject: [PATCH 4/5] add support for baichuan_13b --- .../configs/baichuan_13b.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 vllm/transformers_utils/configs/baichuan_13b.py diff --git a/vllm/transformers_utils/configs/baichuan_13b.py b/vllm/transformers_utils/configs/baichuan_13b.py new file mode 100644 index 0000000000000..e7bd8f6275bc3 --- /dev/null +++ b/vllm/transformers_utils/configs/baichuan_13b.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +from transformers.configuration_utils import PretrainedConfig + + +class BaichuanConfig(PretrainedConfig): + model_type = "baichuan" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=64000, + hidden_size=5120, + intermediate_size=13696, + num_hidden_layers=40, + num_attention_heads=40, + hidden_act="silu", + model_max_length=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + gradient_checkpointing=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.model_max_length = model_max_length + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.gradient_checkpointing = (gradient_checkpointing,) + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) From 2f9d4553ec7b3db445ac8b727e209a43af869539 Mon Sep 17 00:00:00 2001 From: BoringDoggie <57415741+ericzhou571@users.noreply.github.com> Date: Wed, 19 Jul 2023 14:01:13 +0800 Subject: [PATCH 5/5] add support for baichuan_13b --- vllm/transformers_utils/configs/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 5f0ba4eb9fea4..643dfc8d7427e 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,7 +1,9 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig +from vllm.transformers_utils.configs.baichuan_13b import BaichuanConfig __all__ = [ "MPTConfig", - "BaiChuanConfig", + "BaiChuanConfig", # 7b + "BaichuanConfig", # 13b ]