Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add return to load_model function #39

Merged
merged 1 commit into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,71 @@ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor

# For image

autoseg_image = SegAutoMaskPredictor().image_predict(
results = SegAutoMaskPredictor().image_predict(
source="image.jpg",
model_type="vit_l", # vit_l, vit_h, vit_b
points_per_side=16,
points_per_batch=64,
min_area=0,
output_path="output.jpg",
show=True,
save=False,
)

# For video

autoseg_video = SegAutoMaskPredictor().video_predict(
results = SegAutoMaskPredictor().video_predict(
source="video.mp4",
model_type="vit_l", # vit_l, vit_h, vit_b
points_per_side=16,
points_per_batch=64,
min_area=1000,
output_path="output.mp4",
)

# For manuel box and point selection

seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
results = SegManualMaskPredictor().image_predict(
source="image.jpg",
model_type="vit_l", # vit_l, vit_h, vit_b
input_point=[[100, 100], [200, 200]],
input_label=[0, 1],
input_box=[100, 100, 200, 200], # x,y,w,h
input_box=[100, 100, 200, 200], # or [[100, 100, 200, 200], [100, 100, 200, 200]]
multimask_output=False,

random_color=False,
show=True,
save=False,
)
```

# For multi box selection
# SAHI + Segment Anything

seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
source="data/brain.png",
model_type="vit_l",
input_point=None,
input_label=None,
input_box= [[100, 100, 400, 400]],
multimask_output=False,
```python
image_path = "test.jpg"
boxes = sahi_sliced_predict(
image_path=image_path,
detection_model_type="yolov5", #yolov8, detectron2, mmdetection, torchvision
detection_model_path="yolov5l6.pt",
conf_th=0.25,
image_size=1280,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
)

SahiAutoSegmentation().save_image(
source=image_path,
model_type="vit_b",
input_box=boxes,
multimask_output=False,
random_color=False,
show=True,
save=False,
)
```
<img width="1000" alt="teaser" src="https://github.com/kadirnar/segment-anything-pip/releases/download/v0.5.0/sahi_autoseg.png">

# Extra Features

- [x] Support for Yolov5/8, Detectron2, Mmdetection, Torchvision models
Expand Down
2 changes: 1 addition & 1 deletion metaseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from metaseg.generator.predictor import SamPredictor
from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor

__version__ = "0.4.5"
__version__ = "0.5.0"
43 changes: 21 additions & 22 deletions metaseg/mask_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
from tqdm import tqdm

from metaseg import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from metaseg.utils import download_model, load_box, load_image, load_mask, load_video, multi_boxes
from metaseg.utils import download_model, load_box, load_image, load_mask, load_video, multi_boxes,show_image, save_image


class SegAutoMaskPredictor:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.save = False
self.show = False

def load_model(self, model_type):
if self.model is None:
Expand All @@ -24,7 +22,7 @@ def load_model(self, model_type):

return self.model

def image_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png"):
def image_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png", show=False, save=False):
read_image = load_image(source)
model = self.load_model(model_type)
mask_generator = SamAutomaticMaskGenerator(
Expand All @@ -49,15 +47,15 @@ def image_predict(self, source, model_type, points_per_side, points_per_batch, m
mask_image = cv2.add(mask_image, img)

combined_mask = cv2.add(read_image, mask_image)
if self.save:
cv2.imwrite(output_path, combined_mask)

if self.show:
cv2.imshow("Output", combined_mask)
cv2.waitKey(0)
cv2.destroyAllWindows()

return output_path
self.combined_mask = combined_mask
if show:
show_image(combined_mask)
if save:
save_image(output_path=output_path, image=combined_mask)

return masks


def video_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.mp4"):
cap, out = load_video(source, output_path)
Expand Down Expand Up @@ -128,6 +126,9 @@ def image_predict(
input_label=None,
multimask_output=False,
output_path="output.png",
random_color=False,
show=False,
save=False,
):
image = load_image(source)
model = self.load_model(model_type)
Expand All @@ -144,7 +145,7 @@ def image_predict(
multimask_output=False,
)
for mask in masks:
mask_image = load_mask(mask.cpu().numpy(), False)
mask_image = load_mask(mask.cpu().numpy(), random_color)

for box in input_boxes:
image = load_box(box.cpu().numpy(), image)
Expand All @@ -158,16 +159,14 @@ def image_predict(
box=input_boxes,
multimask_output=multimask_output,
)
mask_image = load_mask(masks, True)
mask_image = load_mask(masks, random_color)
image = load_box(input_box, image)

combined_mask = cv2.add(image, mask_image)
if self.save:
cv2.imwrite(output_path, combined_mask)
if save:
save_image(output_path=output_path, image=combined_mask)

if self.show:
cv2.imshow("Output", combined_mask)
cv2.waitKey(0)
cv2.destroyAllWindows()
if show:
show_image(combined_mask)

return output_path
return masks
19 changes: 14 additions & 5 deletions metaseg/sahi_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from metaseg.utils import download_model, load_image, multi_boxes, plt_load_box, plt_load_mask


def sahi_predict(
def sahi_sliced_predict(
image_path,
detection_model_type,
detection_model_path,
Expand Down Expand Up @@ -51,7 +51,7 @@ def sahi_predict(
return boxes


class SahiPredictor:
class SahiAutoSegmentation:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -62,15 +62,21 @@ def load_model(self, model_type):
self.model = sam_model_registry[model_type](checkpoint=self.model_path)
self.model.to(device=self.device)

def save_image(
return self.model

def predict(
self,
source,
model_type,
input_box=None,
input_point=None,
input_label=None,
multimask_output=False,
random_color=False,
save=False,
show=True,
):

read_image = load_image(source)
model = self.load_model(model_type)
predictor = SamPredictor(model)
Expand Down Expand Up @@ -99,8 +105,11 @@ def save_image(
plt.figure(figsize=(10, 10))
plt.imshow(read_image)
for mask in masks:
plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=random_color)
for box in input_boxes:
plt_load_box(box.cpu().numpy(), plt.gca())
plt.axis("off")
plt.show()
if save:
plt.savefig("output.png")
if show:
plt.show()
11 changes: 11 additions & 0 deletions metaseg/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ def multi_boxes(boxes, predictor, image):
input_boxes = torch.tensor(boxes, device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
return input_boxes, transformed_boxes

def show_image(output_image):
import cv2

cv2.imshow("output", output_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

def save_image(output_image, output_path):
import cv2
cv2.imwrite(output_path, output_image)