mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-15 05:52:58 +00:00
feat: update
This commit is contained in:
parent
3f45c776da
commit
8a2ea15471
8
app.py
8
app.py
@ -270,9 +270,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
driving_image_input = gr.Image(type="filepath")
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_video_dir, "d3.jpg")],
|
||||
[osp.join(example_video_dir, "d30.jpg")],
|
||||
[osp.join(example_video_dir, "d9.jpg")],
|
||||
[osp.join(example_video_dir, "d11.jpg")],
|
||||
[osp.join(example_video_dir, "d19.jpg")],
|
||||
[osp.join(example_video_dir, "d8.jpg")],
|
||||
],
|
||||
inputs=[driving_image_input],
|
||||
cache_examples=False,
|
||||
@ -312,7 +313,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||
flag_stitching_input = gr.Checkbox(value=True, label="stitching")
|
||||
animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region")
|
||||
animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="exp", label="animation region")
|
||||
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
|
||||
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
|
||||
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||
@ -327,6 +328,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video"):
|
||||
output_video_concat_i2v.render()
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated image in the original image space"):
|
||||
output_image_i2i.render()
|
||||
|
@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path
|
||||
@dataclass(repr=False) # use repr from PrintableConfig
|
||||
class ArgumentConfig(PrintableConfig):
|
||||
########## input arguments ##########
|
||||
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s18.mp4') # path to the source portrait (human/animal) or video (human)
|
||||
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format)
|
||||
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to the source portrait (human/animal) or video (human)
|
||||
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to driving video or template (.pkl format)
|
||||
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
||||
|
||||
########## inference arguments ##########
|
||||
|
@ -6,10 +6,14 @@ config dataclass used for inference
|
||||
|
||||
import cv2
|
||||
from numpy import ndarray
|
||||
import pickle as pkl
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Tuple
|
||||
from .base_config import PrintableConfig, make_abs_path
|
||||
|
||||
def load_lip_array():
|
||||
with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f:
|
||||
return pkl.load(f)
|
||||
|
||||
@dataclass(repr=False) # use repr from PrintableConfig
|
||||
class InferenceConfig(PrintableConfig):
|
||||
@ -61,4 +65,5 @@ class InferenceConfig(PrintableConfig):
|
||||
output_fps: int = 25 # default output fps
|
||||
|
||||
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
|
||||
lip_array: ndarray = field(default_factory=load_lip_array)
|
||||
size_gif: int = 256 # default gif size, TO IMPLEMENT
|
||||
|
@ -315,6 +315,7 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
if input_lip_ratio != self.source_lip_ratio:
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
|
||||
print(lip_delta)
|
||||
x_d_new = x_d_new + \
|
||||
(eyes_delta if eyes_delta is not None else 0) + \
|
||||
(lip_delta if lip_delta is not None else 0)
|
||||
|
@ -159,7 +159,7 @@ class LivePortraitPipeline(object):
|
||||
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
|
||||
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
|
||||
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
|
||||
if len(ret_d["frame_crop_lst"]) is not n_frames:
|
||||
if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video:
|
||||
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
|
||||
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
|
||||
@ -213,11 +213,19 @@ class LivePortraitPipeline(object):
|
||||
|
||||
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
|
||||
if inf_cfg.flag_relative_motion:
|
||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if flag_is_driving_video:
|
||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else:
|
||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if flag_is_driving_video:
|
||||
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else:
|
||||
x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]
|
||||
else:
|
||||
if flag_is_driving_video:
|
||||
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
|
||||
@ -261,7 +269,10 @@ class LivePortraitPipeline(object):
|
||||
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
|
||||
|
||||
######## animate ########
|
||||
log(f"The animated video consists of {n_frames} frames.")
|
||||
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
|
||||
log(f"The animated video consists of {n_frames} frames.")
|
||||
else:
|
||||
log(f"The output of image-driven portrait animation is an image.")
|
||||
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
|
||||
if flag_is_source_video: # source video
|
||||
x_s_info = source_template_dct['motion'][i]
|
||||
@ -306,7 +317,9 @@ class LivePortraitPipeline(object):
|
||||
|
||||
if i == 0: # cache the first frame
|
||||
R_d_0 = R_d_i
|
||||
x_d_0_info = x_d_i_info
|
||||
x_d_0_info = x_d_i_info.copy()
|
||||
# if not flag_is_driving_video:
|
||||
# x_d_0_info['exp'] = 0
|
||||
|
||||
delta_new = x_s_info['exp'].clone()
|
||||
if inf_cfg.flag_relative_motion:
|
||||
@ -315,7 +328,20 @@ class LivePortraitPipeline(object):
|
||||
else:
|
||||
R_new = R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
# delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
if flag_is_source_video:
|
||||
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
|
||||
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
|
||||
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
|
||||
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
|
||||
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
|
||||
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
|
||||
else:
|
||||
if flag_is_driving_video:
|
||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
else:
|
||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - 0) if x_d_i_info['exp'].mean() > 0 else x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
|
||||
|
||||
elif inf_cfg.animation_region == "lip":
|
||||
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
||||
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
|
||||
@ -355,6 +381,7 @@ class LivePortraitPipeline(object):
|
||||
t_new = x_s_info['t']
|
||||
|
||||
t_new[..., 2].fill_(0) # zero tz
|
||||
# x_d_i_new = x_s_info['scale'] * (x_c_s @ R_s) + x_s_info['t']
|
||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||
|
||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all":
|
||||
|
Loading…
Reference in New Issue
Block a user