feat: image driven, regional animation

This commit is contained in:
zhangdingyun 2024-08-19 17:27:15 +08:00
parent 424a5c74e7
commit 943ee6471a

View File

@ -112,8 +112,6 @@ class LivePortraitPipeline(object):
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames'] driving_n_frames = driving_template_dct['n_frames']
flag_is_driving_video = True if driving_n_frames > 1 else False flag_is_driving_video = True if driving_n_frames > 1 else False
# if flag_is_source_video and not flag_is_driving_video:
# raise Exception(f"Animating a source video with a driving image is not supported!")
if flag_is_source_video and flag_is_driving_video: if flag_is_source_video and flag_is_driving_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
elif flag_is_source_video and not flag_is_driving_video: elif flag_is_source_video and not flag_is_driving_video:
@ -138,8 +136,6 @@ class LivePortraitPipeline(object):
driving_rgb_lst = load_video(args.driving) driving_rgb_lst = load_video(args.driving)
elif is_image(args.driving): elif is_image(args.driving):
flag_is_driving_video = False flag_is_driving_video = False
# if flag_is_source_video:
# raise Exception(f"Animating a source video with a driving image is not supported!")
driving_img_rgb = load_image_rgb(args.driving) driving_img_rgb = load_image_rgb(args.driving)
output_fps = 25 output_fps = 25
log(f"Load driving image from {args.driving}") log(f"Load driving image from {args.driving}")
@ -217,7 +213,6 @@ class LivePortraitPipeline(object):
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
else: else:
# x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)] x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst] x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":