diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9c0a74cdab96e..a0b19046b7491 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -507,8 +507,8 @@ def __init__( dtype: torch.dtype, short_factor: List[float], long_factor: List[float], - short_mscale: float = 1.1, - long_mscale: float = 1.225, + short_mscale: float = 1.0, + long_mscale: float = 1.0, ): super().__init__() @@ -530,6 +530,16 @@ def __init__( self.short_mscale = short_mscale self.long_mscale = long_mscale + scale = (self.max_position_embeddings / + self.original_max_position_embeddings) + + if scale <= 1.0: + self.scaling_factor = 1.0 + else: + self.scaling_factor = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) short_cache = short_cache.to(dtype) @@ -565,8 +575,8 @@ def _compute_cos_sin_cache( inv_freq = self._compute_inv_freq(rescale_factors) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * mscale - sin = freqs.sin() * mscale + cos = freqs.cos() * mscale * self.scaling_factor + sin = freqs.sin() * mscale * self.scaling_factor cache = torch.cat((cos, sin), dim=-1) return cache