|
8 | 8 | "outputs": [],
|
9 | 9 | "source": [
|
10 | 10 | "import torch\n",
|
11 |
| - "from models.yolo import Model\n" |
| 11 | + "from models.yolo import Model" |
| 12 | + ] |
| 13 | + }, |
| 14 | + { |
| 15 | + "cell_type": "markdown", |
| 16 | + "id": "8680f822", |
| 17 | + "metadata": {}, |
| 18 | + "source": [ |
| 19 | + "## Convert YOLOv9-C" |
12 | 20 | ]
|
13 | 21 | },
|
14 | 22 | {
|
|
97 | 105 | " 'epoch': -1}\n",
|
98 | 106 | "torch.save(m_ckpt, \"./yolov9-c-converted.pt\")"
|
99 | 107 | ]
|
| 108 | + }, |
| 109 | + { |
| 110 | + "cell_type": "markdown", |
| 111 | + "id": "47c6e6ae", |
| 112 | + "metadata": {}, |
| 113 | + "source": [ |
| 114 | + "## Convert YOLOv9-E" |
| 115 | + ] |
| 116 | + }, |
| 117 | + { |
| 118 | + "cell_type": "code", |
| 119 | + "execution_count": null, |
| 120 | + "id": "801a1b7c", |
| 121 | + "metadata": {}, |
| 122 | + "outputs": [], |
| 123 | + "source": [ |
| 124 | + "device = torch.device(\"cpu\")\n", |
| 125 | + "cfg = \"./models/detect/gelan-e.yaml\"\n", |
| 126 | + "model = Model(cfg, ch=3, nc=80, anchors=3)\n", |
| 127 | + "#model = model.half()\n", |
| 128 | + "model = model.to(device)\n", |
| 129 | + "_ = model.eval()\n", |
| 130 | + "ckpt = torch.load('./yolov9-e.pt', map_location='cpu')\n", |
| 131 | + "model.names = ckpt['model'].names\n", |
| 132 | + "model.nc = ckpt['model'].nc" |
| 133 | + ] |
| 134 | + }, |
| 135 | + { |
| 136 | + "cell_type": "code", |
| 137 | + "execution_count": null, |
| 138 | + "id": "a2ef4fe6", |
| 139 | + "metadata": {}, |
| 140 | + "outputs": [], |
| 141 | + "source": [ |
| 142 | + "idx = 0\n", |
| 143 | + "for k, v in model.state_dict().items():\n", |
| 144 | + " if \"model.{}.\".format(idx) in k:\n", |
| 145 | + " if idx < 29:\n", |
| 146 | + " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n", |
| 147 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 148 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 149 | + " print(k, \"perfectly matched!!\")\n", |
| 150 | + " elif idx < 42:\n", |
| 151 | + " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n", |
| 152 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 153 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 154 | + " print(k, \"perfectly matched!!\")\n", |
| 155 | + " elif \"model.{}.cv2.\".format(idx) in k:\n", |
| 156 | + " kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n", |
| 157 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 158 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 159 | + " print(k, \"perfectly matched!!\")\n", |
| 160 | + " elif \"model.{}.cv3.\".format(idx) in k:\n", |
| 161 | + " kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n", |
| 162 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 163 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 164 | + " print(k, \"perfectly matched!!\")\n", |
| 165 | + " elif \"model.{}.dfl.\".format(idx) in k:\n", |
| 166 | + " kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n", |
| 167 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 168 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 169 | + " print(k, \"perfectly matched!!\")\n", |
| 170 | + " else:\n", |
| 171 | + " while True:\n", |
| 172 | + " idx += 1\n", |
| 173 | + " if \"model.{}.\".format(idx) in k:\n", |
| 174 | + " break\n", |
| 175 | + " if idx < 29:\n", |
| 176 | + " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n", |
| 177 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 178 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 179 | + " print(k, \"perfectly matched!!\")\n", |
| 180 | + " elif idx < 42:\n", |
| 181 | + " kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n", |
| 182 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 183 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 184 | + " print(k, \"perfectly matched!!\")\n", |
| 185 | + " elif \"model.{}.cv2.\".format(idx) in k:\n", |
| 186 | + " kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n", |
| 187 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 188 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 189 | + " print(k, \"perfectly matched!!\")\n", |
| 190 | + " elif \"model.{}.cv3.\".format(idx) in k:\n", |
| 191 | + " kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n", |
| 192 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 193 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 194 | + " print(k, \"perfectly matched!!\")\n", |
| 195 | + " elif \"model.{}.dfl.\".format(idx) in k:\n", |
| 196 | + " kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n", |
| 197 | + " model.state_dict()[k] -= model.state_dict()[k]\n", |
| 198 | + " model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n", |
| 199 | + " print(k, \"perfectly matched!!\")\n", |
| 200 | + "_ = model.eval()" |
| 201 | + ] |
| 202 | + }, |
| 203 | + { |
| 204 | + "cell_type": "code", |
| 205 | + "execution_count": null, |
| 206 | + "id": "27bc1869", |
| 207 | + "metadata": {}, |
| 208 | + "outputs": [], |
| 209 | + "source": [ |
| 210 | + "m_ckpt = {'model': model.half(),\n", |
| 211 | + " 'optimizer': None,\n", |
| 212 | + " 'best_fitness': None,\n", |
| 213 | + " 'ema': None,\n", |
| 214 | + " 'updates': None,\n", |
| 215 | + " 'opt': None,\n", |
| 216 | + " 'git': None,\n", |
| 217 | + " 'date': None,\n", |
| 218 | + " 'epoch': -1}\n", |
| 219 | + "torch.save(m_ckpt, \"./yolov9-e-converted.pt\")" |
| 220 | + ] |
100 | 221 | }
|
101 | 222 | ],
|
102 | 223 | "metadata": {
|
|
0 commit comments