From 8a2ea1547122f1894df4da18b78aaef4c2685032 Mon Sep 17 00:00:00 2001 From: zhangdingyun Date: Thu, 15 Aug 2024 19:24:04 +0800 Subject: [PATCH] feat: update --- app.py | 8 ++++--- src/config/argument_config.py | 4 ++-- src/config/inference_config.py | 5 ++++ src/gradio_pipeline.py | 1 + src/live_portrait_pipeline.py | 43 +++++++++++++++++++++++++++------- 5 files changed, 48 insertions(+), 13 deletions(-) diff --git a/app.py b/app.py index 8068eee..15cbaa2 100644 --- a/app.py +++ b/app.py @@ -270,9 +270,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San driving_image_input = gr.Image(type="filepath") gr.Examples( examples=[ - [osp.join(example_video_dir, "d3.jpg")], + [osp.join(example_video_dir, "d30.jpg")], [osp.join(example_video_dir, "d9.jpg")], - [osp.join(example_video_dir, "d11.jpg")], + [osp.join(example_video_dir, "d19.jpg")], + [osp.join(example_video_dir, "d8.jpg")], ], inputs=[driving_image_input], cache_examples=False, @@ -312,7 +313,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_stitching_input = gr.Checkbox(value=True, label="stitching") - animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region") + animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="exp", label="animation region") driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)") driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02) driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8) @@ -327,6 +328,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San with gr.Column(): with gr.Accordion(open=True, label="The animated video"): output_video_concat_i2v.render() + with gr.Row(): with gr.Column(): with gr.Accordion(open=True, label="The animated image in the original image space"): output_image_i2i.render() diff --git a/src/config/argument_config.py b/src/config/argument_config.py index d6e6916..e0dccd8 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class ArgumentConfig(PrintableConfig): ########## input arguments ########## - source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s18.mp4') # path to the source portrait (human/animal) or video (human) - driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format) + source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to the source portrait (human/animal) or video (human) + driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to driving video or template (.pkl format) output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video ########## inference arguments ########## diff --git a/src/config/inference_config.py b/src/config/inference_config.py index d1b5572..c9ed197 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -6,10 +6,14 @@ config dataclass used for inference import cv2 from numpy import ndarray +import pickle as pkl from dataclasses import dataclass, field from typing import Literal, Tuple from .base_config import PrintableConfig, make_abs_path +def load_lip_array(): + with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f: + return pkl.load(f) @dataclass(repr=False) # use repr from PrintableConfig class InferenceConfig(PrintableConfig): @@ -61,4 +65,5 @@ class InferenceConfig(PrintableConfig): output_fps: int = 25 # default output fps mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) + lip_array: ndarray = field(default_factory=load_lip_array) size_gif: int = 256 # default gif size, TO IMPLEMENT diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index 5a9398a..a364c3d 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -315,6 +315,7 @@ class GradioPipeline(LivePortraitPipeline): if input_lip_ratio != self.source_lip_ratio: combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) + print(lip_delta) x_d_new = x_d_new + \ (eyes_delta if eyes_delta is not None else 0) + \ (lip_delta if lip_delta is not None else 0) diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index 0ec24ab..54651a8 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -159,7 +159,7 @@ class LivePortraitPipeline(object): if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)): ret_d = self.cropper.crop_driving_video(driving_rgb_lst) log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') - if len(ret_d["frame_crop_lst"]) is not n_frames: + if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video: n_frames = min(n_frames, len(ret_d["frame_crop_lst"])) driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst'] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] @@ -213,11 +213,19 @@ class LivePortraitPipeline(object): key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys if inf_cfg.flag_relative_motion: - x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] - x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) + if flag_is_driving_video: + x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] + x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) + else: + x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)] + x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst] if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": - x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] - x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) + if flag_is_driving_video: + x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] + x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) + else: + x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)] + x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst] else: if flag_is_driving_video: x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)] @@ -261,7 +269,10 @@ class LivePortraitPipeline(object): mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) ######## animate ######## - log(f"The animated video consists of {n_frames} frames.") + if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video): + log(f"The animated video consists of {n_frames} frames.") + else: + log(f"The output of image-driven portrait animation is an image.") for i in track(range(n_frames), description='🚀Animating...', total=n_frames): if flag_is_source_video: # source video x_s_info = source_template_dct['motion'][i] @@ -306,7 +317,9 @@ class LivePortraitPipeline(object): if i == 0: # cache the first frame R_d_0 = R_d_i - x_d_0_info = x_d_i_info + x_d_0_info = x_d_i_info.copy() + # if not flag_is_driving_video: + # x_d_0_info['exp'] = 0 delta_new = x_s_info['exp'].clone() if inf_cfg.flag_relative_motion: @@ -315,7 +328,20 @@ class LivePortraitPipeline(object): else: R_new = R_s if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": - delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + # delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + if flag_is_source_video: + for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: + delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] + delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] + delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] + delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] + delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] + else: + if flag_is_driving_video: + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + else: + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - 0) if x_d_i_info['exp'].mean() > 0 else x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)) + elif inf_cfg.animation_region == "lip": for lip_idx in [6, 12, 14, 17, 19, 20]: delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] @@ -355,6 +381,7 @@ class LivePortraitPipeline(object): t_new = x_s_info['t'] t_new[..., 2].fill_(0) # zero tz + # x_d_i_new = x_s_info['scale'] * (x_c_s @ R_s) + x_s_info['t'] x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all":