feat: image driven, regional animation

This commit is contained in:
zhangdingyun 2024-08-19 14:13:17 +08:00
parent fbb8830b65
commit d1cb6bbc27

View File

@ -437,7 +437,7 @@ class GradioPipeline(LivePortraitPipeline):
if flag_do_crop_input_retargeting_video:
I_p_pstbk_lst = []
I_p_lst = []
for i in track(range(n_frames), description='Silencing video...', total=n_frames):
for i in track(range(n_frames), description='Silencing lip...', total=n_frames):
x_s_user_i = x_s_user_lst[i].to(device)
f_s_user_i = f_s_user_lst[i].to(device)
x_d_i_new = x_d_i_new_lst[i]
@ -570,19 +570,22 @@ class GradioPipeline(LivePortraitPipeline):
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], []
for i in track(range(n_frames), description='Preparing silence lip...', total=n_frames):
for i in track(range(n_frames), description='Preparing silencing lip...', total=n_frames):
x_s_info = source_template_dct['motion'][i]
x_s_info = dct2device(x_s_info, device)
scale_s = x_s_info['scale']
x_s_user = x_s_info['x_s']
delta_new = torch.zeros_like(x_s_info['exp'])
x_c_s = x_s_info['kp']
R_s = x_s_info['R']
t_s = x_s_info['t']
delta_new = torch.zeros_like(x_s_info['exp']) + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :]
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_d_i_new = x_s_info['scale'] * (x_s_info['kp'] @ x_s_info['R'] + delta_new + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)) + x_s_info['t']
x_d_i_new = scale_s * (x_c_s @ R_s + delta_new) + t_s
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new)
return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames
else: