-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix biachuan-7b tp #598
fix biachuan-7b tp #598
Conversation
Is the same reason for baichuan-13b? #530 |
Yes. I have tested it on both baichuan13b and 7b, and it can output normal output under tp. |
Can I use this PR directly on 13B? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution! Can you use our official formatting script and remove other additional format changes?
if "embed_tokens" in name or "lm_head" in name: | ||
# Consider padding in the vocab size. | ||
param = state_dict[name] | ||
padded_vocab_size = param.shape[0] * tp_world_size | ||
num_extra_rows = padded_vocab_size - self.config.vocab_size | ||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) | ||
extra_rows = extra_rows.to(loaded_weight) | ||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | ||
|
||
if "W_pack" in name: | ||
# W_pack.weight.shape [3*hidden_size, hidden_size] [3*4096, 4096] = [12,288, 4096] | ||
total_num_heads = self.config.num_attention_heads | ||
hidden_size = self.config.hidden_size | ||
head_size = hidden_size // total_num_heads | ||
num_heads = total_num_heads // tp_world_size | ||
head_start = tp_rank * num_heads | ||
head_end = (tp_rank + 1) * num_heads | ||
|
||
loaded_weight = loaded_weight.view( | ||
3, total_num_heads, head_size, hidden_size | ||
) | ||
loaded_weight = loaded_weight[:, head_start:head_end, :, :] | ||
loaded_weight = loaded_weight.reshape(-1, hidden_size) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this part the only part that actually changes the code logic? Can you remove other format-only modifications and use format.sh
script provided by us to re-format the code? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I have already modified the content of the PR and removed the invalid format part.
356793c
to
aeb2d9e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you for your contribution!
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
The main modifications are in the "load_weights" function.
Before:

After:
