From 2638f3b10aee8c7208774ccd7863926ec76884e5 Mon Sep 17 00:00:00 2001 From: zhangdingyun Date: Mon, 12 Aug 2024 15:08:47 +0800 Subject: [PATCH] feat: update --- src/config/argument_config.py | 5 +- src/config/inference_config.py | 1 + src/live_portrait_pipeline.py | 163 ++++++++++++++++++++++----------- 3 files changed, 114 insertions(+), 55 deletions(-) diff --git a/src/config/argument_config.py b/src/config/argument_config.py index bea2d2f..6599ce5 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class ArgumentConfig(PrintableConfig): ########## input arguments ########## - source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human) - driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) + source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s3.jpg') # path to the source portrait (human/animal) or video (human) + driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format) output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video ########## inference arguments ########## @@ -35,6 +35,7 @@ class ArgumentConfig(PrintableConfig): driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly" driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video + animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "pose" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose ########## source crop arguments ########## det_thresh: float = 0.15 # detection threshold scale: float = 2.3 # the ratio of face area is smaller if scale is larger diff --git a/src/config/inference_config.py b/src/config/inference_config.py index 38f1ecf..c56f01a 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -49,6 +49,7 @@ class InferenceConfig(PrintableConfig): driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy source_max_dim: int = 1280 # the max dim of height and width of source image or video source_division: int = 2 # make sure the height and width of source image or video can be divided by this number + animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose # NOT EXPORTED PARAMS lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index 7f06d2a..dd66b89 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -111,6 +111,7 @@ class LivePortraitPipeline(object): c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] driving_n_frames = driving_template_dct['n_frames'] + flag_is_driving_video = True if driving_n_frames > 1 else False if flag_is_source_video: n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames else: @@ -123,16 +124,25 @@ class LivePortraitPipeline(object): if args.flag_crop_driving_video: log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") - elif osp.exists(args.driving) and is_video(args.driving): - # load from video file, AND make motion template - output_fps = int(get_fps(args.driving)) - log(f"Load driving video from: {args.driving}, FPS is {output_fps}") - - driving_rgb_lst = load_video(args.driving) - driving_n_frames = len(driving_rgb_lst) + elif osp.exists(args.driving): + if is_video(args.driving): + flag_is_driving_video = True + # load from video file, AND make motion template + output_fps = int(get_fps(args.driving)) + log(f"Load driving video from: {args.driving}, FPS is {output_fps}") + driving_rgb_lst = load_video(args.driving) + elif is_image(args.driving): + flag_is_driving_video = False + driving_img_rgb = load_image_rgb(args.driving) + output_fps = 1 + log(f"Load driving image from {args.driving}") + driving_rgb_lst = [driving_img_rgb] + else: + raise Exception(f"{args.driving} is not a supported type!") ######## make motion template ######## log("Start making driving motion template...") + driving_n_frames = len(driving_rgb_lst) if flag_is_source_video: n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames driving_rgb_lst = driving_rgb_lst[:n_frames] @@ -158,9 +168,8 @@ class LivePortraitPipeline(object): wfp_template = remove_suffix(args.driving) + '.pkl' dump(wfp_template, driving_template_dct) log(f"Dump motion template to {wfp_template}") - else: - raise Exception(f"{args.driving} not exists or unsupported driving info types!") + raise Exception(f"{args.driving} does not exist!") ######## prepare for pasteback ######## I_p_pstbk_lst = None @@ -288,11 +297,30 @@ class LivePortraitPipeline(object): else: R_new = R_s else: - R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + delta_new = x_s_info['exp'] + R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s + else: + R_new = R_s + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": + delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + elif inf_cfg.animation_region == "lip": + delta_new = x_s_info['exp'] + for lip_idx in [14, 17, 19, 20]: + delta_new[:, lip_idx, :] += (x_d_i_info['exp'][:, lip_idx, :] - x_d_0_info['exp'][:, lip_idx, :]) + elif inf_cfg.animation_region == "eyes": + delta_new = x_s_info['exp'] + for eyes_idx in [11, 13, 15, 16]: + delta_new[:, eyes_idx, :] += (x_d_i_info['exp'][:, eyes_idx, :] - x_d_0_info['exp'][:, eyes_idx, :]) - delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) - scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) - t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) + if inf_cfg.animation_region == "all": + scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) + else: + scale_new = x_s_info['scale'] + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) + else: + t_new = x_s_info['t'] else: if flag_is_source_video: if inf_cfg.flag_video_editing_head_rotation: @@ -300,15 +328,31 @@ class LivePortraitPipeline(object): else: R_new = R_s else: - R_new = R_d_i - delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp'] + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + delta_new = x_s_info['exp'] + R_new = R_d_i + else: + R_new = R_s + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": + delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp'] + elif inf_cfg.animation_region == "lip": + delta_new = x_s_info['exp'] + for lip_idx in [14, 17, 19, 20]: + delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :] + elif inf_cfg.animation_region == "eyes": + delta_new = x_s_info['exp'] + for eyes_idx in [11, 13, 15, 16]: + delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :] scale_new = x_s_info['scale'] - t_new = x_d_i_info['t'] + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + t_new = x_d_i_info['t'] + else: + t_new = x_s_info['t'] t_new[..., 2].fill_(0) # zero tz x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new - if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video: + if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all": if i == 0: x_d_0_new = x_d_i_new motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) @@ -373,50 +417,63 @@ class LivePortraitPipeline(object): mkdir(args.output_dir) wfp_concat = None - flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) - flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) - ######### build the final concatenation result ######### - # driving frame | source frame | generation, or source frame | generation + # driving frame | source frame | generation if flag_is_source_video: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) else: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst) - wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') - # NOTE: update output fps - output_fps = source_fps if flag_is_source_video else output_fps - images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) + if flag_is_driving_video: + flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) + flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) - if flag_source_has_audio or flag_driving_has_audio: - # final result with concatenation - wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') - audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source - log(f"Audio is selected from {audio_from_which_video}, concat mode") - add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) - os.replace(wfp_concat_with_audio, wfp_concat) - log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") + wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') - # save the animated result - wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') - if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: - images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) + # NOTE: update output fps + output_fps = source_fps if flag_is_source_video else output_fps + images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) + + if flag_source_has_audio or flag_driving_has_audio: + # final result with concatenation + wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') + audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source + log(f"Audio is selected from {audio_from_which_video}, concat mode") + add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) + os.replace(wfp_concat_with_audio, wfp_concat) + log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") + + # save the animated result + wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') + if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: + images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) + else: + images2video(I_p_lst, wfp=wfp, fps=output_fps) + + ######### build the final result ######### + if flag_source_has_audio or flag_driving_has_audio: + wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') + audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source + log(f"Audio is selected from {audio_from_which_video}") + add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) + os.replace(wfp_with_audio, wfp) + log(f"Replace {wfp_with_audio} with {wfp}") + + # final log + if wfp_template not in (None, ''): + log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') + log(f'Animated video: {wfp}') + log(f'Animated video with concat: {wfp_concat}') else: - images2video(I_p_lst, wfp=wfp, fps=output_fps) - - ######### build the final result ######### - if flag_source_has_audio or flag_driving_has_audio: - wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') - audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source - log(f"Audio is selected from {audio_from_which_video}") - add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) - os.replace(wfp_with_audio, wfp) - log(f"Replace {wfp_with_audio} with {wfp}") - - # final log - if wfp_template not in (None, ''): - log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') - log(f'Animated video: {wfp}') - log(f'Animated video with concat: {wfp_concat}') + wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg') + cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1]) + wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg') + if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: + cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1]) + else: + cv2.imwrite(wfp, frames_concatenated[0][..., ::-1]) + # final log + log(f'Animated image: {wfp}') + log(f'Animated image with concat: {wfp_concat}') return wfp, wfp_concat