Skip to content

Commit

Permalink
Merge pull request #73 from kadirnar/onnx-import-package-req-formatti…
Browse files Browse the repository at this point in the history
…ng-fixes

onnx script fix and added as main package
  • Loading branch information
onuralpszr authored Jun 26, 2023
2 parents cdf69a5 + 034c660 commit 6d86650
Show file tree
Hide file tree
Showing 8 changed files with 1,235 additions and 44 deletions.
2 changes: 1 addition & 1 deletion metaseg/sahi_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from PIL import Image

from metaseg import SamPredictor, sam_model_registry
from metaseg.generator import SamPredictor, sam_model_registry
from metaseg.utils import (
download_model,
load_image,
Expand Down
1 change: 1 addition & 0 deletions metaseg/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .data_utils import load_image as load_image
from .data_utils import load_mask as load_mask
from .data_utils import load_server_image as load_server_image
from .data_utils import load_video as load_video
from .data_utils import multi_boxes as multi_boxes
from .data_utils import plt_load_box as plt_load_box
from .data_utils import plt_load_mask as plt_load_mask
Expand Down
590 changes: 586 additions & 4 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pillow = "^9.5.0"
pycocotools = "^2.0.6"
onnx = "^1.14.0"
onnxruntime = "^1.15.1"
fal-serverless = "^0.6.35"



Expand Down
283 changes: 278 additions & 5 deletions requirements-dev.txt

Large diffs are not rendered by default.

354 changes: 349 additions & 5 deletions requirements.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion scripts/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
from typing import Any, Dict, List

import cv2 # type: ignore
import cv2

from metaseg.generator import SamAutomaticMaskGenerator, sam_model_registry

Expand Down
46 changes: 18 additions & 28 deletions scripts/export_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@
import warnings

import torch
from onnxruntime import InferenceSession
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

from metaseg import build_sam, build_sam_vit_b, build_sam_vit_l
from metaseg.generator import build_sam, build_sam_vit_b, build_sam_vit_l
from metaseg.utils.onnx import SamOnnxModel

try:
import onnxruntime # type: ignore

onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False

parser = argparse.ArgumentParser(
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
)
Expand Down Expand Up @@ -169,11 +165,10 @@ def run_export(
dynamic_axes=dynamic_axes,
)

if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
ort_session = onnxruntime.InferenceSession(output)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
ort_session = InferenceSession(output)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")


def to_numpy(tensor):
Expand All @@ -193,18 +188,13 @@ def to_numpy(tensor):
return_extra_metrics=args.return_extra_metrics,
)

if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore

print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")
print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")

0 comments on commit 6d86650

Please sign in to comment.