feat: update

This commit is contained in:
zhangdingyun 2024-08-15 19:24:04 +08:00
parent 3f45c776da
commit 8a2ea15471
5 changed files with 48 additions and 13 deletions

8
app.py
View File

@ -270,9 +270,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
driving_image_input = gr.Image(type="filepath") driving_image_input = gr.Image(type="filepath")
gr.Examples( gr.Examples(
examples=[ examples=[
[osp.join(example_video_dir, "d3.jpg")], [osp.join(example_video_dir, "d30.jpg")],
[osp.join(example_video_dir, "d9.jpg")], [osp.join(example_video_dir, "d9.jpg")],
[osp.join(example_video_dir, "d11.jpg")], [osp.join(example_video_dir, "d19.jpg")],
[osp.join(example_video_dir, "d8.jpg")],
], ],
inputs=[driving_image_input], inputs=[driving_image_input],
cache_examples=False, cache_examples=False,
@ -312,7 +313,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_stitching_input = gr.Checkbox(value=True, label="stitching") flag_stitching_input = gr.Checkbox(value=True, label="stitching")
animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region") animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="exp", label="animation region")
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)") driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02) driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8) driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
@ -327,6 +328,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video"): with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render() output_video_concat_i2v.render()
with gr.Row():
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated image in the original image space"): with gr.Accordion(open=True, label="The animated image in the original image space"):
output_image_i2i.render() output_image_i2i.render()

View File

@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig): class ArgumentConfig(PrintableConfig):
########## input arguments ########## ########## input arguments ##########
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s18.mp4') # path to the source portrait (human/animal) or video (human) source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to the source portrait (human/animal) or video (human)
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format) driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
########## inference arguments ########## ########## inference arguments ##########

View File

@ -6,10 +6,14 @@ config dataclass used for inference
import cv2 import cv2
from numpy import ndarray from numpy import ndarray
import pickle as pkl
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path from .base_config import PrintableConfig, make_abs_path
def load_lip_array():
with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f:
return pkl.load(f)
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig): class InferenceConfig(PrintableConfig):
@ -61,4 +65,5 @@ class InferenceConfig(PrintableConfig):
output_fps: int = 25 # default output fps output_fps: int = 25 # default output fps
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
lip_array: ndarray = field(default_factory=load_lip_array)
size_gif: int = 256 # default gif size, TO IMPLEMENT size_gif: int = 256 # default gif size, TO IMPLEMENT

View File

@ -315,6 +315,7 @@ class GradioPipeline(LivePortraitPipeline):
if input_lip_ratio != self.source_lip_ratio: if input_lip_ratio != self.source_lip_ratio:
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
print(lip_delta)
x_d_new = x_d_new + \ x_d_new = x_d_new + \
(eyes_delta if eyes_delta is not None else 0) + \ (eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0) (lip_delta if lip_delta is not None else 0)

View File

@ -159,7 +159,7 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)): if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
ret_d = self.cropper.crop_driving_video(driving_rgb_lst) ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames: if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"])) n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst'] driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
@ -213,11 +213,19 @@ class LivePortraitPipeline(object):
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
if inf_cfg.flag_relative_motion: if inf_cfg.flag_relative_motion:
if flag_is_driving_video:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
if flag_is_driving_video:
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]
else: else:
if flag_is_driving_video: if flag_is_driving_video:
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)] x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
@ -261,7 +269,10 @@ class LivePortraitPipeline(object):
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
######## animate ######## ######## animate ########
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
log(f"The animated video consists of {n_frames} frames.") log(f"The animated video consists of {n_frames} frames.")
else:
log(f"The output of image-driven portrait animation is an image.")
for i in track(range(n_frames), description='🚀Animating...', total=n_frames): for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
if flag_is_source_video: # source video if flag_is_source_video: # source video
x_s_info = source_template_dct['motion'][i] x_s_info = source_template_dct['motion'][i]
@ -306,7 +317,9 @@ class LivePortraitPipeline(object):
if i == 0: # cache the first frame if i == 0: # cache the first frame
R_d_0 = R_d_i R_d_0 = R_d_i
x_d_0_info = x_d_i_info x_d_0_info = x_d_i_info.copy()
# if not flag_is_driving_video:
# x_d_0_info['exp'] = 0
delta_new = x_s_info['exp'].clone() delta_new = x_s_info['exp'].clone()
if inf_cfg.flag_relative_motion: if inf_cfg.flag_relative_motion:
@ -315,7 +328,20 @@ class LivePortraitPipeline(object):
else: else:
R_new = R_s R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) # delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
if flag_is_source_video:
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
else:
if flag_is_driving_video:
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
else:
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - 0) if x_d_i_info['exp'].mean() > 0 else x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
elif inf_cfg.animation_region == "lip": elif inf_cfg.animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]: for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
@ -355,6 +381,7 @@ class LivePortraitPipeline(object):
t_new = x_s_info['t'] t_new = x_s_info['t']
t_new[..., 2].fill_(0) # zero tz t_new[..., 2].fill_(0) # zero tz
# x_d_i_new = x_s_info['scale'] * (x_c_s @ R_s) + x_s_info['t']
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all": if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all":