22
22
ROOT = Path (os .path .relpath (ROOT , Path .cwd ())) # relative
23
23
24
24
from models .experimental import attempt_load
25
- from models .yolo import ClassificationModel , Detect , DetectionModel , SegmentationModel
25
+ from models .yolo import ClassificationModel , Detect , DDetect , DualDetect , DualDDetect , DetectionModel , SegmentationModel
26
26
from utils .dataloaders import LoadImages
27
27
from utils .general import (LOGGER , Profile , check_dataset , check_img_size , check_requirements , check_version ,
28
28
check_yaml , colorstr , file_size , get_default_args , print_args , url2file , yaml_save )
@@ -494,7 +494,7 @@ def run(
494
494
# Update model
495
495
model .eval ()
496
496
for k , m in model .named_modules ():
497
- if isinstance (m , (Detect , V6Detect )):
497
+ if isinstance (m , (Detect , DDetect , DualDetect , DualDDetect )):
498
498
m .inplace = inplace
499
499
m .dynamic = dynamic
500
500
m .export = True
@@ -503,7 +503,7 @@ def run(
503
503
y = model (im ) # dry runs
504
504
if half and not coreml :
505
505
im , model = im .half (), model .half () # to FP16
506
- shape = tuple ((y [0 ] if isinstance (y , tuple ) else y ).shape ) # model output shape
506
+ shape = tuple ((y [0 ] if isinstance (y , ( tuple , list ) ) else y ).shape ) # model output shape
507
507
metadata = {'stride' : int (max (model .stride )), 'names' : model .names } # model metadata
508
508
LOGGER .info (f"\n { colorstr ('PyTorch:' )} starting from { file } with output shape { shape } ({ file_size (file ):.1f} MB)" )
509
509
0 commit comments