29
29
args = parser .parse_args ()
30
30
31
31
model_path = args .model_path
32
+ is_4B = "InternVL2-4B" in model_path
32
33
folder = f"./tmp/onnx"
33
34
34
35
origin_model = AutoModelForCausalLM .from_pretrained (
48
49
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
49
50
VOCAB_SIZE = config .llm_config .vocab_size
50
51
DOWNSAMPLE_RATIO = config .downsample_ratio
51
- EOS_TOKEN_ID = config .llm_config .eos_token_id
52
+ ID_EOS = config .llm_config .eos_token_id
52
53
print (f'Layers: { NUM_LAYERS } \n Hidden size: { HIDDEN_SIZE } \n ' )
53
54
54
55
vit = origin_model .vision_model
@@ -63,9 +64,10 @@ class Embedding(torch.nn.Module):
63
64
64
65
def __init__ (self ):
65
66
super ().__init__ ()
67
+ self .embed = transformer .get_input_embeddings ()
66
68
67
69
def forward (self , input_ids ):
68
- hidden_states = transformer . embed_tokens (input_ids )
70
+ hidden_states = self . embed (input_ids )
69
71
return hidden_states
70
72
71
73
@@ -75,13 +77,18 @@ def __init__(self, layer_id):
75
77
super ().__init__ ()
76
78
self .layer_id = layer_id
77
79
self .layer = layers [layer_id ]
78
- self . rotary_emb = self . layer . self_attn . rotary_emb
80
+
79
81
position_ids = torch .tensor (
80
82
[range (SEQ_LENGTH )], dtype = torch .long ).cuda ()
81
83
value_states = torch .randn (
82
84
(1 , SEQ_LENGTH , config .llm_config .num_key_value_heads , HEAD_DIM )).bfloat16 ().cuda ()
83
- self .cos , self .sin = self .rotary_emb (
84
- value_states , position_ids , SEQ_LENGTH )
85
+ if is_4B :
86
+ self .rotary_emb = self .layer .self_attn .rotary_emb
87
+ self .cos , self .sin = self .rotary_emb (
88
+ value_states , position_ids , SEQ_LENGTH )
89
+ else :
90
+ self .rotary_emb = self .layer .attention .rotary_emb
91
+ self .cos , self .sin = self .rotary_emb (value_states , SEQ_LENGTH )
85
92
self .cos = self .cos .view (SEQ_LENGTH , HEAD_DIM )
86
93
self .sin = self .sin .view (SEQ_LENGTH , HEAD_DIM )
87
94
@@ -105,13 +112,17 @@ def __init__(self, layer_id):
105
112
super ().__init__ ()
106
113
self .layer_id = layer_id
107
114
self .layer = layers [layer_id ]
108
- self .rotary_emb = self .layer .self_attn .rotary_emb
109
115
position_ids = torch .tensor (
110
116
[range (SEQ_LENGTH )], dtype = torch .long ).cuda ()
111
117
value_states = torch .randn (
112
118
(1 , SEQ_LENGTH , config .llm_config .num_key_value_heads , HEAD_DIM )).bfloat16 ().cuda ()
113
- self .cos , self .sin = self .rotary_emb (
114
- value_states , position_ids , SEQ_LENGTH )
119
+ if is_4B :
120
+ self .rotary_emb = self .layer .self_attn .rotary_emb
121
+ self .cos , self .sin = self .rotary_emb (
122
+ value_states , position_ids , SEQ_LENGTH )
123
+ else :
124
+ self .rotary_emb = self .layer .attention .rotary_emb
125
+ self .cos , self .sin = self .rotary_emb (value_states , SEQ_LENGTH )
115
126
self .cos = self .cos .view (SEQ_LENGTH , HEAD_DIM )
116
127
self .sin = self .sin .view (SEQ_LENGTH , HEAD_DIM )
117
128
@@ -134,10 +145,11 @@ class LmHead(torch.nn.Module):
134
145
135
146
def __init__ (self ):
136
147
super ().__init__ ()
148
+ self .lm_head = origin_model .language_model .get_output_embeddings ()
137
149
138
150
def forward (self , hidden_states ):
139
151
hidden_states = transformer .norm (hidden_states )
140
- m_logits = origin_model . language_model .lm_head (hidden_states )
152
+ m_logits = self .lm_head (hidden_states )
141
153
_ , token = torch .topk (m_logits .float (), 1 )
142
154
return token
143
155
@@ -251,68 +263,10 @@ def build_transform(input_size):
251
263
return transform
252
264
253
265
254
- def find_closest_aspect_ratio (aspect_ratio , target_ratios , width , height , image_size ):
255
- best_ratio_diff = float ('inf' )
256
- best_ratio = (1 , 1 )
257
- area = width * height
258
- for ratio in target_ratios :
259
- target_aspect_ratio = ratio [0 ] / ratio [1 ]
260
- ratio_diff = abs (aspect_ratio - target_aspect_ratio )
261
- if ratio_diff < best_ratio_diff :
262
- best_ratio_diff = ratio_diff
263
- best_ratio = ratio
264
- elif ratio_diff == best_ratio_diff :
265
- if area > 0.5 * image_size * image_size * ratio [0 ] * ratio [1 ]:
266
- best_ratio = ratio
267
- return best_ratio
268
-
269
-
270
- def dynamic_preprocess (image , min_num = 1 , max_num = 12 , image_size = 448 , use_thumbnail = False ):
271
- orig_width , orig_height = image .size
272
- aspect_ratio = orig_width / orig_height
273
-
274
- # calculate the existing image aspect ratio
275
- target_ratios = set (
276
- (i , j ) for n in range (min_num , max_num + 1 ) for i in range (1 , n + 1 ) for j in range (1 , n + 1 ) if
277
- i * j <= max_num and i * j >= min_num )
278
- target_ratios = sorted (target_ratios , key = lambda x : x [0 ] * x [1 ])
279
-
280
- # find the closest aspect ratio to the target
281
- target_aspect_ratio = find_closest_aspect_ratio (
282
- aspect_ratio , target_ratios , orig_width , orig_height , image_size )
283
-
284
- # calculate the target width and height
285
- target_width = image_size * target_aspect_ratio [0 ]
286
- target_height = image_size * target_aspect_ratio [1 ]
287
- blocks = target_aspect_ratio [0 ] * target_aspect_ratio [1 ]
288
-
289
- # resize the image
290
- resized_img = image .resize ((target_width , target_height ))
291
- processed_images = []
292
- for i in range (blocks ):
293
- box = (
294
- (i % (target_width // image_size )) * image_size ,
295
- (i // (target_width // image_size )) * image_size ,
296
- ((i % (target_width // image_size )) + 1 ) * image_size ,
297
- ((i // (target_width // image_size )) + 1 ) * image_size
298
- )
299
- # split the image
300
- split_img = resized_img .crop (box )
301
- processed_images .append (split_img )
302
- assert len (processed_images ) == blocks
303
- if use_thumbnail and len (processed_images ) != 1 :
304
- thumbnail_img = image .resize ((image_size , image_size ))
305
- processed_images .append (thumbnail_img )
306
- return processed_images
307
-
308
-
309
266
def load_image (image_file , input_size = 448 , max_num = 12 ):
310
267
image = Image .open (image_file ).convert ('RGB' )
311
268
transform = build_transform (input_size = input_size )
312
- images = dynamic_preprocess (
313
- image , image_size = input_size , use_thumbnail = True , max_num = max_num )
314
- pixel_values = [transform (image ) for image in images ]
315
- pixel_values = torch .stack (pixel_values )
269
+ pixel_values = transform (image )
316
270
return pixel_values
317
271
318
272
@@ -332,7 +286,8 @@ def test_net_with_mask():
332
286
pixel_values = load_image (jpg , max_num = 1 ).to (
333
287
torch .bfloat16 ).cuda () # [1, 3, 448, 448]
334
288
vit_embeds = vit_infer (pixel_values ) # [1, 256, 3072]
335
-
289
+ ID_IM_END = tokenizer .convert_tokens_to_ids ("<|im_end|>" )
290
+ ID_END = tokenizer .convert_tokens_to_ids ("<|end|>" )
336
291
token_len = len (ids )
337
292
ids = ids + (SEQ_LENGTH - token_len ) * [0 ]
338
293
input_ids = torch .tensor (ids ).view (SEQ_LENGTH ).cuda ()
@@ -362,7 +317,7 @@ def test_net_with_mask():
362
317
lm = LmHead ()
363
318
token = lm (out .bfloat16 ()).view (1 )
364
319
out_ids = [int (token )]
365
- while int (token ) < EOS_TOKEN_ID and token_len < SEQ_LENGTH :
320
+ while int (token ) not in [ ID_EOS , ID_IM_END , ID_END ] and token_len < SEQ_LENGTH :
366
321
token_len += 1
367
322
input_ids = torch .tensor ([token ]).cuda ()
368
323
out = embed (input_ids ).view (1 , 1 , HIDDEN_SIZE )
0 commit comments