Skip to content

Commit 3f27231

Browse files
authored
Add files via upload
1 parent 2ec64d3 commit 3f27231

File tree

1 file changed

+122
-1
lines changed

1 file changed

+122
-1
lines changed

tools/reparameterization.ipynb

+122-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
"outputs": [],
99
"source": [
1010
"import torch\n",
11-
"from models.yolo import Model\n"
11+
"from models.yolo import Model"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"id": "8680f822",
17+
"metadata": {},
18+
"source": [
19+
"## Convert YOLOv9-C"
1220
]
1321
},
1422
{
@@ -97,6 +105,119 @@
97105
" 'epoch': -1}\n",
98106
"torch.save(m_ckpt, \"./yolov9-c-converted.pt\")"
99107
]
108+
},
109+
{
110+
"cell_type": "markdown",
111+
"id": "47c6e6ae",
112+
"metadata": {},
113+
"source": [
114+
"## Convert YOLOv9-E"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"id": "801a1b7c",
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"device = torch.device(\"cpu\")\n",
125+
"cfg = \"./models/detect/gelan-e.yaml\"\n",
126+
"model = Model(cfg, ch=3, nc=80, anchors=3)\n",
127+
"#model = model.half()\n",
128+
"model = model.to(device)\n",
129+
"_ = model.eval()\n",
130+
"ckpt = torch.load('./yolov9-e.pt', map_location='cpu')\n",
131+
"model.names = ckpt['model'].names\n",
132+
"model.nc = ckpt['model'].nc"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"id": "a2ef4fe6",
139+
"metadata": {},
140+
"outputs": [],
141+
"source": [
142+
"idx = 0\n",
143+
"for k, v in model.state_dict().items():\n",
144+
" if \"model.{}.\".format(idx) in k:\n",
145+
" if idx < 29:\n",
146+
" kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n",
147+
" model.state_dict()[k] -= model.state_dict()[k]\n",
148+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
149+
" print(k, \"perfectly matched!!\")\n",
150+
" elif idx < 42:\n",
151+
" kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n",
152+
" model.state_dict()[k] -= model.state_dict()[k]\n",
153+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
154+
" print(k, \"perfectly matched!!\")\n",
155+
" elif \"model.{}.cv2.\".format(idx) in k:\n",
156+
" kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n",
157+
" model.state_dict()[k] -= model.state_dict()[k]\n",
158+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
159+
" print(k, \"perfectly matched!!\")\n",
160+
" elif \"model.{}.cv3.\".format(idx) in k:\n",
161+
" kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n",
162+
" model.state_dict()[k] -= model.state_dict()[k]\n",
163+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
164+
" print(k, \"perfectly matched!!\")\n",
165+
" elif \"model.{}.dfl.\".format(idx) in k:\n",
166+
" kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n",
167+
" model.state_dict()[k] -= model.state_dict()[k]\n",
168+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
169+
" print(k, \"perfectly matched!!\")\n",
170+
" else:\n",
171+
" while True:\n",
172+
" idx += 1\n",
173+
" if \"model.{}.\".format(idx) in k:\n",
174+
" break\n",
175+
" if idx < 29:\n",
176+
" kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n",
177+
" model.state_dict()[k] -= model.state_dict()[k]\n",
178+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
179+
" print(k, \"perfectly matched!!\")\n",
180+
" elif idx < 42:\n",
181+
" kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n",
182+
" model.state_dict()[k] -= model.state_dict()[k]\n",
183+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
184+
" print(k, \"perfectly matched!!\")\n",
185+
" elif \"model.{}.cv2.\".format(idx) in k:\n",
186+
" kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n",
187+
" model.state_dict()[k] -= model.state_dict()[k]\n",
188+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
189+
" print(k, \"perfectly matched!!\")\n",
190+
" elif \"model.{}.cv3.\".format(idx) in k:\n",
191+
" kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n",
192+
" model.state_dict()[k] -= model.state_dict()[k]\n",
193+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
194+
" print(k, \"perfectly matched!!\")\n",
195+
" elif \"model.{}.dfl.\".format(idx) in k:\n",
196+
" kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n",
197+
" model.state_dict()[k] -= model.state_dict()[k]\n",
198+
" model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
199+
" print(k, \"perfectly matched!!\")\n",
200+
"_ = model.eval()"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": null,
206+
"id": "27bc1869",
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"m_ckpt = {'model': model.half(),\n",
211+
" 'optimizer': None,\n",
212+
" 'best_fitness': None,\n",
213+
" 'ema': None,\n",
214+
" 'updates': None,\n",
215+
" 'opt': None,\n",
216+
" 'git': None,\n",
217+
" 'date': None,\n",
218+
" 'epoch': -1}\n",
219+
"torch.save(m_ckpt, \"./yolov9-e-converted.pt\")"
220+
]
100221
}
101222
],
102223
"metadata": {

0 commit comments

Comments
 (0)