From d8036cffdeecd395f8bf02e70e481c4207842cf1 Mon Sep 17 00:00:00 2001 From: ZhizhouZhong <1819489045@qq.com> Date: Fri, 12 Jul 2024 17:57:01 +0800 Subject: [PATCH] feat: gradio acceleration (#123) * fix: typo * feat: gradio acceleration * doc: update readme.md --------- Co-authored-by: Jianzhu Guo --- app.py | 13 +++++++++++++ readme.md | 7 +++++++ src/config/argument_config.py | 3 ++- src/config/inference_config.py | 7 ++++--- src/live_portrait_wrapper.py | 10 ++++++++-- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/app.py b/app.py index 5494d9b..af15991 100644 --- a/app.py +++ b/app.py @@ -5,6 +5,7 @@ The entrance of the gradio """ import tyro +import subprocess import gradio as gr import os.path as osp from src.utils.helper import load_description @@ -18,10 +19,22 @@ def partial_fields(target_class, kwargs): return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) +def fast_check_ffmpeg(): + try: + subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) + return True + except: + return False + # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) +if not fast_check_ffmpeg(): + raise ImportError( + "FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html" + ) + # specify configs for inference inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig diff --git a/readme.md b/readme.md index 4a16bcc..90b8bcf 100644 --- a/readme.md +++ b/readme.md @@ -148,6 +148,13 @@ python app.py You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs! +🚀 We also provide an acceleration option `--flag_do_torch_compile`. The first-time inference triggers an optimization process (about one minute), making subsequent inferences 20-30% faster. Performance gains may vary with different CUDA versions. +```bash +# enable torch.compile for faster inference +python app.py --flag_do_torch_compile +``` +**Note**: This method has not been fully tested. e.g., on Windows. + **Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗** ### 5. Inference speed evaluation 🚀🚀🚀 diff --git a/src/config/argument_config.py b/src/config/argument_config.py index 0bbaa20..aa86713 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -23,7 +23,7 @@ class ArgumentConfig(PrintableConfig): flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video device_id: int = 0 # gpu device id flag_force_cpu: bool = False # force cpu inference, WIP! - flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False + flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False flag_eye_retargeting: bool = False # not recommend to be True, WIP flag_lip_retargeting: bool = False # not recommend to be True, WIP flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large @@ -45,3 +45,4 @@ class ArgumentConfig(PrintableConfig): server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server share: bool = False # whether to share the server to public server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all + flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation diff --git a/src/config/inference_config.py b/src/config/inference_config.py index 70eedd8..c25e83d 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -14,7 +14,7 @@ from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class InferenceConfig(PrintableConfig): - # MODEL CONFIG, NOT EXPOERTED PARAMS + # MODEL CONFIG, NOT EXPORTED PARAMS models_config: str = make_abs_path('./models.yaml') # portrait animation config checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M @@ -22,7 +22,7 @@ class InferenceConfig(PrintableConfig): checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip - # EXPOERTED PARAMS + # EXPORTED PARAMS flag_use_half_precision: bool = True flag_crop_driving_video: bool = False device_id: int = 0 @@ -35,8 +35,9 @@ class InferenceConfig(PrintableConfig): flag_do_crop: bool = True flag_do_rot: bool = True flag_force_cpu: bool = False + flag_do_torch_compile: bool = False - # NOT EXPOERTED PARAMS + # NOT EXPORTED PARAMS lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero anchor_frame: int = 0 # TO IMPLEMENT diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py index 8869b95..ca4bd1b 100644 --- a/src/live_portrait_wrapper.py +++ b/src/live_portrait_wrapper.py @@ -24,6 +24,7 @@ class LivePortraitWrapper(object): self.inference_cfg = inference_cfg self.device_id = inference_cfg.device_id + self.compile = inference_cfg.flag_do_torch_compile if inference_cfg.flag_force_cpu: self.device = 'cpu' else: @@ -48,8 +49,10 @@ class LivePortraitWrapper(object): log(f'Load stitching_retargeting_module done.') else: self.stitching_retargeting_module = None - - + # Optimize for inference + if self.compile: + self.warping_module = torch.compile(self.warping_module, mode='max-autotune') + self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune') self.timer = Timer() @@ -261,6 +264,9 @@ class LivePortraitWrapper(object): # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)) with torch.no_grad(): with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision): + if self.compile: + # Mark the beginning of a new CUDA Graph step + torch.compiler.cudagraph_mark_step_begin() # get decoder input ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving) # decode