mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-14 21:22:43 +00:00
feat: update
This commit is contained in:
parent
8a7682aaa4
commit
2638f3b10a
@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path
|
||||
@dataclass(repr=False) # use repr from PrintableConfig
|
||||
class ArgumentConfig(PrintableConfig):
|
||||
########## input arguments ##########
|
||||
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human)
|
||||
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
||||
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s3.jpg') # path to the source portrait (human/animal) or video (human)
|
||||
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format)
|
||||
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
||||
|
||||
########## inference arguments ##########
|
||||
@ -35,6 +35,7 @@ class ArgumentConfig(PrintableConfig):
|
||||
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
|
||||
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "pose" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
|
||||
########## source crop arguments ##########
|
||||
det_thresh: float = 0.15 # detection threshold
|
||||
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
||||
|
@ -49,6 +49,7 @@ class InferenceConfig(PrintableConfig):
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
source_max_dim: int = 1280 # the max dim of height and width of source image or video
|
||||
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
||||
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
|
||||
|
||||
# NOT EXPORTED PARAMS
|
||||
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
||||
|
@ -111,6 +111,7 @@ class LivePortraitPipeline(object):
|
||||
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
|
||||
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
|
||||
driving_n_frames = driving_template_dct['n_frames']
|
||||
flag_is_driving_video = True if driving_n_frames > 1 else False
|
||||
if flag_is_source_video:
|
||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||
else:
|
||||
@ -123,16 +124,25 @@ class LivePortraitPipeline(object):
|
||||
if args.flag_crop_driving_video:
|
||||
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
||||
|
||||
elif osp.exists(args.driving) and is_video(args.driving):
|
||||
# load from video file, AND make motion template
|
||||
output_fps = int(get_fps(args.driving))
|
||||
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
||||
|
||||
driving_rgb_lst = load_video(args.driving)
|
||||
driving_n_frames = len(driving_rgb_lst)
|
||||
elif osp.exists(args.driving):
|
||||
if is_video(args.driving):
|
||||
flag_is_driving_video = True
|
||||
# load from video file, AND make motion template
|
||||
output_fps = int(get_fps(args.driving))
|
||||
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
||||
|
||||
driving_rgb_lst = load_video(args.driving)
|
||||
elif is_image(args.driving):
|
||||
flag_is_driving_video = False
|
||||
driving_img_rgb = load_image_rgb(args.driving)
|
||||
output_fps = 1
|
||||
log(f"Load driving image from {args.driving}")
|
||||
driving_rgb_lst = [driving_img_rgb]
|
||||
else:
|
||||
raise Exception(f"{args.driving} is not a supported type!")
|
||||
######## make motion template ########
|
||||
log("Start making driving motion template...")
|
||||
driving_n_frames = len(driving_rgb_lst)
|
||||
if flag_is_source_video:
|
||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||
driving_rgb_lst = driving_rgb_lst[:n_frames]
|
||||
@ -158,9 +168,8 @@ class LivePortraitPipeline(object):
|
||||
wfp_template = remove_suffix(args.driving) + '.pkl'
|
||||
dump(wfp_template, driving_template_dct)
|
||||
log(f"Dump motion template to {wfp_template}")
|
||||
|
||||
else:
|
||||
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
|
||||
raise Exception(f"{args.driving} does not exist!")
|
||||
|
||||
######## prepare for pasteback ########
|
||||
I_p_pstbk_lst = None
|
||||
@ -288,11 +297,30 @@ class LivePortraitPipeline(object):
|
||||
else:
|
||||
R_new = R_s
|
||||
else:
|
||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
delta_new = x_s_info['exp']
|
||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||
else:
|
||||
R_new = R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
elif inf_cfg.animation_region == "lip":
|
||||
delta_new = x_s_info['exp']
|
||||
for lip_idx in [14, 17, 19, 20]:
|
||||
delta_new[:, lip_idx, :] += (x_d_i_info['exp'][:, lip_idx, :] - x_d_0_info['exp'][:, lip_idx, :])
|
||||
elif inf_cfg.animation_region == "eyes":
|
||||
delta_new = x_s_info['exp']
|
||||
for eyes_idx in [11, 13, 15, 16]:
|
||||
delta_new[:, eyes_idx, :] += (x_d_i_info['exp'][:, eyes_idx, :] - x_d_0_info['exp'][:, eyes_idx, :])
|
||||
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||
if inf_cfg.animation_region == "all":
|
||||
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||
else:
|
||||
scale_new = x_s_info['scale']
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||
else:
|
||||
t_new = x_s_info['t']
|
||||
else:
|
||||
if flag_is_source_video:
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
@ -300,15 +328,31 @@ class LivePortraitPipeline(object):
|
||||
else:
|
||||
R_new = R_s
|
||||
else:
|
||||
R_new = R_d_i
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
delta_new = x_s_info['exp']
|
||||
R_new = R_d_i
|
||||
else:
|
||||
R_new = R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
|
||||
elif inf_cfg.animation_region == "lip":
|
||||
delta_new = x_s_info['exp']
|
||||
for lip_idx in [14, 17, 19, 20]:
|
||||
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
|
||||
elif inf_cfg.animation_region == "eyes":
|
||||
delta_new = x_s_info['exp']
|
||||
for eyes_idx in [11, 13, 15, 16]:
|
||||
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
|
||||
scale_new = x_s_info['scale']
|
||||
t_new = x_d_i_info['t']
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
t_new = x_d_i_info['t']
|
||||
else:
|
||||
t_new = x_s_info['t']
|
||||
|
||||
t_new[..., 2].fill_(0) # zero tz
|
||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||
|
||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video:
|
||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all":
|
||||
if i == 0:
|
||||
x_d_0_new = x_d_i_new
|
||||
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
||||
@ -373,50 +417,63 @@ class LivePortraitPipeline(object):
|
||||
|
||||
mkdir(args.output_dir)
|
||||
wfp_concat = None
|
||||
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
|
||||
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
||||
|
||||
######### build the final concatenation result #########
|
||||
# driving frame | source frame | generation, or source frame | generation
|
||||
# driving frame | source frame | generation
|
||||
if flag_is_source_video:
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
||||
else:
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
||||
|
||||
# NOTE: update output fps
|
||||
output_fps = source_fps if flag_is_source_video else output_fps
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||
if flag_is_driving_video:
|
||||
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
|
||||
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
||||
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}, concat mode")
|
||||
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||
# NOTE: update output fps
|
||||
output_fps = source_fps if flag_is_source_video else output_fps
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}, concat mode")
|
||||
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}")
|
||||
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
|
||||
# final log
|
||||
if wfp_template not in (None, ''):
|
||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||
log(f'Animated video: {wfp}')
|
||||
log(f'Animated video with concat: {wfp_concat}')
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}")
|
||||
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
|
||||
# final log
|
||||
if wfp_template not in (None, ''):
|
||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||
log(f'Animated video: {wfp}')
|
||||
log(f'Animated video with concat: {wfp_concat}')
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg')
|
||||
cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1])
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1])
|
||||
else:
|
||||
cv2.imwrite(wfp, frames_concatenated[0][..., ::-1])
|
||||
# final log
|
||||
log(f'Animated image: {wfp}')
|
||||
log(f'Animated image with concat: {wfp_concat}')
|
||||
|
||||
return wfp, wfp_concat
|
||||
|
Loading…
Reference in New Issue
Block a user