diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 71238d6909a69..4453b4b9f0523 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -17,6 +17,7 @@ "stabilityai/stablelm-3b-4e1t", # "allenai/OLMo-1B", # Broken "bigcode/starcoder2-3b", + "google/gemma-1.1-2b-it", ] diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index ce97fc808c85e..efefb34814c90 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -26,14 +26,14 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -148,12 +148,14 @@ def __init__(self, quant_config=quant_config, ) - self.rotary_emb = get_rope( + # TODO(woosuk): Use the `get_rope` interface. + self.rotary_emb = GemmaRotaryEmbedding( self.head_dim, rotary_dim=self.head_dim, - max_position=max_position_embeddings, + max_position_embeddings=max_position_embeddings, base=self.rope_theta, is_neox_style=True, + dtype=torch.get_default_dtype(), ) self.attn = Attention(self.num_heads, self.head_dim, @@ -204,10 +206,10 @@ def __init__( hidden_activation=getattr(config, "hidden_activation", None), quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -257,7 +259,7 @@ def __init__( GemmaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) # The normalizer's data type should be downcasted to the model's @@ -331,7 +333,6 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -388,10 +389,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # GemmaRMSNorm is different from Llama's in that it multiplies - # (1 + weight) to the output, instead of just weight. - if "norm.weight" in name: - loaded_weight += 1.0 param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)