fix: clarify driving modes and add lip normalization option (#372)

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
This commit is contained in:
Jianzhu Guo 2024-09-06 16:43:08 +08:00 committed by GitHub
parent 1b96d32e4b
commit b01aeb4050
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 2 deletions

2
app.py
View File

@ -241,6 +241,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Animation Options"): with gr.Accordion(open=True, label="Animation Options"):
with gr.Row(): with gr.Row():
flag_normalize_lip = gr.Checkbox(value=False, label="normalize lip")
flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_stitching_input = gr.Checkbox(value=True, label="stitching") flag_stitching_input = gr.Checkbox(value=True, label="stitching")
@ -435,6 +436,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
driving_video_input, driving_video_input,
driving_image_input, driving_image_input,
driving_video_pickle_input, driving_video_pickle_input,
flag_normalize_lip,
flag_relative_input, flag_relative_input,
flag_do_crop_input, flag_do_crop_input,
flag_remap_input, flag_remap_input,

View File

@ -149,6 +149,7 @@ class GradioPipeline(LivePortraitPipeline):
input_driving_video_path=None, input_driving_video_path=None,
input_driving_image_path=None, input_driving_image_path=None,
input_driving_video_pickle_path=None, input_driving_video_pickle_path=None,
flag_normalize_lip=False,
flag_relative_input=True, flag_relative_input=True,
flag_do_crop_input=True, flag_do_crop_input=True,
flag_remap_input=True, flag_remap_input=True,
@ -187,7 +188,7 @@ class GradioPipeline(LivePortraitPipeline):
input_driving_path = input_driving_video_path input_driving_path = input_driving_video_path
if input_source_path is not None and input_driving_path is not None: if input_source_path is not None and input_driving_path is not None:
if osp.exists(input_driving_path) and v_tab_selection == 'Video' and is_square_video(input_driving_path) is False: if osp.exists(input_driving_path) and v_tab_selection == 'Video' and not flag_crop_driving_video_input and is_square_video(input_driving_path) is False:
flag_crop_driving_video_input = True flag_crop_driving_video_input = True
log("The driving video is not square, it will be cropped to square automatically.") log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2) gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
@ -195,6 +196,7 @@ class GradioPipeline(LivePortraitPipeline):
args_user = { args_user = {
'source': input_source_path, 'source': input_source_path,
'driving': input_driving_path, 'driving': input_driving_path,
'flag_normalize_lip' : flag_normalize_lip,
'flag_relative_motion': flag_relative_input, 'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input, 'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input, 'flag_pasteback': flag_remap_input,

View File

@ -386,7 +386,7 @@ class LivePortraitPipeline(object):
t_new[..., 2].fill_(0) # zero tz t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video: if inf_cfg.flag_relative_motion and inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video:
if i == 0: if i == 0:
x_d_0_new = x_d_i_new x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)