chore: slightly refine the codebase

This commit is contained in:
guojianzhu 2024-07-05 15:09:43 +08:00
parent 669487a2fe
commit d09527c762
10 changed files with 35 additions and 39 deletions

8
app.py
View File

@ -44,8 +44,8 @@ data_examples = [
#################### interface logic ####################
# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
retargeting_input_image = gr.Image(type="numpy")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
@ -56,7 +56,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"):
with gr.Accordion(open=True, label="Source Portrait"):
image_input = gr.Image(type="filepath")
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video()
@ -64,7 +64,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative pose")
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
with gr.Row():

View File

@ -1,7 +1,7 @@
<span style="font-size: 1.2em;">🔥 To animate the reference portrait with the driving video, please follow these steps:</span>
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>
</div>

View File

@ -1 +1 @@
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>

View File

@ -1,4 +1,2 @@
## 🤗 This is the official gradio demo for **Live Portrait**.
### Guidance for the gradio page:
<div style="font-size: 1.2em;">Please upload or use the webcam to get a reference portrait to the <strong>Reference Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
## 🤗 This is the official gradio demo for **LivePortrait**.
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>

View File

@ -14,7 +14,7 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig):
########## input arguments ##########
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the reference portrait
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
#####################################
@ -25,9 +25,9 @@ class ArgumentConfig(PrintableConfig):
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative pose
flag_relative: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the reference portrait to the face-cropping space
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
#########################################

View File

@ -28,7 +28,7 @@ class InferenceConfig(PrintableConfig):
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative pose
flag_relative: bool = True # whether to use relative motion
anchor_frame: int = 0 # set this value if find_best_frame is True
input_shape: Tuple[int, int] = (256, 256) # input shape
@ -45,5 +45,5 @@ class InferenceConfig(PrintableConfig):
ref_shape_n: int = 2
device_id: int = 0
flag_do_crop: bool = False # whether to crop the reference portrait to the face-cropping space
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True

View File

@ -35,7 +35,7 @@ class GradioPipeline(LivePortraitPipeline):
self.mask_ori = None
self.img_rgb = None
self.crop_M_c2o = None
def execute_video(
self,
@ -62,9 +62,9 @@ class GradioPipeline(LivePortraitPipeline):
# video driven animation
video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
return video_path, video_path_concat,
else:
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5)
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting
@ -74,12 +74,12 @@ class GradioPipeline(LivePortraitPipeline):
elif self.f_s_user is None:
if self.start_prepare:
raise gr.Error(
"The reference portrait is under processing 💥! Please wait for a second.",
"The source portrait is under processing 💥! Please wait for a second.",
duration=5
)
else:
raise gr.Error(
"The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
duration=5
)
else:
@ -98,7 +98,7 @@ class GradioPipeline(LivePortraitPipeline):
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
""" for single image retargeting
@ -107,7 +107,7 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Upload successfully!", duration=2)
self.start_prepare = True
inference_cfg = self.live_portrait_wrapper.cfg
######## process reference portrait ########
######## process source portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
@ -125,7 +125,7 @@ class GradioPipeline(LivePortraitPipeline):
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
self.img_rgb = img_rgb
self.crop_M_c2o = crop_info['M_c2o']
self.crop_M_c2o = crop_info['M_c2o']
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])

View File

@ -40,7 +40,7 @@ class LivePortraitPipeline(object):
def execute(self, args: ArgumentConfig):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
######## process reference portrait ########
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image)
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
log(f"Load source image from {args.source_image}")

View File

@ -10,12 +10,12 @@ import cv2
import torch
import yaml
from src.utils.timer import Timer
from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.timer import Timer
from .utils.helper import load_model, concat_feat
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.rprint import rlog as log
from .config.inference_config import InferenceConfig
from .utils.rprint import rlog as log
class LivePortraitWrapper(object):

View File

@ -6,17 +6,14 @@ utility functions and classes to handle feature extraction and model loading
import os
import os.path as osp
import cv2
import torch
from rich.console import Console
from collections import OrderedDict
from src.modules.spade_generator import SPADEDecoder
from src.modules.warping_network import WarpingNetwork
from src.modules.motion_extractor import MotionExtractor
from src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
from src.modules.stitching_retargeting_network import StitchingRetargetingNetwork
from .rprint import rlog as log
from ..modules.spade_generator import SPADEDecoder
from ..modules.warping_network import WarpingNetwork
from ..modules.motion_extractor import MotionExtractor
from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork
def suffix(filename):
@ -45,6 +42,7 @@ def is_video(file_path):
return True
return False
def is_template(file_path):
if file_path.endswith(".pkl"):
return True
@ -149,8 +147,8 @@ def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
return new_rotation, new_expression, new_translation, new_scale
def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
return content