diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 3696c56a903f0..ff5afe51d80e7 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -251,8 +251,8 @@ def forward( return next_tokens _column_parallel_weights = [ - "embed_tokens.weight", "lm_head.weight", "W_pack.weight", - "gate_proj.weight", "up_proj.weight" + "embed_tokens.weight", + "lm_head.weight", ] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] @@ -260,7 +260,8 @@ def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_rank = get_tensor_model_parallel_rank() + tp_world_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( @@ -268,15 +269,37 @@ def load_weights(self, if "rotary_emb.inv_freq" in name: continue + if "embed_tokens" in name or "lm_head" in name: + # Consider padding in the vocab size. + param = state_dict[name] + padded_vocab_size = param.shape[0] * tp_world_size + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + if "W_pack" in name: + total_num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // total_num_heads + num_heads = total_num_heads // tp_world_size + head_start = tp_rank * num_heads + head_end = (tp_rank + 1) * num_heads + + loaded_weight = loaded_weight.view(3, total_num_heads, + head_size, hidden_size) + loaded_weight = loaded_weight[:, head_start:head_end, :, :] + loaded_weight = loaded_weight.reshape(-1, hidden_size) + is_gate_up_weight = False for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): if weight_name not in name: continue param = state_dict[name.replace(weight_name, "gate_up_proj")] shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] + loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * + (tp_rank + 1)] param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] assert param_slice.shape == loaded_weight.shape @@ -287,7 +310,11 @@ def load_weights(self, continue param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + )