@@ -239,7 +239,6 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
239
239
self .rope_theta = config .rope_theta
240
240
self .is_causal = True
241
241
self .attention_dropout = config .attention_dropout
242
- self .max_pos_len = config .max_pos_len
243
242
244
243
if (self .head_dim * self .num_heads ) != self .hidden_size :
245
244
raise ValueError (
@@ -265,6 +264,7 @@ def forward(
265
264
past_key_value : Optional [Cache ] = None ,
266
265
output_attentions : bool = False ,
267
266
use_cache : bool = False ,
267
+ max_pos_len : Optional [int ] = 0 ,
268
268
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
269
269
bsz , q_len , _ = hidden_states .size ()
270
270
@@ -294,7 +294,7 @@ def forward(
294
294
)
295
295
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
296
296
297
- cos , sin = self .rotary_emb (value_states , seq_len = self . max_pos_len )
297
+ cos , sin = self .rotary_emb (value_states , seq_len = max_pos_len )
298
298
# if past_key_value is not None:
299
299
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len-1)
300
300
# else:
@@ -764,6 +764,7 @@ def forward(
764
764
past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
765
765
output_attentions : Optional [bool ] = False ,
766
766
use_cache : Optional [bool ] = False ,
767
+ max_pos_len : Optional [int ] = 0 ,
767
768
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
768
769
"""
769
770
Args:
@@ -791,6 +792,7 @@ def forward(
791
792
past_key_value = past_key_value ,
792
793
output_attentions = output_attentions ,
793
794
use_cache = use_cache ,
795
+ max_pos_len = max_pos_len
794
796
)
795
797
hidden_states = residual + hidden_states
796
798
0 commit comments