feat: gradio acceleration (#123)

* fix: typo

* feat: gradio acceleration

* doc: update readme.md

---------

Co-authored-by: Jianzhu Guo <guojianzhu@kuaishou.com>
This commit is contained in:
ZhizhouZhong 2024-07-12 17:57:01 +08:00 committed by GitHub
parent e327753042
commit d8036cffde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 6 deletions

13
app.py
View File

@ -5,6 +5,7 @@ The entrance of the gradio
""" """
import tyro import tyro
import subprocess
import gradio as gr import gradio as gr
import os.path as osp import os.path as osp
from src.utils.helper import load_description from src.utils.helper import load_description
@ -18,10 +19,22 @@ def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
# set tyro theme # set tyro theme
tyro.extras.set_accent_color("bright_cyan") tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig) args = tyro.cli(ArgumentConfig)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
)
# specify configs for inference # specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig

View File

@ -148,6 +148,13 @@ python app.py
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs! You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
🚀 We also provide an acceleration option `--flag_do_torch_compile`. The first-time inference triggers an optimization process (about one minute), making subsequent inferences 20-30% faster. Performance gains may vary with different CUDA versions.
```bash
# enable torch.compile for faster inference
python app.py --flag_do_torch_compile
```
**Note**: This method has not been fully tested. e.g., on Windows.
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗** **Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
### 5. Inference speed evaluation 🚀🚀🚀 ### 5. Inference speed evaluation 🚀🚀🚀

View File

@ -23,7 +23,7 @@ class ArgumentConfig(PrintableConfig):
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
device_id: int = 0 # gpu device id device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP! flag_force_cpu: bool = False # force cpu inference, WIP!
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_eye_retargeting: bool = False # not recommend to be True, WIP flag_eye_retargeting: bool = False # not recommend to be True, WIP
flag_lip_retargeting: bool = False # not recommend to be True, WIP flag_lip_retargeting: bool = False # not recommend to be True, WIP
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
@ -45,3 +45,4 @@ class ArgumentConfig(PrintableConfig):
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
share: bool = False # whether to share the server to public share: bool = False # whether to share the server to public
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation

View File

@ -14,7 +14,7 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig): class InferenceConfig(PrintableConfig):
# MODEL CONFIG, NOT EXPOERTED PARAMS # MODEL CONFIG, NOT EXPORTED PARAMS
models_config: str = make_abs_path('./models.yaml') # portrait animation config models_config: str = make_abs_path('./models.yaml') # portrait animation config
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
@ -22,7 +22,7 @@ class InferenceConfig(PrintableConfig):
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
# EXPOERTED PARAMS # EXPORTED PARAMS
flag_use_half_precision: bool = True flag_use_half_precision: bool = True
flag_crop_driving_video: bool = False flag_crop_driving_video: bool = False
device_id: int = 0 device_id: int = 0
@ -35,8 +35,9 @@ class InferenceConfig(PrintableConfig):
flag_do_crop: bool = True flag_do_crop: bool = True
flag_do_rot: bool = True flag_do_rot: bool = True
flag_force_cpu: bool = False flag_force_cpu: bool = False
flag_do_torch_compile: bool = False
# NOT EXPOERTED PARAMS # NOT EXPORTED PARAMS
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
anchor_frame: int = 0 # TO IMPLEMENT anchor_frame: int = 0 # TO IMPLEMENT

View File

@ -24,6 +24,7 @@ class LivePortraitWrapper(object):
self.inference_cfg = inference_cfg self.inference_cfg = inference_cfg
self.device_id = inference_cfg.device_id self.device_id = inference_cfg.device_id
self.compile = inference_cfg.flag_do_torch_compile
if inference_cfg.flag_force_cpu: if inference_cfg.flag_force_cpu:
self.device = 'cpu' self.device = 'cpu'
else: else:
@ -48,8 +49,10 @@ class LivePortraitWrapper(object):
log(f'Load stitching_retargeting_module done.') log(f'Load stitching_retargeting_module done.')
else: else:
self.stitching_retargeting_module = None self.stitching_retargeting_module = None
# Optimize for inference
if self.compile:
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
self.timer = Timer() self.timer = Timer()
@ -261,6 +264,9 @@ class LivePortraitWrapper(object):
# The line 18 in Algorithm 1: D(W(f_s; x_s, x_d,i) # The line 18 in Algorithm 1: D(W(f_s; x_s, x_d,i)
with torch.no_grad(): with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision): with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
if self.compile:
# Mark the beginning of a new CUDA Graph step
torch.compiler.cudagraph_mark_step_begin()
# get decoder input # get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving) ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode # decode