feat: update

This commit is contained in:
zhangdingyun 2024-08-12 15:08:47 +08:00
parent 8a7682aaa4
commit 2638f3b10a
3 changed files with 114 additions and 55 deletions

View File

@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig): class ArgumentConfig(PrintableConfig):
########## input arguments ########## ########## input arguments ##########
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human) source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s3.jpg') # path to the source portrait (human/animal) or video (human)
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
########## inference arguments ########## ########## inference arguments ##########
@ -35,6 +35,7 @@ class ArgumentConfig(PrintableConfig):
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly" driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "pose" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
########## source crop arguments ########## ########## source crop arguments ##########
det_thresh: float = 0.15 # detection threshold det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger scale: float = 2.3 # the ratio of face area is smaller if scale is larger

View File

@ -49,6 +49,7 @@ class InferenceConfig(PrintableConfig):
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
source_max_dim: int = 1280 # the max dim of height and width of source image or video source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
# NOT EXPORTED PARAMS # NOT EXPORTED PARAMS
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip

View File

@ -111,6 +111,7 @@ class LivePortraitPipeline(object):
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames'] driving_n_frames = driving_template_dct['n_frames']
flag_is_driving_video = True if driving_n_frames > 1 else False
if flag_is_source_video: if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
else: else:
@ -123,16 +124,25 @@ class LivePortraitPipeline(object):
if args.flag_crop_driving_video: if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving) and is_video(args.driving): elif osp.exists(args.driving):
# load from video file, AND make motion template if is_video(args.driving):
output_fps = int(get_fps(args.driving)) flag_is_driving_video = True
log(f"Load driving video from: {args.driving}, FPS is {output_fps}") # load from video file, AND make motion template
output_fps = int(get_fps(args.driving))
driving_rgb_lst = load_video(args.driving) log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_n_frames = len(driving_rgb_lst)
driving_rgb_lst = load_video(args.driving)
elif is_image(args.driving):
flag_is_driving_video = False
driving_img_rgb = load_image_rgb(args.driving)
output_fps = 1
log(f"Load driving image from {args.driving}")
driving_rgb_lst = [driving_img_rgb]
else:
raise Exception(f"{args.driving} is not a supported type!")
######## make motion template ######## ######## make motion template ########
log("Start making driving motion template...") log("Start making driving motion template...")
driving_n_frames = len(driving_rgb_lst)
if flag_is_source_video: if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
driving_rgb_lst = driving_rgb_lst[:n_frames] driving_rgb_lst = driving_rgb_lst[:n_frames]
@ -158,9 +168,8 @@ class LivePortraitPipeline(object):
wfp_template = remove_suffix(args.driving) + '.pkl' wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, driving_template_dct) dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}") log(f"Dump motion template to {wfp_template}")
else: else:
raise Exception(f"{args.driving} not exists or unsupported driving info types!") raise Exception(f"{args.driving} does not exist!")
######## prepare for pasteback ######## ######## prepare for pasteback ########
I_p_pstbk_lst = None I_p_pstbk_lst = None
@ -288,11 +297,30 @@ class LivePortraitPipeline(object):
else: else:
R_new = R_s R_new = R_s
else: else:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
delta_new = x_s_info['exp']
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
else:
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
elif inf_cfg.animation_region == "lip":
delta_new = x_s_info['exp']
for lip_idx in [14, 17, 19, 20]:
delta_new[:, lip_idx, :] += (x_d_i_info['exp'][:, lip_idx, :] - x_d_0_info['exp'][:, lip_idx, :])
elif inf_cfg.animation_region == "eyes":
delta_new = x_s_info['exp']
for eyes_idx in [11, 13, 15, 16]:
delta_new[:, eyes_idx, :] += (x_d_i_info['exp'][:, eyes_idx, :] - x_d_0_info['exp'][:, eyes_idx, :])
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) if inf_cfg.animation_region == "all":
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) else:
scale_new = x_s_info['scale']
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
t_new = x_s_info['t']
else: else:
if flag_is_source_video: if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation: if inf_cfg.flag_video_editing_head_rotation:
@ -300,15 +328,31 @@ class LivePortraitPipeline(object):
else: else:
R_new = R_s R_new = R_s
else: else:
R_new = R_d_i if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp'] delta_new = x_s_info['exp']
R_new = R_d_i
else:
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
elif inf_cfg.animation_region == "lip":
delta_new = x_s_info['exp']
for lip_idx in [14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
elif inf_cfg.animation_region == "eyes":
delta_new = x_s_info['exp']
for eyes_idx in [11, 13, 15, 16]:
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
scale_new = x_s_info['scale'] scale_new = x_s_info['scale']
t_new = x_d_i_info['t'] if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
t_new = x_d_i_info['t']
else:
t_new = x_s_info['t']
t_new[..., 2].fill_(0) # zero tz t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video: if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all":
if i == 0: if i == 0:
x_d_0_new = x_d_i_new x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
@ -373,50 +417,63 @@ class LivePortraitPipeline(object):
mkdir(args.output_dir) mkdir(args.output_dir)
wfp_concat = None wfp_concat = None
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
######### build the final concatenation result ######### ######### build the final concatenation result #########
# driving frame | source frame | generation, or source frame | generation # driving frame | source frame | generation
if flag_is_source_video: if flag_is_source_video:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
else: else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst) frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# NOTE: update output fps if flag_is_driving_video:
output_fps = source_fps if flag_is_source_video else output_fps flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
if flag_source_has_audio or flag_driving_has_audio: wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result # NOTE: update output fps
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') output_fps = source_fps if flag_is_source_video else output_fps
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
if flag_source_has_audio or flag_driving_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build the final result #########
if flag_source_has_audio or flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}')
else: else:
images2video(I_p_lst, wfp=wfp, fps=output_fps) wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg')
cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1])
######### build the final result ######### wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg')
if flag_source_has_audio or flag_driving_has_audio: if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1])
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source else:
log(f"Audio is selected from {audio_from_which_video}") cv2.imwrite(wfp, frames_concatenated[0][..., ::-1])
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) # final log
os.replace(wfp_with_audio, wfp) log(f'Animated image: {wfp}')
log(f"Replace {wfp_with_audio} with {wfp}") log(f'Animated image with concat: {wfp_concat}')
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}')
return wfp, wfp_concat return wfp, wfp_concat