From e1385f9e56f3f8f9209367c68ee24e9305797556 Mon Sep 17 00:00:00 2001 From: Shukant Pal Date: Tue, 18 Jun 2024 00:28:06 -0700 Subject: [PATCH 1/4] Fix Phi-3 Long RoPE scaling implementation When evaluating our fine tunes of Phi-3, we noticed a big difference in the logits outputted by Huggingface transformers and vLLM. The most significant cause of this was a deviation in the positional embedding implementation of vLLM. Since Microsoft contributed their Phi-3 implementation to Huggingface transformers, I believe that is the reference implementation to follow:- https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/models/phi3/modeling_phi3.py#L153 The above link shows how Phi-3's rotation matrix is scaled. I've added similar code to vLLM's `Phi3LongRoPEScaledRotaryEmbedding` and removed the hardcoded scaling factors. --- vllm/model_executor/layers/rotary_embedding.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9c0a74cdab96e..7bba3657ccc93 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -507,8 +507,8 @@ def __init__( dtype: torch.dtype, short_factor: List[float], long_factor: List[float], - short_mscale: float = 1.1, - long_mscale: float = 1.225, + short_mscale: float = 1.0, + long_mscale: float = 1.0, ): super().__init__() @@ -530,6 +530,13 @@ def __init__( self.short_mscale = short_mscale self.long_mscale = long_mscale + scale = self.max_position_embeddings / self.original_max_position_embeddings + + if scale <= 1.0: + self.scaling_factor = 1.0 + else: + self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) short_cache = short_cache.to(dtype) @@ -565,8 +572,8 @@ def _compute_cos_sin_cache( inv_freq = self._compute_inv_freq(rescale_factors) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * mscale - sin = freqs.sin() * mscale + cos = freqs.cos() * mscale * self.scaling_factor + sin = freqs.sin() * mscale * self.scaling_factor cache = torch.cat((cos, sin), dim=-1) return cache From 15968ab08381f3442af9d5732cae16fa70efbaa3 Mon Sep 17 00:00:00 2001 From: Shukant Pal Date: Tue, 18 Jun 2024 00:36:03 -0700 Subject: [PATCH 2/4] Manual Lint --- vllm/model_executor/layers/rotary_embedding.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 7bba3657ccc93..4f51f9fa05b86 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -530,12 +530,19 @@ def __init__( self.short_mscale = short_mscale self.long_mscale = long_mscale - scale = self.max_position_embeddings / self.original_max_position_embeddings + scale = ( + self.max_position_embeddings / self.original_max_position_embeddings + ) + if scale <= 1.0: self.scaling_factor = 1.0 else: - self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + self.scaling_factor = math.sqrt( + 1 + + math.log(scale) + / math.log(self.original_max_position_embeddings) + ) short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) From c9ad1a8a7d3e3a0c3853108b6709be8aa0989297 Mon Sep 17 00:00:00 2001 From: Shukant Pal Date: Tue, 18 Jun 2024 00:36:12 -0700 Subject: [PATCH 3/4] Fix extra line --- vllm/model_executor/layers/rotary_embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 4f51f9fa05b86..f6f9c6a13c9b6 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -534,7 +534,6 @@ def __init__( self.max_position_embeddings / self.original_max_position_embeddings ) - if scale <= 1.0: self.scaling_factor = 1.0 else: From d1b651c9d8d8184b80992158c54cae371186f459 Mon Sep 17 00:00:00 2001 From: Shukant Pal Date: Tue, 18 Jun 2024 10:03:42 -0700 Subject: [PATCH 4/4] Yapf format --- vllm/model_executor/layers/rotary_embedding.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index f6f9c6a13c9b6..a0b19046b7491 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -530,18 +530,15 @@ def __init__( self.short_mscale = short_mscale self.long_mscale = long_mscale - scale = ( - self.max_position_embeddings / self.original_max_position_embeddings - ) + scale = (self.max_position_embeddings / + self.original_max_position_embeddings) if scale <= 1.0: self.scaling_factor = 1.0 else: self.scaling_factor = math.sqrt( - 1 - + math.log(scale) - / math.log(self.original_max_position_embeddings) - ) + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale)