mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 20:42:38 +00:00
feat: gradio acceleration (#123)
* fix: typo * feat: gradio acceleration * doc: update readme.md --------- Co-authored-by: Jianzhu Guo <guojianzhu@kuaishou.com>
This commit is contained in:
parent
e327753042
commit
d8036cffde
13
app.py
13
app.py
@ -5,6 +5,7 @@ The entrance of the gradio
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import tyro
|
import tyro
|
||||||
|
import subprocess
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from src.utils.helper import load_description
|
from src.utils.helper import load_description
|
||||||
@ -18,10 +19,22 @@ def partial_fields(target_class, kwargs):
|
|||||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||||
|
|
||||||
|
|
||||||
|
def fast_check_ffmpeg():
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
# set tyro theme
|
# set tyro theme
|
||||||
tyro.extras.set_accent_color("bright_cyan")
|
tyro.extras.set_accent_color("bright_cyan")
|
||||||
args = tyro.cli(ArgumentConfig)
|
args = tyro.cli(ArgumentConfig)
|
||||||
|
|
||||||
|
if not fast_check_ffmpeg():
|
||||||
|
raise ImportError(
|
||||||
|
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
|
||||||
|
)
|
||||||
|
|
||||||
# specify configs for inference
|
# specify configs for inference
|
||||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||||
|
@ -148,6 +148,13 @@ python app.py
|
|||||||
|
|
||||||
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
|
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
|
||||||
|
|
||||||
|
🚀 We also provide an acceleration option `--flag_do_torch_compile`. The first-time inference triggers an optimization process (about one minute), making subsequent inferences 20-30% faster. Performance gains may vary with different CUDA versions.
|
||||||
|
```bash
|
||||||
|
# enable torch.compile for faster inference
|
||||||
|
python app.py --flag_do_torch_compile
|
||||||
|
```
|
||||||
|
**Note**: This method has not been fully tested. e.g., on Windows.
|
||||||
|
|
||||||
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
|
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
|
||||||
|
|
||||||
### 5. Inference speed evaluation 🚀🚀🚀
|
### 5. Inference speed evaluation 🚀🚀🚀
|
||||||
|
@ -23,7 +23,7 @@ class ArgumentConfig(PrintableConfig):
|
|||||||
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
|
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
|
||||||
device_id: int = 0 # gpu device id
|
device_id: int = 0 # gpu device id
|
||||||
flag_force_cpu: bool = False # force cpu inference, WIP!
|
flag_force_cpu: bool = False # force cpu inference, WIP!
|
||||||
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
||||||
flag_eye_retargeting: bool = False # not recommend to be True, WIP
|
flag_eye_retargeting: bool = False # not recommend to be True, WIP
|
||||||
flag_lip_retargeting: bool = False # not recommend to be True, WIP
|
flag_lip_retargeting: bool = False # not recommend to be True, WIP
|
||||||
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
|
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
|
||||||
@ -45,3 +45,4 @@ class ArgumentConfig(PrintableConfig):
|
|||||||
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
|
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
|
||||||
share: bool = False # whether to share the server to public
|
share: bool = False # whether to share the server to public
|
||||||
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
|
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
|
||||||
|
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation
|
||||||
|
@ -14,7 +14,7 @@ from .base_config import PrintableConfig, make_abs_path
|
|||||||
|
|
||||||
@dataclass(repr=False) # use repr from PrintableConfig
|
@dataclass(repr=False) # use repr from PrintableConfig
|
||||||
class InferenceConfig(PrintableConfig):
|
class InferenceConfig(PrintableConfig):
|
||||||
# MODEL CONFIG, NOT EXPOERTED PARAMS
|
# MODEL CONFIG, NOT EXPORTED PARAMS
|
||||||
models_config: str = make_abs_path('./models.yaml') # portrait animation config
|
models_config: str = make_abs_path('./models.yaml') # portrait animation config
|
||||||
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
|
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
|
||||||
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
|
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
|
||||||
@ -22,7 +22,7 @@ class InferenceConfig(PrintableConfig):
|
|||||||
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
|
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
|
||||||
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
|
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
|
||||||
|
|
||||||
# EXPOERTED PARAMS
|
# EXPORTED PARAMS
|
||||||
flag_use_half_precision: bool = True
|
flag_use_half_precision: bool = True
|
||||||
flag_crop_driving_video: bool = False
|
flag_crop_driving_video: bool = False
|
||||||
device_id: int = 0
|
device_id: int = 0
|
||||||
@ -35,8 +35,9 @@ class InferenceConfig(PrintableConfig):
|
|||||||
flag_do_crop: bool = True
|
flag_do_crop: bool = True
|
||||||
flag_do_rot: bool = True
|
flag_do_rot: bool = True
|
||||||
flag_force_cpu: bool = False
|
flag_force_cpu: bool = False
|
||||||
|
flag_do_torch_compile: bool = False
|
||||||
|
|
||||||
# NOT EXPOERTED PARAMS
|
# NOT EXPORTED PARAMS
|
||||||
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
|
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
|
||||||
anchor_frame: int = 0 # TO IMPLEMENT
|
anchor_frame: int = 0 # TO IMPLEMENT
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ class LivePortraitWrapper(object):
|
|||||||
|
|
||||||
self.inference_cfg = inference_cfg
|
self.inference_cfg = inference_cfg
|
||||||
self.device_id = inference_cfg.device_id
|
self.device_id = inference_cfg.device_id
|
||||||
|
self.compile = inference_cfg.flag_do_torch_compile
|
||||||
if inference_cfg.flag_force_cpu:
|
if inference_cfg.flag_force_cpu:
|
||||||
self.device = 'cpu'
|
self.device = 'cpu'
|
||||||
else:
|
else:
|
||||||
@ -48,8 +49,10 @@ class LivePortraitWrapper(object):
|
|||||||
log(f'Load stitching_retargeting_module done.')
|
log(f'Load stitching_retargeting_module done.')
|
||||||
else:
|
else:
|
||||||
self.stitching_retargeting_module = None
|
self.stitching_retargeting_module = None
|
||||||
|
# Optimize for inference
|
||||||
|
if self.compile:
|
||||||
|
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
|
||||||
|
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
|
||||||
|
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
|
|
||||||
@ -261,6 +264,9 @@ class LivePortraitWrapper(object):
|
|||||||
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
||||||
|
if self.compile:
|
||||||
|
# Mark the beginning of a new CUDA Graph step
|
||||||
|
torch.compiler.cudagraph_mark_step_begin()
|
||||||
# get decoder input
|
# get decoder input
|
||||||
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
||||||
# decode
|
# decode
|
||||||
|
Loading…
Reference in New Issue
Block a user