diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py index a944176..af95812 100644 --- a/src/live_portrait_wrapper.py +++ b/src/live_portrait_wrapper.py @@ -32,9 +32,12 @@ class LivePortraitWrapper(object): if inference_cfg.flag_force_cpu: self.device = 'cpu' else: - if torch.backends.mps.is_available(): - self.device = 'mps' - else: + try: + if torch.backends.mps.is_available(): + self.device = 'mps' + else: + self.device = 'cuda:' + str(self.device_id) + except: self.device = 'cuda:' + str(self.device_id) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) @@ -344,10 +347,13 @@ class LivePortraitWrapperAnimal(LivePortraitWrapper): if inference_cfg.flag_force_cpu: self.device = 'cpu' else: - if torch.backends.mps.is_available(): - self.device = 'mps' - else: - self.device = 'cuda:' + str(self.device_id) + try: + if torch.backends.mps.is_available(): + self.device = 'mps' + else: + self.device = 'cuda:' + str(self.device_id) + except: + self.device = 'cuda:' + str(self.device_id) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) # init F diff --git a/src/utils/cropper.py b/src/utils/cropper.py index 8b3ac6a..97e26c7 100644 --- a/src/utils/cropper.py +++ b/src/utils/cropper.py @@ -48,15 +48,18 @@ class Cropper(object): device = "cpu" face_analysis_wrapper_provider = ["CPUExecutionProvider"] else: - if torch.backends.mps.is_available(): - # Shape inference currently fails with CoreMLExecutionProvider - # for the retinaface model - device = "mps" - face_analysis_wrapper_provider = ["CPUExecutionProvider"] - else: - device = "cuda" - face_analysis_wrapper_provider = ["CUDAExecutionProvider"] - + try: + if torch.backends.mps.is_available(): + # Shape inference currently fails with CoreMLExecutionProvider + # for the retinaface model + device = "mps" + face_analysis_wrapper_provider = ["CPUExecutionProvider"] + else: + device = "cuda" + face_analysis_wrapper_provider = ["CUDAExecutionProvider"] + except: + device = "cuda" + face_analysis_wrapper_provider = ["CUDAExecutionProvider"] self.face_analysis_wrapper = FaceAnalysisDIY( name="buffalo_l", root=self.crop_cfg.insightface_root,