From 357226b2e7e278dcf4f2b83dbb824d9d075742f8 Mon Sep 17 00:00:00 2001 From: ZhizhouZhong <1819489045@qq.com> Date: Mon, 5 Aug 2024 14:17:17 +0800 Subject: [PATCH] fix: torch.backends check (#280) --- src/live_portrait_wrapper.py | 20 +++++++++++++------- src/utils/cropper.py | 21 ++++++++++++--------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py index a944176..af95812 100644 --- a/src/live_portrait_wrapper.py +++ b/src/live_portrait_wrapper.py @@ -32,9 +32,12 @@ class LivePortraitWrapper(object): if inference_cfg.flag_force_cpu: self.device = 'cpu' else: - if torch.backends.mps.is_available(): - self.device = 'mps' - else: + try: + if torch.backends.mps.is_available(): + self.device = 'mps' + else: + self.device = 'cuda:' + str(self.device_id) + except: self.device = 'cuda:' + str(self.device_id) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) @@ -344,10 +347,13 @@ class LivePortraitWrapperAnimal(LivePortraitWrapper): if inference_cfg.flag_force_cpu: self.device = 'cpu' else: - if torch.backends.mps.is_available(): - self.device = 'mps' - else: - self.device = 'cuda:' + str(self.device_id) + try: + if torch.backends.mps.is_available(): + self.device = 'mps' + else: + self.device = 'cuda:' + str(self.device_id) + except: + self.device = 'cuda:' + str(self.device_id) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) # init F diff --git a/src/utils/cropper.py b/src/utils/cropper.py index 8b3ac6a..97e26c7 100644 --- a/src/utils/cropper.py +++ b/src/utils/cropper.py @@ -48,15 +48,18 @@ class Cropper(object): device = "cpu" face_analysis_wrapper_provider = ["CPUExecutionProvider"] else: - if torch.backends.mps.is_available(): - # Shape inference currently fails with CoreMLExecutionProvider - # for the retinaface model - device = "mps" - face_analysis_wrapper_provider = ["CPUExecutionProvider"] - else: - device = "cuda" - face_analysis_wrapper_provider = ["CUDAExecutionProvider"] - + try: + if torch.backends.mps.is_available(): + # Shape inference currently fails with CoreMLExecutionProvider + # for the retinaface model + device = "mps" + face_analysis_wrapper_provider = ["CPUExecutionProvider"] + else: + device = "cuda" + face_analysis_wrapper_provider = ["CUDAExecutionProvider"] + except: + device = "cuda" + face_analysis_wrapper_provider = ["CUDAExecutionProvider"] self.face_analysis_wrapper = FaceAnalysisDIY( name="buffalo_l", root=self.crop_cfg.insightface_root,