Skip to content

Commit ef8942f

Browse files
authored
Merge pull request #279 from SakhinetiPraveena/master
To support GPU acceleration using MPS in MacOS
2 parents cb7564f + 883f48e commit ef8942f

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

detectionmetrics/models/torch.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -313,18 +313,20 @@ def __init__(
313313
:type ontology_fname: str
314314
"""
315315
# Get device (CPU or GPU)
316-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
316+
self.device = torch.device("cuda" if torch.cuda.is_available() else
317+
"mps" if torch.backends.mps.is_available() else
318+
"cpu")
317319

318320
# If 'model' contains a string, check that it is a valid filename and load model
319321
if isinstance(model, str):
320322
assert os.path.isfile(model), "TorchScript Model file not found"
321323
model_fname = model
322324
try:
323-
model = torch.jit.load(model)
325+
model = torch.jit.load(model, map_location=self.device)
324326
model_type = "compiled"
325327
except:
326328
print("Model is not a TorchScript model. Loading as a PyTorch module.")
327-
model = torch.load(model)
329+
model = torch.load(model, map_location=self.device)
328330
model_type = "native"
329331
# Otherwise, check that it is a PyTorch module
330332
elif isinstance(model, torch.nn.Module):
@@ -587,18 +589,20 @@ def __init__(
587589
:type ontology_fname: str
588590
"""
589591
# Get device (CPU or GPU)
590-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
592+
self.device = torch.device("cuda" if torch.cuda.is_available() else
593+
"mps" if torch.backends.mps.is_available() else
594+
"cpu")
591595

592596
# If 'model' contains a string, check that it is a valid filename and load model
593597
if isinstance(model, str):
594598
assert os.path.isfile(model), "TorchScript Model file not found"
595599
model_fname = model
596600
try:
597-
model = torch.jit.load(model)
601+
model = torch.jit.load(model, map_location=self.device)
598602
model_type = "compiled"
599603
except Exception:
600604
print("Model is not a TorchScript model. Loading as a PyTorch module.")
601-
model = torch.load(model)
605+
model = torch.load(model, map_location=self.device)
602606
model_type = "native"
603607
# Otherwise, check that it is a PyTorch module
604608
elif isinstance(model, torch.nn.Module):

0 commit comments

Comments
 (0)