Skip to content

Commit

Permalink
refactor dbrx experts to use FusedMoe layer (vllm-project#186)
Browse files Browse the repository at this point in the history
* refactor dbrx experts to use FusedMoe layer
* prepare for the fp8 input

---------

Co-authored-by: charlifu <charlifu@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 17, 2024
1 parent c68c242 commit 6bd99d2
Showing 1 changed file with 56 additions and 72 deletions.
128 changes: 56 additions & 72 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
Expand All @@ -22,7 +21,6 @@
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig

Expand Down Expand Up @@ -54,63 +52,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return router_logits


class DbrxExperts(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
class DbrxExperts(FusedMoE):

def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
super().__init__(
num_experts=config.ffn_config.moe_num_experts,
top_k=config.ffn_config.moe_top_k,
hidden_size=config.d_model,
intermediate_size=config.ffn_config.ffn_hidden_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
self.d_model = config.d_model
self.intermediate_size = (config.ffn_config.ffn_hidden_size //
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
self.tp_size)

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

self.router = DbrxRouter(config, self.params_dtype)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.d_model,
device="cuda",
dtype=self.params_dtype,
))
self.w2s = nn.Parameter(
torch.empty(
self.num_total_experts,
self.d_model,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
))

set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)

# Define custom weight loader for dbrx model
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
tp_rank = get_tensor_model_parallel_rank()
Expand All @@ -119,47 +86,61 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1"):
if weight_name.endswith("w1."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
if weight_name.endswith("v1"):
if weight_name.endswith("v1."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
if weight_name.endswith("w2"):
if weight_name.endswith("w2."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]


class DbrxMoE(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""

def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.d_model = config.d_model
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

self.router = DbrxRouter(config, self.params_dtype)

self.experts = DbrxExperts(config=config,
quant_config=quant_config,
params_dtype=self.params_dtype)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states.view(num_tokens, hidden_size)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)


class DbrxAttention(nn.Module):
Expand Down Expand Up @@ -288,7 +269,7 @@ def __init__(
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config)
self.ffn = DbrxExperts(config, quant_config)
self.ffn = DbrxMoE(config, quant_config)

def forward(
self,
Expand Down Expand Up @@ -409,12 +390,15 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
"w13_" if weight_name in ["w1", "v1"] else "w2_",
f"mlp.{weight_name}.",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name.endswith(("w1", "v1", "w2")):
name = name + ".weight"
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
Expand Down

0 comments on commit 6bd99d2

Please sign in to comment.