Skip to content

Commit

Permalink
fix biachuan-7b tp (#598)
Browse files Browse the repository at this point in the history
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
  • Loading branch information
Sanster and wq.chu authored Aug 1, 2023
1 parent aa39e42 commit d4c7755
Showing 1 changed file with 37 additions and 10 deletions.
47 changes: 37 additions & 10 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,32 +251,55 @@ def forward(
return next_tokens

_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "W_pack.weight",
"gate_proj.weight", "up_proj.weight"
"embed_tokens.weight",
"lm_head.weight",
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()

for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue

if "embed_tokens" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads

loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
Expand All @@ -287,7 +310,11 @@ def load_weights(self,
continue

param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)

0 comments on commit d4c7755

Please sign in to comment.