Skip to content

Commit 3c32c53

Browse files
authored
Update export.py
1 parent 60bcdfe commit 3c32c53

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

export.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
2323

2424
from models.experimental import attempt_load
25-
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
25+
from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel
2626
from utils.dataloaders import LoadImages
2727
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
2828
check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
@@ -494,7 +494,7 @@ def run(
494494
# Update model
495495
model.eval()
496496
for k, m in model.named_modules():
497-
if isinstance(m, (Detect, V6Detect)):
497+
if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)):
498498
m.inplace = inplace
499499
m.dynamic = dynamic
500500
m.export = True
@@ -503,7 +503,7 @@ def run(
503503
y = model(im) # dry runs
504504
if half and not coreml:
505505
im, model = im.half(), model.half() # to FP16
506-
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
506+
shape = tuple((y[0] if isinstance(y, (tuple, list)) else y).shape) # model output shape
507507
metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
508508
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
509509

0 commit comments

Comments
 (0)