fix: torch.backends check (#280)

This commit is contained in:
ZhizhouZhong 2024-08-05 14:17:17 +08:00 committed by GitHub
parent 67d567f38c
commit 357226b2e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 16 deletions

View File

@ -32,9 +32,12 @@ class LivePortraitWrapper(object):
if inference_cfg.flag_force_cpu: if inference_cfg.flag_force_cpu:
self.device = 'cpu' self.device = 'cpu'
else: else:
if torch.backends.mps.is_available(): try:
self.device = 'mps' if torch.backends.mps.is_available():
else: self.device = 'mps'
else:
self.device = 'cuda:' + str(self.device_id)
except:
self.device = 'cuda:' + str(self.device_id) self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
@ -344,10 +347,13 @@ class LivePortraitWrapperAnimal(LivePortraitWrapper):
if inference_cfg.flag_force_cpu: if inference_cfg.flag_force_cpu:
self.device = 'cpu' self.device = 'cpu'
else: else:
if torch.backends.mps.is_available(): try:
self.device = 'mps' if torch.backends.mps.is_available():
else: self.device = 'mps'
self.device = 'cuda:' + str(self.device_id) else:
self.device = 'cuda:' + str(self.device_id)
except:
self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F # init F

View File

@ -48,15 +48,18 @@ class Cropper(object):
device = "cpu" device = "cpu"
face_analysis_wrapper_provider = ["CPUExecutionProvider"] face_analysis_wrapper_provider = ["CPUExecutionProvider"]
else: else:
if torch.backends.mps.is_available(): try:
# Shape inference currently fails with CoreMLExecutionProvider if torch.backends.mps.is_available():
# for the retinaface model # Shape inference currently fails with CoreMLExecutionProvider
device = "mps" # for the retinaface model
face_analysis_wrapper_provider = ["CPUExecutionProvider"] device = "mps"
else: face_analysis_wrapper_provider = ["CPUExecutionProvider"]
device = "cuda" else:
face_analysis_wrapper_provider = ["CUDAExecutionProvider"] device = "cuda"
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
except:
device = "cuda"
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
self.face_analysis_wrapper = FaceAnalysisDIY( self.face_analysis_wrapper = FaceAnalysisDIY(
name="buffalo_l", name="buffalo_l",
root=self.crop_cfg.insightface_root, root=self.crop_cfg.insightface_root,