@@ -313,18 +313,20 @@ def __init__(
313
313
:type ontology_fname: str
314
314
"""
315
315
# Get device (CPU or GPU)
316
- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
316
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else
317
+ "mps" if torch .backends .mps .is_available () else
318
+ "cpu" )
317
319
318
320
# If 'model' contains a string, check that it is a valid filename and load model
319
321
if isinstance (model , str ):
320
322
assert os .path .isfile (model ), "TorchScript Model file not found"
321
323
model_fname = model
322
324
try :
323
- model = torch .jit .load (model )
325
+ model = torch .jit .load (model , map_location = self . device )
324
326
model_type = "compiled"
325
327
except :
326
328
print ("Model is not a TorchScript model. Loading as a PyTorch module." )
327
- model = torch .load (model )
329
+ model = torch .load (model , map_location = self . device )
328
330
model_type = "native"
329
331
# Otherwise, check that it is a PyTorch module
330
332
elif isinstance (model , torch .nn .Module ):
@@ -587,18 +589,20 @@ def __init__(
587
589
:type ontology_fname: str
588
590
"""
589
591
# Get device (CPU or GPU)
590
- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
592
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else
593
+ "mps" if torch .backends .mps .is_available () else
594
+ "cpu" )
591
595
592
596
# If 'model' contains a string, check that it is a valid filename and load model
593
597
if isinstance (model , str ):
594
598
assert os .path .isfile (model ), "TorchScript Model file not found"
595
599
model_fname = model
596
600
try :
597
- model = torch .jit .load (model )
601
+ model = torch .jit .load (model , map_location = self . device )
598
602
model_type = "compiled"
599
603
except Exception :
600
604
print ("Model is not a TorchScript model. Loading as a PyTorch module." )
601
- model = torch .load (model )
605
+ model = torch .load (model , map_location = self . device )
602
606
model_type = "native"
603
607
# Otherwise, check that it is a PyTorch module
604
608
elif isinstance (model , torch .nn .Module ):
0 commit comments