mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-15 05:52:58 +00:00
feat: update
This commit is contained in:
parent
8a2ea15471
commit
18ebcc8b61
24
app.py
24
app.py
@ -98,6 +98,7 @@ data_examples_v2v = [
|
|||||||
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
|
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
|
||||||
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
|
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
|
||||||
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
|
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||||
|
video_retargeting_silence = gr.Checkbox(value=False, label="keeping the lip silent")
|
||||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
||||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||||
video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||||
@ -124,10 +125,6 @@ retargeting_output_image = gr.Image(type="numpy")
|
|||||||
retargeting_output_image_paste_back = gr.Image(type="numpy")
|
retargeting_output_image_paste_back = gr.Image(type="numpy")
|
||||||
output_video = gr.Video(autoplay=False)
|
output_video = gr.Video(autoplay=False)
|
||||||
output_video_paste_back = gr.Video(autoplay=False)
|
output_video_paste_back = gr.Video(autoplay=False)
|
||||||
output_video_i2v = gr.Video(autoplay=False)
|
|
||||||
output_video_concat_i2v = gr.Video(autoplay=False)
|
|
||||||
output_image_i2i = gr.Image(type="numpy")
|
|
||||||
output_image_concat_i2i = gr.Image(type="numpy")
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
每个点和每个维度对应的表情:
|
每个点和每个维度对应的表情:
|
||||||
@ -274,6 +271,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
|||||||
[osp.join(example_video_dir, "d9.jpg")],
|
[osp.join(example_video_dir, "d9.jpg")],
|
||||||
[osp.join(example_video_dir, "d19.jpg")],
|
[osp.join(example_video_dir, "d19.jpg")],
|
||||||
[osp.join(example_video_dir, "d8.jpg")],
|
[osp.join(example_video_dir, "d8.jpg")],
|
||||||
|
[osp.join(example_video_dir, "d12.jpg")],
|
||||||
|
[osp.join(example_video_dir, "d38.jpg")],
|
||||||
],
|
],
|
||||||
inputs=[driving_image_input],
|
inputs=[driving_image_input],
|
||||||
cache_examples=False,
|
cache_examples=False,
|
||||||
@ -323,18 +322,14 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
|||||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
output_video_i2v = gr.Video(autoplay=False, label="The animated video in the original image space")
|
||||||
output_video_i2v.render()
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="The animated video"):
|
output_video_concat_i2v = gr.Video(autoplay=False, label="The animated video")
|
||||||
output_video_concat_i2v.render()
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="The animated image in the original image space"):
|
output_image_i2i = gr.Image(type="numpy", label="The animated image in the original image space", visible=False)
|
||||||
output_image_i2i.render()
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="The animated image"):
|
output_image_concat_i2i = gr.Image(type="numpy", label="The animated image", visible=False)
|
||||||
output_image_concat_i2i.render()
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear")
|
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear")
|
||||||
|
|
||||||
@ -463,6 +458,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
|||||||
video_retargeting_source_scale.render()
|
video_retargeting_source_scale.render()
|
||||||
video_lip_retargeting_slider.render()
|
video_lip_retargeting_slider.render()
|
||||||
driving_smooth_observation_variance_retargeting.render()
|
driving_smooth_observation_variance_retargeting.render()
|
||||||
|
video_retargeting_silence.render()
|
||||||
with gr.Row(visible=True):
|
with gr.Row(visible=True):
|
||||||
process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
|
process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
|
||||||
with gr.Row(visible=True):
|
with gr.Row(visible=True):
|
||||||
@ -524,7 +520,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
|||||||
tab_selection,
|
tab_selection,
|
||||||
v_tab_selection,
|
v_tab_selection,
|
||||||
],
|
],
|
||||||
outputs=[output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i],
|
outputs=[output_video_i2v, output_video_i2v, output_video_concat_i2v, output_video_concat_i2v, output_image_i2i, output_image_i2i, output_image_concat_i2i, output_image_concat_i2i],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -550,7 +546,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
|||||||
|
|
||||||
process_button_retargeting_video.click(
|
process_button_retargeting_video.click(
|
||||||
fn=gpu_wrapped_execute_video_retargeting,
|
fn=gpu_wrapped_execute_video_retargeting,
|
||||||
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video],
|
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, video_retargeting_silence, flag_do_crop_input_retargeting_video],
|
||||||
outputs=[output_video, output_video_paste_back],
|
outputs=[output_video, output_video_paste_back],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@ from .base_config import PrintableConfig, make_abs_path
|
|||||||
@dataclass(repr=False) # use repr from PrintableConfig
|
@dataclass(repr=False) # use repr from PrintableConfig
|
||||||
class ArgumentConfig(PrintableConfig):
|
class ArgumentConfig(PrintableConfig):
|
||||||
########## input arguments ##########
|
########## input arguments ##########
|
||||||
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to the source portrait (human/animal) or video (human)
|
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s3.jpg') # path to the source portrait (human/animal) or video (human)
|
||||||
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to driving video or template (.pkl format)
|
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to driving video or template (.pkl format)
|
||||||
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
||||||
|
|
||||||
|
@ -219,9 +219,9 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
output_path, output_path_concat = self.execute(self.args)
|
output_path, output_path_concat = self.execute(self.args)
|
||||||
gr.Info("Run successfully!", duration=2)
|
gr.Info("Run successfully!", duration=2)
|
||||||
if output_path.endswith(".jpg"):
|
if output_path.endswith(".jpg"):
|
||||||
return None, None, output_path, output_path_concat
|
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True)
|
||||||
else:
|
else:
|
||||||
return output_path, output_path_concat, None, None
|
return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
|
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
|
||||||
|
|
||||||
@ -396,29 +396,51 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
return source_eye_ratio, source_lip_ratio
|
return source_eye_ratio, source_lip_ratio
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True):
|
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True):
|
||||||
""" retargeting the lip-open ratio of each source frame
|
""" retargeting the lip-open ratio of each source frame
|
||||||
"""
|
"""
|
||||||
# disposable feature
|
# disposable feature
|
||||||
device = self.live_portrait_wrapper.device
|
device = self.live_portrait_wrapper.device
|
||||||
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
|
|
||||||
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
|
|
||||||
|
|
||||||
if input_lip_ratio is None:
|
if not video_retargeting_silence:
|
||||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
|
||||||
|
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
|
||||||
|
if input_lip_ratio is None:
|
||||||
|
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||||
|
else:
|
||||||
|
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||||
|
|
||||||
|
I_p_pstbk_lst = None
|
||||||
|
if flag_do_crop_input_retargeting_video:
|
||||||
|
I_p_pstbk_lst = []
|
||||||
|
I_p_lst = []
|
||||||
|
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
|
||||||
|
x_s_user_i = x_s_user_lst[i].to(device)
|
||||||
|
f_s_user_i = f_s_user_lst[i].to(device)
|
||||||
|
|
||||||
|
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
|
||||||
|
x_d_i_new = x_s_user_i + lip_delta_retargeting
|
||||||
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
|
||||||
|
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
|
||||||
|
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||||
|
I_p_lst.append(I_p_i)
|
||||||
|
|
||||||
|
if flag_do_crop_input_retargeting_video:
|
||||||
|
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
|
||||||
|
I_p_pstbk_lst.append(I_p_pstbk)
|
||||||
else:
|
else:
|
||||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||||
|
f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \
|
||||||
|
self.prepare_video_lip_silence(input_video, device, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
|
||||||
|
|
||||||
I_p_pstbk_lst = None
|
I_p_pstbk_lst = None
|
||||||
if flag_do_crop_input_retargeting_video:
|
if flag_do_crop_input_retargeting_video:
|
||||||
I_p_pstbk_lst = []
|
I_p_pstbk_lst = []
|
||||||
I_p_lst = []
|
I_p_lst = []
|
||||||
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
|
for i in track(range(n_frames), description='Silencing video...', total=n_frames):
|
||||||
x_s_user_i = x_s_user_lst[i].to(device)
|
x_s_user_i = x_s_user_lst[i].to(device)
|
||||||
f_s_user_i = f_s_user_lst[i].to(device)
|
f_s_user_i = f_s_user_lst[i].to(device)
|
||||||
|
x_d_i_new = x_d_i_new_lst[i]
|
||||||
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
|
|
||||||
x_d_i_new = x_s_user_i + lip_delta_retargeting
|
|
||||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
|
||||||
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
|
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
|
||||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||||
@ -428,37 +450,37 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
|
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
|
||||||
I_p_pstbk_lst.append(I_p_pstbk)
|
I_p_pstbk_lst.append(I_p_pstbk)
|
||||||
|
|
||||||
mkdir(self.args.output_dir)
|
mkdir(self.args.output_dir)
|
||||||
flag_source_has_audio = has_audio_stream(input_video)
|
flag_source_has_audio = has_audio_stream(input_video)
|
||||||
|
|
||||||
######### build the final concatenation result #########
|
######### build the final concatenation result #########
|
||||||
# source frame | generation
|
# source frame | generation
|
||||||
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
|
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
|
||||||
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
|
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
|
||||||
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
|
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
|
||||||
|
|
||||||
if flag_source_has_audio:
|
if flag_source_has_audio:
|
||||||
# final result with concatenation
|
# final result with concatenation
|
||||||
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
|
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
|
||||||
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
|
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
|
||||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||||
|
|
||||||
# save the animated result
|
# save the animated result
|
||||||
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
|
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
|
||||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
|
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
|
||||||
else:
|
else:
|
||||||
images2video(I_p_lst, wfp=wfp, fps=source_fps)
|
images2video(I_p_lst, wfp=wfp, fps=source_fps)
|
||||||
|
|
||||||
######### build the final result #########
|
######### build the final result #########
|
||||||
if flag_source_has_audio:
|
if flag_source_has_audio:
|
||||||
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
|
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
|
||||||
add_audio_to_video(wfp, input_video, wfp_with_audio)
|
add_audio_to_video(wfp, input_video, wfp_with_audio)
|
||||||
os.replace(wfp_with_audio, wfp)
|
os.replace(wfp_with_audio, wfp)
|
||||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||||
gr.Info("Run successfully!", duration=2)
|
gr.Info("Run successfully!", duration=2)
|
||||||
return wfp_concat, wfp
|
return wfp_concat, wfp
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
|
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
|
||||||
@ -517,6 +539,59 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
# when press the clear button, go here
|
# when press the clear button, go here
|
||||||
raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
|
raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def prepare_video_lip_silence(self, input_video, device, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
|
||||||
|
""" for keeping lips in the source video silent
|
||||||
|
"""
|
||||||
|
if input_video is not None:
|
||||||
|
# gr.Info("Upload successfully!", duration=2)
|
||||||
|
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||||
|
######## process source video ########
|
||||||
|
source_rgb_lst = load_video(input_video)
|
||||||
|
source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
|
||||||
|
source_fps = int(get_fps(input_video))
|
||||||
|
n_frames = len(source_rgb_lst)
|
||||||
|
log(f"Load source video from {input_video}. FPS is {source_fps}")
|
||||||
|
|
||||||
|
if flag_do_crop:
|
||||||
|
ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
|
||||||
|
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
|
||||||
|
if len(ret_s["frame_crop_lst"]) != n_frames:
|
||||||
|
n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
|
||||||
|
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
|
||||||
|
mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
|
||||||
|
else:
|
||||||
|
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
|
||||||
|
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
|
||||||
|
source_M_c2o_lst, mask_ori_lst = None, None
|
||||||
|
|
||||||
|
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
|
||||||
|
# save the motion template
|
||||||
|
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
|
||||||
|
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
|
||||||
|
|
||||||
|
f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], []
|
||||||
|
for i in track(range(n_frames), description='Preparing silence lip...', total=n_frames):
|
||||||
|
x_s_info = source_template_dct['motion'][i]
|
||||||
|
x_s_info = dct2device(x_s_info, device)
|
||||||
|
x_s_user = x_s_info['x_s']
|
||||||
|
delta_new = torch.zeros_like(x_s_info['exp'])
|
||||||
|
for eyes_idx in [11, 13, 15, 16, 18]:
|
||||||
|
delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :]
|
||||||
|
|
||||||
|
source_lmk = source_lmk_crop_lst[i]
|
||||||
|
img_crop_256x256 = img_crop_256x256_lst[i]
|
||||||
|
I_s = I_s_lst[i]
|
||||||
|
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||||
|
x_d_i_new = x_s_info['scale'] * (x_s_info['kp'] @ x_s_info['R'] + delta_new + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)) + x_s_info['t']
|
||||||
|
|
||||||
|
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new)
|
||||||
|
|
||||||
|
return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames
|
||||||
|
else:
|
||||||
|
# when press the clear button, go here
|
||||||
|
raise gr.Error("Please upload a source video as the input 🤗🤗🤗", duration=5)
|
||||||
|
|
||||||
class GradioPipelineAnimal(LivePortraitPipelineAnimal):
|
class GradioPipelineAnimal(LivePortraitPipelineAnimal):
|
||||||
"""gradio for animal
|
"""gradio for animal
|
||||||
"""
|
"""
|
||||||
|
@ -217,7 +217,8 @@ class LivePortraitPipeline(object):
|
|||||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
||||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||||
else:
|
else:
|
||||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
|
# x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
|
||||||
|
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
|
||||||
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
|
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
|
||||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||||
if flag_is_driving_video:
|
if flag_is_driving_video:
|
||||||
@ -340,14 +341,25 @@ class LivePortraitPipeline(object):
|
|||||||
if flag_is_driving_video:
|
if flag_is_driving_video:
|
||||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||||
else:
|
else:
|
||||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - 0) if x_d_i_info['exp'].mean() > 0 else x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
|
# delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - 0) if x_d_i_info['exp'].mean() > 0 else x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
|
||||||
|
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
|
||||||
elif inf_cfg.animation_region == "lip":
|
elif inf_cfg.animation_region == "lip":
|
||||||
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
||||||
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
|
if flag_is_source_video:
|
||||||
|
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
|
||||||
|
elif flag_is_driving_video:
|
||||||
|
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
|
||||||
|
else:
|
||||||
|
# delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, lip_idx, :] if x_d_i_info['exp'].mean() > 0 else (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :]
|
||||||
|
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :]
|
||||||
elif inf_cfg.animation_region == "eyes":
|
elif inf_cfg.animation_region == "eyes":
|
||||||
for eyes_idx in [11, 13, 15, 16, 18]:
|
for eyes_idx in [11, 13, 15, 16, 18]:
|
||||||
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
|
if flag_is_source_video:
|
||||||
|
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
|
||||||
|
elif flag_is_driving_video:
|
||||||
|
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
|
||||||
|
else:
|
||||||
|
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, eyes_idx, :]
|
||||||
if inf_cfg.animation_region == "all":
|
if inf_cfg.animation_region == "all":
|
||||||
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||||
else:
|
else:
|
||||||
@ -384,7 +396,7 @@ class LivePortraitPipeline(object):
|
|||||||
# x_d_i_new = x_s_info['scale'] * (x_c_s @ R_s) + x_s_info['t']
|
# x_d_i_new = x_s_info['scale'] * (x_c_s @ R_s) + x_s_info['t']
|
||||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||||
|
|
||||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all":
|
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
x_d_0_new = x_d_i_new
|
x_d_0_new = x_d_i_new
|
||||||
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
||||||
@ -454,7 +466,10 @@ class LivePortraitPipeline(object):
|
|||||||
if flag_is_source_video and flag_is_driving_video:
|
if flag_is_source_video and flag_is_driving_video:
|
||||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
||||||
elif flag_is_source_video and not flag_is_driving_video:
|
elif flag_is_source_video and not flag_is_driving_video:
|
||||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst)
|
if flag_load_from_template:
|
||||||
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
||||||
|
else:
|
||||||
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst)
|
||||||
else:
|
else:
|
||||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user