LivePortrait/inference.py
Jianzhu Guo 0f839844f6
feat: support macOS with Apple Silicon  (#155)
* feat: macOS support (#143)

* Support for running on Apple Silicon Macs with MPS

* Minor typo fix: s/provicer/provider/

* Another typo fix: s/concact/concat/

* s/cudaexecutionprovider/CUDAExecutionProvider/

* Add requirements_apple.txt

* doc: macOS support

* chore: refine the structure and doc

* doc: update readme

* doc: update readme

* doc: update readme

* doc: update readme

---------

Co-authored-by: Jeethu Rao <jeethu@jeethurao.com>
Co-authored-by: zzzweakman <1819489045@qq.com>
2024-07-17 16:57:33 +08:00

59 lines
1.6 KiB
Python

# coding: utf-8
import os.path as osp
import tyro
import subprocess
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source_image):
raise FileNotFoundError(f"source image not found: {args.source_image}")
if not osp.exists(args.driving_info):
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
)
# fast check the args
fast_check_args(args)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)
# run
live_portrait_pipeline.execute(args)
if __name__ == "__main__":
main()