chore: slightly refine the codebase

This commit is contained in:
guojianzhu 2024-07-05 15:09:43 +08:00
parent 669487a2fe
commit d09527c762
10 changed files with 35 additions and 39 deletions

8
app.py
View File

@ -44,8 +44,8 @@ data_examples = [
#################### interface logic #################### #################### interface logic ####################
# Define components first # Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
retargeting_input_image = gr.Image(type="numpy") retargeting_input_image = gr.Image(type="numpy")
output_image = gr.Image(type="numpy") output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy")
@ -56,7 +56,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md)) gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md")) gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"): with gr.Accordion(open=True, label="Source Portrait"):
image_input = gr.Image(type="filepath") image_input = gr.Image(type="filepath")
with gr.Accordion(open=True, label="Driving Video"): with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video() video_input = gr.Video()
@ -64,7 +64,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Animation Options"): with gr.Accordion(open=True, label="Animation Options"):
with gr.Row(): with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative pose") flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop") flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_remap_input = gr.Checkbox(value=True, label="paste-back")
with gr.Row(): with gr.Row():

View File

@ -1,4 +1,4 @@
<span style="font-size: 1.2em;">🔥 To animate the reference portrait with the driving video, please follow these steps:</span> <span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;"> <div style="font-size: 1.2em; margin-left: 20px;">
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image. 1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
</div> </div>

View File

@ -1 +1 @@
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span> <span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>

View File

@ -1,4 +1,2 @@
## 🤗 This is the official gradio demo for **LivePortrait**. ## 🤗 This is the official gradio demo for **LivePortrait**.
### Guidance for the gradio page: <div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
<div style="font-size: 1.2em;">Please upload or use the webcam to get a reference portrait to the <strong>Reference Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>

View File

@ -14,7 +14,7 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig): class ArgumentConfig(PrintableConfig):
########## input arguments ########## ########## input arguments ##########
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the reference portrait source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
##################################### #####################################
@ -25,9 +25,9 @@ class ArgumentConfig(PrintableConfig):
flag_eye_retargeting: bool = False flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True! flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative pose flag_relative: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the reference portrait to the face-cropping space flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
######################################### #########################################

View File

@ -28,7 +28,7 @@ class InferenceConfig(PrintableConfig):
flag_lip_retargeting: bool = False flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True! flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative pose flag_relative: bool = True # whether to use relative motion
anchor_frame: int = 0 # set this value if find_best_frame is True anchor_frame: int = 0 # set this value if find_best_frame is True
input_shape: Tuple[int, int] = (256, 256) # input shape input_shape: Tuple[int, int] = (256, 256) # input shape
@ -45,5 +45,5 @@ class InferenceConfig(PrintableConfig):
ref_shape_n: int = 2 ref_shape_n: int = 2
device_id: int = 0 device_id: int = 0
flag_do_crop: bool = False # whether to crop the reference portrait to the face-cropping space flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True

View File

@ -64,7 +64,7 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Run successfully!", duration=2) gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat, return video_path, video_path_concat,
else: else:
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5) raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float): def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting """ for single image retargeting
@ -74,12 +74,12 @@ class GradioPipeline(LivePortraitPipeline):
elif self.f_s_user is None: elif self.f_s_user is None:
if self.start_prepare: if self.start_prepare:
raise gr.Error( raise gr.Error(
"The reference portrait is under processing 💥! Please wait for a second.", "The source portrait is under processing 💥! Please wait for a second.",
duration=5 duration=5
) )
else: else:
raise gr.Error( raise gr.Error(
"The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.", "The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
duration=5 duration=5
) )
else: else:
@ -107,7 +107,7 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Upload successfully!", duration=2) gr.Info("Upload successfully!", duration=2)
self.start_prepare = True self.start_prepare = True
inference_cfg = self.live_portrait_wrapper.cfg inference_cfg = self.live_portrait_wrapper.cfg
######## process reference portrait ######## ######## process source portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16) img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image_path}.") log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb) crop_info = self.cropper.crop_single_image(img_rgb)

View File

@ -40,7 +40,7 @@ class LivePortraitPipeline(object):
def execute(self, args: ArgumentConfig): def execute(self, args: ArgumentConfig):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience inference_cfg = self.live_portrait_wrapper.cfg # for convenience
######## process reference portrait ######## ######## process source portrait ########
img_rgb = load_image_rgb(args.source_image) img_rgb = load_image_rgb(args.source_image)
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
log(f"Load source image from {args.source_image}") log(f"Load source image from {args.source_image}")

View File

@ -10,12 +10,12 @@ import cv2
import torch import torch
import yaml import yaml
from src.utils.timer import Timer from .utils.timer import Timer
from src.utils.helper import load_model, concat_feat from .utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig from .config.inference_config import InferenceConfig
from src.utils.rprint import rlog as log from .utils.rprint import rlog as log
class LivePortraitWrapper(object): class LivePortraitWrapper(object):

View File

@ -6,17 +6,14 @@ utility functions and classes to handle feature extraction and model loading
import os import os
import os.path as osp import os.path as osp
import cv2
import torch import torch
from rich.console import Console
from collections import OrderedDict from collections import OrderedDict
from src.modules.spade_generator import SPADEDecoder from ..modules.spade_generator import SPADEDecoder
from src.modules.warping_network import WarpingNetwork from ..modules.warping_network import WarpingNetwork
from src.modules.motion_extractor import MotionExtractor from ..modules.motion_extractor import MotionExtractor
from src.modules.appearance_feature_extractor import AppearanceFeatureExtractor from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
from src.modules.stitching_retargeting_network import StitchingRetargetingNetwork from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork
from .rprint import rlog as log
def suffix(filename): def suffix(filename):
@ -45,6 +42,7 @@ def is_video(file_path):
return True return True
return False return False
def is_template(file_path): def is_template(file_path):
if file_path.endswith(".pkl"): if file_path.endswith(".pkl"):
return True return True
@ -149,8 +147,8 @@ def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale']) new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
return new_rotation, new_expression, new_translation, new_scale return new_rotation, new_expression, new_translation, new_scale
def load_description(fp): def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f: with open(fp, 'r', encoding='utf-8') as f:
content = f.read() content = f.read()
return content return content