mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 12:22:38 +00:00
fix: torch.backends check (#280)
This commit is contained in:
parent
67d567f38c
commit
357226b2e7
@ -32,10 +32,13 @@ class LivePortraitWrapper(object):
|
|||||||
if inference_cfg.flag_force_cpu:
|
if inference_cfg.flag_force_cpu:
|
||||||
self.device = 'cpu'
|
self.device = 'cpu'
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
self.device = 'mps'
|
self.device = 'mps'
|
||||||
else:
|
else:
|
||||||
self.device = 'cuda:' + str(self.device_id)
|
self.device = 'cuda:' + str(self.device_id)
|
||||||
|
except:
|
||||||
|
self.device = 'cuda:' + str(self.device_id)
|
||||||
|
|
||||||
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||||
# init F
|
# init F
|
||||||
@ -344,10 +347,13 @@ class LivePortraitWrapperAnimal(LivePortraitWrapper):
|
|||||||
if inference_cfg.flag_force_cpu:
|
if inference_cfg.flag_force_cpu:
|
||||||
self.device = 'cpu'
|
self.device = 'cpu'
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
self.device = 'mps'
|
self.device = 'mps'
|
||||||
else:
|
else:
|
||||||
self.device = 'cuda:' + str(self.device_id)
|
self.device = 'cuda:' + str(self.device_id)
|
||||||
|
except:
|
||||||
|
self.device = 'cuda:' + str(self.device_id)
|
||||||
|
|
||||||
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||||
# init F
|
# init F
|
||||||
|
@ -48,6 +48,7 @@ class Cropper(object):
|
|||||||
device = "cpu"
|
device = "cpu"
|
||||||
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
|
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
# Shape inference currently fails with CoreMLExecutionProvider
|
# Shape inference currently fails with CoreMLExecutionProvider
|
||||||
# for the retinaface model
|
# for the retinaface model
|
||||||
@ -56,7 +57,9 @@ class Cropper(object):
|
|||||||
else:
|
else:
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
|
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
|
||||||
|
except:
|
||||||
|
device = "cuda"
|
||||||
|
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
|
||||||
self.face_analysis_wrapper = FaceAnalysisDIY(
|
self.face_analysis_wrapper = FaceAnalysisDIY(
|
||||||
name="buffalo_l",
|
name="buffalo_l",
|
||||||
root=self.crop_cfg.insightface_root,
|
root=self.crop_cfg.insightface_root,
|
||||||
|
Loading…
Reference in New Issue
Block a user