mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-15 14:02:12 +00:00
feat: update
This commit is contained in:
parent
8a7682aaa4
commit
2638f3b10a
@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path
|
|||||||
@dataclass(repr=False) # use repr from PrintableConfig
|
@dataclass(repr=False) # use repr from PrintableConfig
|
||||||
class ArgumentConfig(PrintableConfig):
|
class ArgumentConfig(PrintableConfig):
|
||||||
########## input arguments ##########
|
########## input arguments ##########
|
||||||
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human)
|
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s3.jpg') # path to the source portrait (human/animal) or video (human)
|
||||||
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format)
|
||||||
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
||||||
|
|
||||||
########## inference arguments ##########
|
########## inference arguments ##########
|
||||||
@ -35,6 +35,7 @@ class ArgumentConfig(PrintableConfig):
|
|||||||
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
|
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
|
||||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||||
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
|
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
|
||||||
|
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "pose" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
|
||||||
########## source crop arguments ##########
|
########## source crop arguments ##########
|
||||||
det_thresh: float = 0.15 # detection threshold
|
det_thresh: float = 0.15 # detection threshold
|
||||||
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
||||||
|
@ -49,6 +49,7 @@ class InferenceConfig(PrintableConfig):
|
|||||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||||
source_max_dim: int = 1280 # the max dim of height and width of source image or video
|
source_max_dim: int = 1280 # the max dim of height and width of source image or video
|
||||||
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
||||||
|
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
|
||||||
|
|
||||||
# NOT EXPORTED PARAMS
|
# NOT EXPORTED PARAMS
|
||||||
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
||||||
|
@ -111,6 +111,7 @@ class LivePortraitPipeline(object):
|
|||||||
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
|
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
|
||||||
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
|
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
|
||||||
driving_n_frames = driving_template_dct['n_frames']
|
driving_n_frames = driving_template_dct['n_frames']
|
||||||
|
flag_is_driving_video = True if driving_n_frames > 1 else False
|
||||||
if flag_is_source_video:
|
if flag_is_source_video:
|
||||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||||
else:
|
else:
|
||||||
@ -123,16 +124,25 @@ class LivePortraitPipeline(object):
|
|||||||
if args.flag_crop_driving_video:
|
if args.flag_crop_driving_video:
|
||||||
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
||||||
|
|
||||||
elif osp.exists(args.driving) and is_video(args.driving):
|
elif osp.exists(args.driving):
|
||||||
|
if is_video(args.driving):
|
||||||
|
flag_is_driving_video = True
|
||||||
# load from video file, AND make motion template
|
# load from video file, AND make motion template
|
||||||
output_fps = int(get_fps(args.driving))
|
output_fps = int(get_fps(args.driving))
|
||||||
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
||||||
|
|
||||||
driving_rgb_lst = load_video(args.driving)
|
driving_rgb_lst = load_video(args.driving)
|
||||||
driving_n_frames = len(driving_rgb_lst)
|
elif is_image(args.driving):
|
||||||
|
flag_is_driving_video = False
|
||||||
|
driving_img_rgb = load_image_rgb(args.driving)
|
||||||
|
output_fps = 1
|
||||||
|
log(f"Load driving image from {args.driving}")
|
||||||
|
driving_rgb_lst = [driving_img_rgb]
|
||||||
|
else:
|
||||||
|
raise Exception(f"{args.driving} is not a supported type!")
|
||||||
######## make motion template ########
|
######## make motion template ########
|
||||||
log("Start making driving motion template...")
|
log("Start making driving motion template...")
|
||||||
|
driving_n_frames = len(driving_rgb_lst)
|
||||||
if flag_is_source_video:
|
if flag_is_source_video:
|
||||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||||
driving_rgb_lst = driving_rgb_lst[:n_frames]
|
driving_rgb_lst = driving_rgb_lst[:n_frames]
|
||||||
@ -158,9 +168,8 @@ class LivePortraitPipeline(object):
|
|||||||
wfp_template = remove_suffix(args.driving) + '.pkl'
|
wfp_template = remove_suffix(args.driving) + '.pkl'
|
||||||
dump(wfp_template, driving_template_dct)
|
dump(wfp_template, driving_template_dct)
|
||||||
log(f"Dump motion template to {wfp_template}")
|
log(f"Dump motion template to {wfp_template}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
|
raise Exception(f"{args.driving} does not exist!")
|
||||||
|
|
||||||
######## prepare for pasteback ########
|
######## prepare for pasteback ########
|
||||||
I_p_pstbk_lst = None
|
I_p_pstbk_lst = None
|
||||||
@ -288,11 +297,30 @@ class LivePortraitPipeline(object):
|
|||||||
else:
|
else:
|
||||||
R_new = R_s
|
R_new = R_s
|
||||||
else:
|
else:
|
||||||
|
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||||
|
delta_new = x_s_info['exp']
|
||||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||||
|
else:
|
||||||
|
R_new = R_s
|
||||||
|
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
|
||||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||||
|
elif inf_cfg.animation_region == "lip":
|
||||||
|
delta_new = x_s_info['exp']
|
||||||
|
for lip_idx in [14, 17, 19, 20]:
|
||||||
|
delta_new[:, lip_idx, :] += (x_d_i_info['exp'][:, lip_idx, :] - x_d_0_info['exp'][:, lip_idx, :])
|
||||||
|
elif inf_cfg.animation_region == "eyes":
|
||||||
|
delta_new = x_s_info['exp']
|
||||||
|
for eyes_idx in [11, 13, 15, 16]:
|
||||||
|
delta_new[:, eyes_idx, :] += (x_d_i_info['exp'][:, eyes_idx, :] - x_d_0_info['exp'][:, eyes_idx, :])
|
||||||
|
|
||||||
|
if inf_cfg.animation_region == "all":
|
||||||
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||||
|
else:
|
||||||
|
scale_new = x_s_info['scale']
|
||||||
|
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||||
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||||
|
else:
|
||||||
|
t_new = x_s_info['t']
|
||||||
else:
|
else:
|
||||||
if flag_is_source_video:
|
if flag_is_source_video:
|
||||||
if inf_cfg.flag_video_editing_head_rotation:
|
if inf_cfg.flag_video_editing_head_rotation:
|
||||||
@ -300,15 +328,31 @@ class LivePortraitPipeline(object):
|
|||||||
else:
|
else:
|
||||||
R_new = R_s
|
R_new = R_s
|
||||||
else:
|
else:
|
||||||
|
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||||
|
delta_new = x_s_info['exp']
|
||||||
R_new = R_d_i
|
R_new = R_d_i
|
||||||
|
else:
|
||||||
|
R_new = R_s
|
||||||
|
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
|
||||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
|
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
|
||||||
|
elif inf_cfg.animation_region == "lip":
|
||||||
|
delta_new = x_s_info['exp']
|
||||||
|
for lip_idx in [14, 17, 19, 20]:
|
||||||
|
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
|
||||||
|
elif inf_cfg.animation_region == "eyes":
|
||||||
|
delta_new = x_s_info['exp']
|
||||||
|
for eyes_idx in [11, 13, 15, 16]:
|
||||||
|
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
|
||||||
scale_new = x_s_info['scale']
|
scale_new = x_s_info['scale']
|
||||||
|
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||||
t_new = x_d_i_info['t']
|
t_new = x_d_i_info['t']
|
||||||
|
else:
|
||||||
|
t_new = x_s_info['t']
|
||||||
|
|
||||||
t_new[..., 2].fill_(0) # zero tz
|
t_new[..., 2].fill_(0) # zero tz
|
||||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||||
|
|
||||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video:
|
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all":
|
||||||
if i == 0:
|
if i == 0:
|
||||||
x_d_0_new = x_d_i_new
|
x_d_0_new = x_d_i_new
|
||||||
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
||||||
@ -373,15 +417,17 @@ class LivePortraitPipeline(object):
|
|||||||
|
|
||||||
mkdir(args.output_dir)
|
mkdir(args.output_dir)
|
||||||
wfp_concat = None
|
wfp_concat = None
|
||||||
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
|
|
||||||
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
|
||||||
|
|
||||||
######### build the final concatenation result #########
|
######### build the final concatenation result #########
|
||||||
# driving frame | source frame | generation, or source frame | generation
|
# driving frame | source frame | generation
|
||||||
if flag_is_source_video:
|
if flag_is_source_video:
|
||||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
||||||
else:
|
else:
|
||||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
||||||
|
|
||||||
|
if flag_is_driving_video:
|
||||||
|
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
|
||||||
|
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
||||||
|
|
||||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
||||||
|
|
||||||
# NOTE: update output fps
|
# NOTE: update output fps
|
||||||
@ -418,5 +464,16 @@ class LivePortraitPipeline(object):
|
|||||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||||
log(f'Animated video: {wfp}')
|
log(f'Animated video: {wfp}')
|
||||||
log(f'Animated video with concat: {wfp_concat}')
|
log(f'Animated video with concat: {wfp_concat}')
|
||||||
|
else:
|
||||||
|
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg')
|
||||||
|
cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1])
|
||||||
|
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg')
|
||||||
|
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||||
|
cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1])
|
||||||
|
else:
|
||||||
|
cv2.imwrite(wfp, frames_concatenated[0][..., ::-1])
|
||||||
|
# final log
|
||||||
|
log(f'Animated image: {wfp}')
|
||||||
|
log(f'Animated image with concat: {wfp_concat}')
|
||||||
|
|
||||||
return wfp, wfp_concat
|
return wfp, wfp_concat
|
||||||
|
Loading…
Reference in New Issue
Block a user