From d09527c76275969e99f4d9e29775f02df1b4c48e Mon Sep 17 00:00:00 2001 From: guojianzhu Date: Fri, 5 Jul 2024 15:09:43 +0800 Subject: [PATCH] chore: slightly refine the codebase --- app.py | 8 ++++---- assets/gradio_description_animation.md | 4 ++-- assets/gradio_description_retargeting.md | 2 +- assets/gradio_description_upload.md | 6 ++---- src/config/argument_config.py | 6 +++--- src/config/inference_config.py | 4 ++-- src/gradio_pipeline.py | 16 ++++++++-------- src/live_portrait_pipeline.py | 2 +- src/live_portrait_wrapper.py | 10 +++++----- src/utils/helper.py | 16 +++++++--------- 10 files changed, 35 insertions(+), 39 deletions(-) diff --git a/app.py b/app.py index 33e24ab..a82443b 100644 --- a/app.py +++ b/app.py @@ -44,8 +44,8 @@ data_examples = [ #################### interface logic #################### # Define components first -eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio") -lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio") +eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") +lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") retargeting_input_image = gr.Image(type="numpy") output_image = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy") @@ -56,7 +56,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.HTML(load_description(title_md)) gr.Markdown(load_description("assets/gradio_description_upload.md")) with gr.Row(): - with gr.Accordion(open=True, label="Reference Portrait"): + with gr.Accordion(open=True, label="Source Portrait"): image_input = gr.Image(type="filepath") with gr.Accordion(open=True, label="Driving Video"): video_input = gr.Video() @@ -64,7 +64,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Accordion(open=True, label="Animation Options"): with gr.Row(): - flag_relative_input = gr.Checkbox(value=True, label="relative pose") + flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_do_crop_input = gr.Checkbox(value=True, label="do crop") flag_remap_input = gr.Checkbox(value=True, label="paste-back") with gr.Row(): diff --git a/assets/gradio_description_animation.md b/assets/gradio_description_animation.md index 6cd6791..34b3897 100644 --- a/assets/gradio_description_animation.md +++ b/assets/gradio_description_animation.md @@ -1,7 +1,7 @@ -🔥 To animate the reference portrait with the driving video, please follow these steps: +🔥 To animate the source portrait with the driving video, please follow these steps:
1. Specify the options in the Animation Options section. We recommend checking the do crop option when facial areas occupy a relatively small portion of your image.
2. Press the 🚀 Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. -
\ No newline at end of file + diff --git a/assets/gradio_description_retargeting.md b/assets/gradio_description_retargeting.md index 5fe6ebf..a99796d 100644 --- a/assets/gradio_description_retargeting.md +++ b/assets/gradio_description_retargeting.md @@ -1 +1 @@ -🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the 🚗 Retargeting button. The result would be shown in the middle block. You can try running it multiple times. 😊 Set both ratios to 0.8 to see what's going on! +🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the 🚗 Retargeting button. The result would be shown in the middle block. You can try running it multiple times. 😊 Set both ratios to 0.8 to see what's going on! diff --git a/assets/gradio_description_upload.md b/assets/gradio_description_upload.md index cba21ca..46a5fa5 100644 --- a/assets/gradio_description_upload.md +++ b/assets/gradio_description_upload.md @@ -1,4 +1,2 @@ - ## 🤗 This is the official gradio demo for **Live Portrait**. -### Guidance for the gradio page: -
Please upload or use the webcam to get a reference portrait to the Reference Portrait field and a driving video to the Driving Video field.
- +## 🤗 This is the official gradio demo for **LivePortrait**. +
Please upload or use the webcam to get a source portrait to the Source Portrait field and a driving video to the Driving Video field.
diff --git a/src/config/argument_config.py b/src/config/argument_config.py index 142f0b2..0431627 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -14,7 +14,7 @@ from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class ArgumentConfig(PrintableConfig): ########## input arguments ########## - source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the reference portrait + source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video ##################################### @@ -25,9 +25,9 @@ class ArgumentConfig(PrintableConfig): flag_eye_retargeting: bool = False flag_lip_retargeting: bool = False flag_stitching: bool = True # we recommend setting it to True! - flag_relative: bool = True # whether to use relative pose + flag_relative: bool = True # whether to use relative motion flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space - flag_do_crop: bool = True # whether to crop the reference portrait to the face-cropping space + flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True ######################################### diff --git a/src/config/inference_config.py b/src/config/inference_config.py index 0da3e3c..e94aeb8 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -28,7 +28,7 @@ class InferenceConfig(PrintableConfig): flag_lip_retargeting: bool = False flag_stitching: bool = True # we recommend setting it to True! - flag_relative: bool = True # whether to use relative pose + flag_relative: bool = True # whether to use relative motion anchor_frame: int = 0 # set this value if find_best_frame is True input_shape: Tuple[int, int] = (256, 256) # input shape @@ -45,5 +45,5 @@ class InferenceConfig(PrintableConfig): ref_shape_n: int = 2 device_id: int = 0 - flag_do_crop: bool = False # whether to crop the reference portrait to the face-cropping space + flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index 0f89013..c717897 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -35,7 +35,7 @@ class GradioPipeline(LivePortraitPipeline): self.mask_ori = None self.img_rgb = None self.crop_M_c2o = None - + def execute_video( self, @@ -62,9 +62,9 @@ class GradioPipeline(LivePortraitPipeline): # video driven animation video_path, video_path_concat = self.execute(self.args) gr.Info("Run successfully!", duration=2) - return video_path, video_path_concat, + return video_path, video_path_concat, else: - raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5) + raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5) def execute_image(self, input_eye_ratio: float, input_lip_ratio: float): """ for single image retargeting @@ -74,12 +74,12 @@ class GradioPipeline(LivePortraitPipeline): elif self.f_s_user is None: if self.start_prepare: raise gr.Error( - "The reference portrait is under processing 💥! Please wait for a second.", + "The source portrait is under processing 💥! Please wait for a second.", duration=5 ) else: raise gr.Error( - "The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.", + "The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.", duration=5 ) else: @@ -98,7 +98,7 @@ class GradioPipeline(LivePortraitPipeline): out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori) gr.Info("Run successfully!", duration=2) return out, out_to_ori_blend - + def prepare_retargeting(self, input_image_path, flag_do_crop = True): """ for single image retargeting @@ -107,7 +107,7 @@ class GradioPipeline(LivePortraitPipeline): gr.Info("Upload successfully!", duration=2) self.start_prepare = True inference_cfg = self.live_portrait_wrapper.cfg - ######## process reference portrait ######## + ######## process source portrait ######## img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16) log(f"Load source image from {input_image_path}.") crop_info = self.cropper.crop_single_image(img_rgb) @@ -125,7 +125,7 @@ class GradioPipeline(LivePortraitPipeline): self.x_s_info_user = x_s_info self.source_lmk_user = crop_info['lmk_crop'] self.img_rgb = img_rgb - self.crop_M_c2o = crop_info['M_c2o'] + self.crop_M_c2o = crop_info['M_c2o'] self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) # update slider eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None]) diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index 933c911..7fda1f5 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -40,7 +40,7 @@ class LivePortraitPipeline(object): def execute(self, args: ArgumentConfig): inference_cfg = self.live_portrait_wrapper.cfg # for convenience - ######## process reference portrait ######## + ######## process source portrait ######## img_rgb = load_image_rgb(args.source_image) img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) log(f"Load source image from {args.source_image}") diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py index ac3c63a..0ad9d06 100644 --- a/src/live_portrait_wrapper.py +++ b/src/live_portrait_wrapper.py @@ -10,12 +10,12 @@ import cv2 import torch import yaml -from src.utils.timer import Timer -from src.utils.helper import load_model, concat_feat -from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix +from .utils.timer import Timer +from .utils.helper import load_model, concat_feat +from .utils.camera import headpose_pred_to_degree, get_rotation_matrix from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio -from src.config.inference_config import InferenceConfig -from src.utils.rprint import rlog as log +from .config.inference_config import InferenceConfig +from .utils.rprint import rlog as log class LivePortraitWrapper(object): diff --git a/src/utils/helper.py b/src/utils/helper.py index 05c991e..4974fc5 100644 --- a/src/utils/helper.py +++ b/src/utils/helper.py @@ -6,17 +6,14 @@ utility functions and classes to handle feature extraction and model loading import os import os.path as osp -import cv2 import torch -from rich.console import Console from collections import OrderedDict -from src.modules.spade_generator import SPADEDecoder -from src.modules.warping_network import WarpingNetwork -from src.modules.motion_extractor import MotionExtractor -from src.modules.appearance_feature_extractor import AppearanceFeatureExtractor -from src.modules.stitching_retargeting_network import StitchingRetargetingNetwork -from .rprint import rlog as log +from ..modules.spade_generator import SPADEDecoder +from ..modules.warping_network import WarpingNetwork +from ..modules.motion_extractor import MotionExtractor +from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor +from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork def suffix(filename): @@ -45,6 +42,7 @@ def is_video(file_path): return True return False + def is_template(file_path): if file_path.endswith(".pkl"): return True @@ -149,8 +147,8 @@ def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale']) return new_rotation, new_expression, new_translation, new_scale + def load_description(fp): with open(fp, 'r', encoding='utf-8') as f: content = f.read() return content -