mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 12:22:38 +00:00
feat: support macOS with Apple Silicon (#155)
* feat: macOS support (#143) * Support for running on Apple Silicon Macs with MPS * Minor typo fix: s/provicer/provider/ * Another typo fix: s/concact/concat/ * s/cudaexecutionprovider/CUDAExecutionProvider/ * Add requirements_apple.txt * doc: macOS support * chore: refine the structure and doc * doc: update readme * doc: update readme * doc: update readme * doc: update readme --------- Co-authored-by: Jeethu Rao <jeethu@jeethurao.com> Co-authored-by: zzzweakman <1819489045@qq.com>
This commit is contained in:
parent
54e50986b2
commit
0f839844f6
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,6 +11,7 @@ __pycache__/
|
|||||||
|
|
||||||
pretrained_weights/*.md
|
pretrained_weights/*.md
|
||||||
pretrained_weights/docs
|
pretrained_weights/docs
|
||||||
|
pretrained_weights/liveportrait
|
||||||
|
|
||||||
# Ipython notebook
|
# Ipython notebook
|
||||||
*.ipynb
|
*.ipynb
|
||||||
@ -19,3 +20,4 @@ pretrained_weights/docs
|
|||||||
animations/*
|
animations/*
|
||||||
tmp/*
|
tmp/*
|
||||||
.vscode/launch.json
|
.vscode/launch.json
|
||||||
|
**/*.DS_Store
|
||||||
|
@ -42,8 +42,8 @@ def main():
|
|||||||
fast_check_args(args)
|
fast_check_args(args)
|
||||||
|
|
||||||
# specify configs for inference
|
# specify configs for inference
|
||||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
||||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
crop_cfg = partial_fields(CropConfig, args.__dict__)
|
||||||
|
|
||||||
live_portrait_pipeline = LivePortraitPipeline(
|
live_portrait_pipeline = LivePortraitPipeline(
|
||||||
inference_cfg=inference_cfg,
|
inference_cfg=inference_cfg,
|
||||||
|
27
readme.md
27
readme.md
@ -35,6 +35,7 @@
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Updates
|
## 🔥 Updates
|
||||||
|
- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
|
||||||
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
|
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
|
||||||
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
|
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
|
||||||
- **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned!
|
- **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned!
|
||||||
@ -55,11 +56,14 @@ cd LivePortrait
|
|||||||
# create env using conda
|
# create env using conda
|
||||||
conda create -n LivePortrait python==3.9.18
|
conda create -n LivePortrait python==3.9.18
|
||||||
conda activate LivePortrait
|
conda activate LivePortrait
|
||||||
# install dependencies with pip
|
|
||||||
|
# install dependencies with pip (for Linux and Windows)
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
# for macOS with Apple Silicon
|
||||||
|
pip install -r requirements_macOS.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/) installed!
|
**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/download.html) installed, including both `ffmpeg` and `ffprobe`!
|
||||||
|
|
||||||
### 2. Download pretrained weights
|
### 2. Download pretrained weights
|
||||||
|
|
||||||
@ -67,8 +71,10 @@ The easiest way to download the pretrained weights is from HuggingFace:
|
|||||||
```bash
|
```bash
|
||||||
# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
|
# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
|
||||||
git lfs install
|
git lfs install
|
||||||
# clone the weights
|
# clone and move the weights
|
||||||
git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights
|
git clone https://huggingface.co/KwaiVGI/liveportrait temp_pretrained_weights
|
||||||
|
mv temp_pretrained_weights/* pretrained_weights/
|
||||||
|
rm -rf temp_pretrained_weights
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
|
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
|
||||||
@ -96,7 +102,11 @@ pretrained_weights
|
|||||||
|
|
||||||
#### Fast hands-on
|
#### Fast hands-on
|
||||||
```bash
|
```bash
|
||||||
|
# For Linux and Windows
|
||||||
python inference.py
|
python inference.py
|
||||||
|
|
||||||
|
# For macOS with Apple Silicon, Intel not supported, this maybe 20x slower than RTX 4090
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py
|
||||||
```
|
```
|
||||||
|
|
||||||
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result.
|
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result.
|
||||||
@ -145,7 +155,11 @@ python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/
|
|||||||
We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
|
We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# For Linux and Windows:
|
||||||
python app.py
|
python app.py
|
||||||
|
|
||||||
|
# For macOS with Apple Silicon, Intel not supported, this maybe 20x slower than RTX 4090
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py
|
||||||
```
|
```
|
||||||
|
|
||||||
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
|
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
|
||||||
@ -155,7 +169,7 @@ You can specify the `--server_port`, `--share`, `--server_name` arguments to sat
|
|||||||
# enable torch.compile for faster inference
|
# enable torch.compile for faster inference
|
||||||
python app.py --flag_do_torch_compile
|
python app.py --flag_do_torch_compile
|
||||||
```
|
```
|
||||||
**Note**: This method has not been fully tested. e.g., on Windows.
|
**Note**: This method is not supported on Windows and macOS.
|
||||||
|
|
||||||
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
|
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
|
||||||
|
|
||||||
@ -163,6 +177,7 @@ python app.py --flag_do_torch_compile
|
|||||||
We have also provided a script to evaluate the inference speed of each module:
|
We have also provided a script to evaluate the inference speed of each module:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# For NVIDIA GPU
|
||||||
python speed.py
|
python speed.py
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -184,9 +199,9 @@ Discover the invaluable resources contributed by our community to enhance your L
|
|||||||
|
|
||||||
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai)
|
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai)
|
||||||
- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007)
|
- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007)
|
||||||
|
- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker)
|
||||||
- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch)
|
- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch)
|
||||||
- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph)
|
- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph)
|
||||||
- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker)
|
|
||||||
- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr)
|
- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr)
|
||||||
|
|
||||||
And many more amazing contributions from our community!
|
And many more amazing contributions from our community!
|
||||||
|
@ -1,22 +1,2 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
-r requirements_base.txt
|
||||||
torch==2.3.0
|
|
||||||
torchvision==0.18.0
|
|
||||||
torchaudio==2.3.0
|
|
||||||
|
|
||||||
numpy==1.26.4
|
|
||||||
pyyaml==6.0.1
|
|
||||||
opencv-python==4.10.0.84
|
|
||||||
scipy==1.13.1
|
|
||||||
imageio==2.34.2
|
|
||||||
lmdb==1.4.1
|
|
||||||
tqdm==4.66.4
|
|
||||||
rich==13.7.1
|
|
||||||
ffmpeg-python==0.2.0
|
|
||||||
onnxruntime-gpu==1.18.0
|
onnxruntime-gpu==1.18.0
|
||||||
onnx==1.16.1
|
|
||||||
scikit-image==0.24.0
|
|
||||||
albumentations==1.4.10
|
|
||||||
matplotlib==3.9.0
|
|
||||||
imageio-ffmpeg==0.5.1
|
|
||||||
tyro==0.8.5
|
|
||||||
gradio==4.37.1
|
|
||||||
|
21
requirements_base.txt
Normal file
21
requirements_base.txt
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
|
torch==2.3.0
|
||||||
|
torchvision==0.18.0
|
||||||
|
torchaudio==2.3.0
|
||||||
|
|
||||||
|
numpy==1.26.4
|
||||||
|
pyyaml==6.0.1
|
||||||
|
opencv-python==4.10.0.84
|
||||||
|
scipy==1.13.1
|
||||||
|
imageio==2.34.2
|
||||||
|
lmdb==1.4.1
|
||||||
|
tqdm==4.66.4
|
||||||
|
rich==13.7.1
|
||||||
|
ffmpeg-python==0.2.0
|
||||||
|
onnx==1.16.1
|
||||||
|
scikit-image==0.24.0
|
||||||
|
albumentations==1.4.10
|
||||||
|
matplotlib==3.9.0
|
||||||
|
imageio-ffmpeg==0.5.1
|
||||||
|
tyro==0.8.5
|
||||||
|
gradio==4.37.1
|
2
requirements_macOS.txt
Normal file
2
requirements_macOS.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
-r requirements_base.txt
|
||||||
|
onnxruntime-silicon==1.16.3
|
@ -216,14 +216,14 @@ class LivePortraitPipeline(object):
|
|||||||
wfp_concat = None
|
wfp_concat = None
|
||||||
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
|
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
|
||||||
|
|
||||||
######### build final concact result #########
|
######### build final concat result #########
|
||||||
# driving frame | source image | generation, or source image | generation
|
# driving frame | source image | generation, or source image | generation
|
||||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
|
||||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
||||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||||
|
|
||||||
if flag_has_audio:
|
if flag_has_audio:
|
||||||
# final result with concact
|
# final result with concat
|
||||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
|
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
|
||||||
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
|
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
|
||||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||||
@ -247,7 +247,7 @@ class LivePortraitPipeline(object):
|
|||||||
if wfp_template not in (None, ''):
|
if wfp_template not in (None, ''):
|
||||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||||
log(f'Animated video: {wfp}')
|
log(f'Animated video: {wfp}')
|
||||||
log(f'Animated video with concact: {wfp_concat}')
|
log(f'Animated video with concat: {wfp_concat}')
|
||||||
|
|
||||||
return wfp, wfp_concat
|
return wfp, wfp_concat
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
Wrapper for LivePortrait core functions
|
Wrapper for LivePortrait core functions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
@ -28,7 +29,10 @@ class LivePortraitWrapper(object):
|
|||||||
if inference_cfg.flag_force_cpu:
|
if inference_cfg.flag_force_cpu:
|
||||||
self.device = 'cpu'
|
self.device = 'cpu'
|
||||||
else:
|
else:
|
||||||
self.device = 'cuda:' + str(self.device_id)
|
if torch.backends.mps.is_available():
|
||||||
|
self.device = 'mps'
|
||||||
|
else:
|
||||||
|
self.device = 'cuda:' + str(self.device_id)
|
||||||
|
|
||||||
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||||
# init F
|
# init F
|
||||||
@ -57,6 +61,14 @@ class LivePortraitWrapper(object):
|
|||||||
|
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
|
|
||||||
|
def inference_ctx(self):
|
||||||
|
if self.device == "mps":
|
||||||
|
ctx = contextlib.nullcontext()
|
||||||
|
else:
|
||||||
|
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
|
||||||
|
enabled=self.inference_cfg.flag_use_half_precision)
|
||||||
|
return ctx
|
||||||
|
|
||||||
def update_config(self, user_args):
|
def update_config(self, user_args):
|
||||||
for k, v in user_args.items():
|
for k, v in user_args.items():
|
||||||
if hasattr(self.inference_cfg, k):
|
if hasattr(self.inference_cfg, k):
|
||||||
@ -105,9 +117,8 @@ class LivePortraitWrapper(object):
|
|||||||
""" get the appearance feature of the image by F
|
""" get the appearance feature of the image by F
|
||||||
x: Bx3xHxW, normalized to 0~1
|
x: Bx3xHxW, normalized to 0~1
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad(), self.inference_ctx():
|
||||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
feature_3d = self.appearance_feature_extractor(x)
|
||||||
feature_3d = self.appearance_feature_extractor(x)
|
|
||||||
|
|
||||||
return feature_3d.float()
|
return feature_3d.float()
|
||||||
|
|
||||||
@ -117,9 +128,8 @@ class LivePortraitWrapper(object):
|
|||||||
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
|
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
|
||||||
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad(), self.inference_ctx():
|
||||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
kp_info = self.motion_extractor(x)
|
||||||
kp_info = self.motion_extractor(x)
|
|
||||||
|
|
||||||
if self.inference_cfg.flag_use_half_precision:
|
if self.inference_cfg.flag_use_half_precision:
|
||||||
# float the dict
|
# float the dict
|
||||||
@ -264,15 +274,14 @@ class LivePortraitWrapper(object):
|
|||||||
kp_driving: BxNx3
|
kp_driving: BxNx3
|
||||||
"""
|
"""
|
||||||
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
||||||
with torch.no_grad():
|
with torch.no_grad(), self.inference_ctx():
|
||||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
if self.compile:
|
||||||
if self.compile:
|
# Mark the beginning of a new CUDA Graph step
|
||||||
# Mark the beginning of a new CUDA Graph step
|
torch.compiler.cudagraph_mark_step_begin()
|
||||||
torch.compiler.cudagraph_mark_step_begin()
|
# get decoder input
|
||||||
# get decoder input
|
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
||||||
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
# decode
|
||||||
# decode
|
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
||||||
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
|
||||||
|
|
||||||
# float the dict
|
# float the dict
|
||||||
if self.inference_cfg.flag_use_half_precision:
|
if self.inference_cfg.flag_use_half_precision:
|
||||||
|
@ -59,7 +59,7 @@ class DenseMotionNetwork(nn.Module):
|
|||||||
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
|
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
|
||||||
|
|
||||||
# adding background feature
|
# adding background feature
|
||||||
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
|
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
|
||||||
heatmap = torch.cat([zeros, heatmap], dim=1)
|
heatmap = torch.cat([zeros, heatmap], dim=1)
|
||||||
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
|
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
|
||||||
return heatmap
|
return heatmap
|
||||||
|
@ -6,6 +6,7 @@ from typing import List, Tuple, Union
|
|||||||
|
|
||||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..config.crop_config import CropConfig
|
from ..config.crop_config import CropConfig
|
||||||
from .crop import (
|
from .crop import (
|
||||||
@ -43,10 +44,16 @@ class Cropper(object):
|
|||||||
flag_force_cpu = kwargs.get("flag_force_cpu", False)
|
flag_force_cpu = kwargs.get("flag_force_cpu", False)
|
||||||
if flag_force_cpu:
|
if flag_force_cpu:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
face_analysis_wrapper_provicer = ["CPUExecutionProvider"]
|
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
|
||||||
else:
|
else:
|
||||||
device = "cuda"
|
if torch.backends.mps.is_available():
|
||||||
face_analysis_wrapper_provicer = ["CUDAExecutionProvider"]
|
# Shape inference currently fails with CoreMLExecutionProvider
|
||||||
|
# for the retinaface model
|
||||||
|
device = "mps"
|
||||||
|
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
|
||||||
|
else:
|
||||||
|
device = "cuda"
|
||||||
|
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
|
||||||
self.landmark_runner = LandmarkRunner(
|
self.landmark_runner = LandmarkRunner(
|
||||||
ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
|
ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
|
||||||
onnx_provider=device,
|
onnx_provider=device,
|
||||||
@ -57,7 +64,7 @@ class Cropper(object):
|
|||||||
self.face_analysis_wrapper = FaceAnalysisDIY(
|
self.face_analysis_wrapper = FaceAnalysisDIY(
|
||||||
name="buffalo_l",
|
name="buffalo_l",
|
||||||
root=make_abs_path(self.crop_cfg.insightface_root),
|
root=make_abs_path(self.crop_cfg.insightface_root),
|
||||||
providers=face_analysis_wrapper_provicer,
|
providers=face_analysis_wrapper_provider,
|
||||||
)
|
)
|
||||||
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
|
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
|
||||||
self.face_analysis_wrapper.warmup()
|
self.face_analysis_wrapper.warmup()
|
||||||
|
@ -68,7 +68,7 @@ def find_onnx_file(dir_path):
|
|||||||
return paths[-1]
|
return paths[-1]
|
||||||
|
|
||||||
def get_default_providers():
|
def get_default_providers():
|
||||||
return ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
return ['CUDAExecutionProvider', 'CoreMLExecutionProvider', 'CPUExecutionProvider']
|
||||||
|
|
||||||
def get_default_provider_options():
|
def get_default_provider_options():
|
||||||
return None
|
return None
|
||||||
|
@ -39,6 +39,12 @@ class LandmarkRunner(object):
|
|||||||
('CUDAExecutionProvider', {'device_id': device_id})
|
('CUDAExecutionProvider', {'device_id': device_id})
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
elif onnx_provider.lower() == 'mps':
|
||||||
|
self.session = onnxruntime.InferenceSession(
|
||||||
|
ckpt_path, providers=[
|
||||||
|
'CoreMLExecutionProvider'
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
opts = onnxruntime.SessionOptions()
|
opts = onnxruntime.SessionOptions()
|
||||||
opts.intra_op_num_threads = 4 # 默认线程数为 4
|
opts.intra_op_num_threads = 4 # 默认线程数为 4
|
||||||
|
@ -175,8 +175,13 @@ def has_audio_stream(video_path: str) -> bool:
|
|||||||
# Check if there is any output from ffprobe command
|
# Check if there is any output from ffprobe command
|
||||||
return bool(result.stdout.strip())
|
return bool(result.stdout.strip())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log(f"Error occurred while probing video: {video_path}, you may need to install ffprobe! Now set audio to false!", style="bold red")
|
log(
|
||||||
return False
|
f"Error occurred while probing video: {video_path}, "
|
||||||
|
"you may need to install ffprobe! (https://ffmpeg.org/download.html) "
|
||||||
|
"Now set audio to false!",
|
||||||
|
style="bold red"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
|
def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
|
||||||
|
Loading…
Reference in New Issue
Block a user