@@ -239,16 +239,27 @@ def select_ext_factor(seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, s
239
239
return short_factor
240
240
241
241
def rope_fwd (self , x , position_ids , seq_len = None ):
242
- seq_len = torch .tensor (seq_len ) or torch .max (position_ids ) + 1
242
+ seq_len = torch .max (position_ids ) + 1
243
+ original_max_position_embeddings = (
244
+ self .original_max_position_embeddings
245
+ if hasattr (self , "original_max_positional_embeddings" )
246
+ else self .config .original_max_position_embeddings
247
+ )
248
+ max_position_embeddings = self .max_position_embeddings if hasattr (self , "max_position_embeddings" ) else self .config .max_position_embeddings
249
+ short_factor = self .short_factor if hasattr (self , "short_factor" ) else self .config .rope_scaling ["short_factor" ]
250
+ long_factor = self .long_factor if hasattr (self , "long_factor" ) else self .config .rope_scaling ["long_factor" ]
243
251
ext_factors = select_ext_factor (
244
252
seq_len ,
245
- torch .tensor (self . original_max_position_embeddings ),
246
- torch .tensor (self . short_factor , dtype = torch .float32 , device = x .device ),
247
- torch .tensor (self . long_factor , dtype = torch .float32 , device = x .device ),
253
+ torch .tensor (original_max_position_embeddings ),
254
+ torch .tensor (short_factor , dtype = torch .float32 , device = x .device ),
255
+ torch .tensor (long_factor , dtype = torch .float32 , device = x .device ),
248
256
)
257
+ base = self .config .rope_theta if not hasattr (self , "base" ) else self .base
258
+
259
+ dim = self .dim if hasattr (self , "dim" ) else getattr (self .config , "head_dim" , self .config .hidden_size // self .config .num_attention_heads )
249
260
250
- inv_freq_shape = torch .arange (0 , self . dim , 2 , dtype = torch .int64 , device = x .device ).float () / self . dim
251
- inv_freq = 1.0 / (ext_factors * self . base ** inv_freq_shape )
261
+ inv_freq_shape = torch .arange (0 , dim , 2 , dtype = torch .int64 , device = x .device ).float () / dim
262
+ inv_freq = 1.0 / (ext_factors * base ** inv_freq_shape )
252
263
253
264
inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
254
265
position_ids_expanded = position_ids [:, None , :].float ()
@@ -257,23 +268,111 @@ def rope_fwd(self, x, position_ids, seq_len=None):
257
268
# See https://github.com/huggingface/transformers/pull/29285
258
269
device_type = x .device .type
259
270
device_type = device_type if isinstance (device_type , str ) and device_type != "mps" else "cpu"
260
- with torch .autocast (device_type = device_type , enabled = False ):
261
- freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
262
- emb = torch .cat ((freqs , freqs ), dim = - 1 )
271
+ freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
272
+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
263
273
264
- scale = self . max_position_embeddings / self . original_max_position_embeddings
265
- if scale <= 1.0 :
266
- scaling_factor = 1.0
267
- else :
268
- scaling_factor = math .sqrt (1 + math .log (scale ) / math .log (self . original_max_position_embeddings ))
269
- cos = emb .cos () * scaling_factor
270
- sin = emb .sin () * scaling_factor
271
- return cos . to ( dtype = x . dtype ) , sin . to ( dtype = x . dtype )
274
+ scale = max_position_embeddings / original_max_position_embeddings
275
+ if scale <= 1.0 :
276
+ scaling_factor = 1.0
277
+ else :
278
+ scaling_factor = math .sqrt (1 + math .log (scale ) / math .log (original_max_position_embeddings ))
279
+ cos = emb .cos () * scaling_factor
280
+ sin = emb .sin () * scaling_factor
281
+ return cos , sin
272
282
273
283
pipe .model .llm ._orig_forward = pipe .model .llm .forward
274
284
pipe .model .llm .forward = MethodType (forward_wrap , pipe .model .llm )
275
285
if hasattr (pipe .model .llm , "rotary_emb" ):
276
286
pipe .model .llm .rotary_emb .forward = MethodType (rope_fwd , pipe .model .llm .rotary_emb )
287
+ from transformers .cache_utils import Cache , DynamicCache
288
+ from transformers .modeling_outputs import BaseModelOutputWithPast
289
+
290
+ def new_transformers_forward (
291
+ self ,
292
+ input_ids : torch .LongTensor = None ,
293
+ attention_mask : Optional [torch .Tensor ] = None ,
294
+ position_ids : Optional [torch .LongTensor ] = None ,
295
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
296
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
297
+ use_cache : Optional [bool ] = None ,
298
+ output_attentions : Optional [bool ] = None ,
299
+ output_hidden_states : Optional [bool ] = None ,
300
+ return_dict : Optional [bool ] = None ,
301
+ cache_position : Optional [torch .LongTensor ] = None ,
302
+ offload_model : Optional [bool ] = False ,
303
+ ) -> Union [Tuple , BaseModelOutputWithPast ]:
304
+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
305
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
306
+ use_cache = use_cache if use_cache is not None else self .config .use_cache
307
+
308
+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
309
+
310
+ if (input_ids is None ) ^ (inputs_embeds is not None ):
311
+ raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
312
+
313
+ # kept for BC (non `Cache` `past_key_values` inputs)
314
+
315
+ if attention_mask is not None and attention_mask .dim () == 3 :
316
+ dtype = inputs_embeds .dtype
317
+ min_dtype = torch .finfo (dtype ).min
318
+ attention_mask = (1 - attention_mask ) * min_dtype
319
+ attention_mask = attention_mask .unsqueeze (1 ).to (inputs_embeds .dtype )
320
+ else :
321
+ raise Exception ("attention_mask parameter was unavailable or invalid" )
322
+ # causal_mask = self._update_causal_mask(
323
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
324
+ # )
325
+
326
+ hidden_states = inputs_embeds
327
+ position_embeddings = self .rotary_emb (hidden_states , position_ids )
328
+
329
+ # decoder layers
330
+ all_hidden_states = () if output_hidden_states else None
331
+ all_self_attns = () if output_attentions else None
332
+
333
+ layer_idx = - 1
334
+ for decoder_layer in self .layers :
335
+ layer_idx += 1
336
+
337
+ if output_hidden_states :
338
+ all_hidden_states += (hidden_states ,)
339
+
340
+ if offload_model and not self .training :
341
+ self .get_offlaod_layer (layer_idx , device = inputs_embeds .device )
342
+ layer_outputs = decoder_layer (
343
+ hidden_states ,
344
+ attention_mask = attention_mask ,
345
+ position_ids = position_ids ,
346
+ past_key_value = past_key_values ,
347
+ output_attentions = output_attentions ,
348
+ use_cache = use_cache ,
349
+ cache_position = cache_position ,
350
+ position_embeddings = position_embeddings ,
351
+ )
352
+
353
+ hidden_states = layer_outputs [0 ]
354
+
355
+ if output_attentions :
356
+ all_self_attns += (layer_outputs [1 ],)
357
+
358
+ hidden_states = self .norm (hidden_states )
359
+
360
+ # add hidden states from the last decoder layer
361
+ if output_hidden_states :
362
+ all_hidden_states += (hidden_states ,)
363
+
364
+ next_cache = past_key_values if use_cache else None
365
+
366
+ if not return_dict :
367
+ return tuple (v for v in [hidden_states , next_cache , all_hidden_states , all_self_attns ] if v is not None )
368
+ return BaseModelOutputWithPast (
369
+ last_hidden_state = hidden_states ,
370
+ past_key_values = next_cache ,
371
+ hidden_states = all_hidden_states ,
372
+ attentions = all_self_attns ,
373
+ )
374
+
375
+ pipe .model .llm ._orig_forward = MethodType (new_transformers_forward , pipe .model .llm )
277
376
else :
278
377
for layer in pipe .model .llm .layers :
279
378
layer .self_attn .rotary_emb .forward = MethodType (rope_fwd , layer .self_attn .rotary_emb )
0 commit comments