mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 12:22:38 +00:00
feat: refine and upgrade gradio (#6)
This commit is contained in:
parent
6473a3b8b5
commit
293cb9ee31
72
app.py
72
app.py
@ -4,10 +4,9 @@
|
|||||||
The entrance of the gradio
|
The entrance of the gradio
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import gradio as gr
|
|
||||||
import tyro
|
import tyro
|
||||||
|
import gradio as gr
|
||||||
|
import os.path as osp
|
||||||
from src.utils.helper import load_description
|
from src.utils.helper import load_description
|
||||||
from src.gradio_pipeline import GradioPipeline
|
from src.gradio_pipeline import GradioPipeline
|
||||||
from src.config.crop_config import CropConfig
|
from src.config.crop_config import CropConfig
|
||||||
@ -43,18 +42,24 @@ data_examples = [
|
|||||||
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True],
|
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True],
|
||||||
]
|
]
|
||||||
#################### interface logic ####################
|
#################### interface logic ####################
|
||||||
|
|
||||||
# Define components first
|
# Define components first
|
||||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
|
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
|
||||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
|
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
|
||||||
output_image = gr.Image(label="The animated image with the given eye-close and lip-close ratio.", type="numpy")
|
retargeting_input_image = gr.Image(type="numpy")
|
||||||
|
output_image = gr.Image( type="numpy")
|
||||||
|
output_image_paste_back = gr.Image(type="numpy")
|
||||||
|
output_video = gr.Video()
|
||||||
|
output_video_concat = gr.Video()
|
||||||
|
|
||||||
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
||||||
gr.HTML(load_description(title_md))
|
gr.HTML(load_description(title_md))
|
||||||
gr.Markdown(load_description("assets/gradio_description_upload.md"))
|
gr.Markdown(load_description("assets/gradio_description_upload.md"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Accordion(open=True, label="Reference Portrait"):
|
with gr.Accordion(open=True, label="Reference Portrait"):
|
||||||
image_input = gr.Image(label="Please upload the reference portrait here.", type="filepath")
|
image_input = gr.Image(type="filepath")
|
||||||
with gr.Accordion(open=True, label="Driving Video"):
|
with gr.Accordion(open=True, label="Driving Video"):
|
||||||
video_input = gr.Video(label="Please upload the driving video here.")
|
video_input = gr.Video()
|
||||||
gr.Markdown(load_description("assets/gradio_description_animation.md"))
|
gr.Markdown(load_description("assets/gradio_description_animation.md"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Accordion(open=True, label="Animation Options"):
|
with gr.Accordion(open=True, label="Animation Options"):
|
||||||
@ -63,16 +68,17 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||||
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
|
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
with gr.Column():
|
||||||
|
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||||
|
with gr.Column():
|
||||||
|
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
||||||
output_video = gr.Video(label="The animated video after pasted back.")
|
output_video.render()
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="The animated video"):
|
with gr.Accordion(open=True, label="The animated video"):
|
||||||
output_video_concat = gr.Video(label="The animated video and driving video.")
|
output_video_concat.render()
|
||||||
with gr.Row():
|
|
||||||
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# Examples
|
# Examples
|
||||||
gr.Markdown("## You could choose the examples below ⬇️")
|
gr.Markdown("## You could choose the examples below ⬇️")
|
||||||
@ -89,28 +95,36 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
examples_per_page=5
|
examples_per_page=5
|
||||||
)
|
)
|
||||||
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
|
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
|
||||||
|
with gr.Row():
|
||||||
|
eye_retargeting_slider.render()
|
||||||
|
lip_retargeting_slider.render()
|
||||||
|
with gr.Row():
|
||||||
|
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
||||||
|
process_button_reset_retargeting = gr.ClearButton(
|
||||||
|
[
|
||||||
|
eye_retargeting_slider,
|
||||||
|
lip_retargeting_slider,
|
||||||
|
retargeting_input_image,
|
||||||
|
output_image,
|
||||||
|
output_image_paste_back
|
||||||
|
],
|
||||||
|
value="🧹 Clear"
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
process_button_close_ratio = gr.Button("🤖 Calculate the eye-close and lip-close ratio")
|
with gr.Accordion(open=True, label="Retargeting Input"):
|
||||||
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
retargeting_input_image.render()
|
||||||
process_button_reset_retargeting = gr.ClearButton([output_image, eye_retargeting_slider, lip_retargeting_slider], value="🧹 Clear")
|
|
||||||
# with gr.Column():
|
|
||||||
eye_retargeting_slider.render()
|
|
||||||
lip_retargeting_slider.render()
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="Eye and lip Retargeting Result"):
|
with gr.Accordion(open=True, label="Retargeting Result"):
|
||||||
output_image.render()
|
output_image.render()
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Accordion(open=True, label="Paste-back Result"):
|
||||||
|
output_image_paste_back.render()
|
||||||
# binding functions for buttons
|
# binding functions for buttons
|
||||||
process_button_close_ratio.click(
|
|
||||||
fn=gradio_pipeline.prepare_retargeting,
|
|
||||||
inputs=image_input,
|
|
||||||
outputs=[eye_retargeting_slider, lip_retargeting_slider],
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
process_button_retargeting.click(
|
process_button_retargeting.click(
|
||||||
fn=gradio_pipeline.execute_image,
|
fn=gradio_pipeline.execute_image,
|
||||||
inputs=[eye_retargeting_slider, lip_retargeting_slider],
|
inputs=[eye_retargeting_slider, lip_retargeting_slider],
|
||||||
outputs=output_image,
|
outputs=[output_image, output_image_paste_back],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
process_button_animation.click(
|
process_button_animation.click(
|
||||||
@ -125,8 +139,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
outputs=[output_video, output_video_concat],
|
outputs=[output_video, output_video_concat],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
process_button_reset.click()
|
image_input.change(
|
||||||
process_button_reset_retargeting
|
fn=gradio_pipeline.prepare_retargeting,
|
||||||
|
inputs=image_input,
|
||||||
|
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
|
||||||
|
)
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
demo.launch(
|
demo.launch(
|
||||||
|
@ -1,7 +1 @@
|
|||||||
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please:</span>
|
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
||||||
<div style="margin-left: 20px;">
|
|
||||||
<span style="font-size: 1.2em;">1. Please <strong>first</strong> press the <strong>🤖 Calculate the eye-close and lip-close ratio</strong> button, and wait for the result shown in the sliders.</span>
|
|
||||||
</div>
|
|
||||||
<div style="margin-left: 20px;">
|
|
||||||
<span style="font-size: 1.2em;">2. Please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. Then the result would be shown in the middle block. You can try running it multiple times!</span>
|
|
||||||
</div>
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
<div>
|
<div>
|
||||||
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
|
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
|
||||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
||||||
<a href=""><img src="https://img.shields.io/badge/arXiv-XXXX.XXXX-red"></a>
|
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
||||||
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
||||||
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
||||||
</div>
|
</div>
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
"""
|
"""
|
||||||
Pipeline for gradio
|
Pipeline for gradio
|
||||||
"""
|
"""
|
||||||
|
import gradio as gr
|
||||||
from .config.argument_config import ArgumentConfig
|
from .config.argument_config import ArgumentConfig
|
||||||
from .live_portrait_pipeline import LivePortraitPipeline
|
from .live_portrait_pipeline import LivePortraitPipeline
|
||||||
from .utils.io import load_img_online
|
from .utils.io import load_img_online
|
||||||
|
from .utils.rprint import rlog as log
|
||||||
|
from .utils.crop import prepare_paste_back, paste_back
|
||||||
from .utils.camera import get_rotation_matrix
|
from .utils.camera import get_rotation_matrix
|
||||||
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
||||||
from .utils.rprint import rlog as log
|
|
||||||
|
|
||||||
def update_args(args, user_args):
|
def update_args(args, user_args):
|
||||||
"""update the args according to user inputs
|
"""update the args according to user inputs
|
||||||
@ -26,10 +27,15 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
||||||
self.args = args
|
self.args = args
|
||||||
# for single image retargeting
|
# for single image retargeting
|
||||||
|
self.start_prepare = False
|
||||||
self.f_s_user = None
|
self.f_s_user = None
|
||||||
self.x_c_s_info_user = None
|
self.x_c_s_info_user = None
|
||||||
self.x_s_user = None
|
self.x_s_user = None
|
||||||
self.source_lmk_user = None
|
self.source_lmk_user = None
|
||||||
|
self.mask_ori = None
|
||||||
|
self.img_rgb = None
|
||||||
|
self.crop_M_c2o = None
|
||||||
|
|
||||||
|
|
||||||
def execute_video(
|
def execute_video(
|
||||||
self,
|
self,
|
||||||
@ -41,64 +47,94 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
):
|
):
|
||||||
""" for video driven potrait animation
|
""" for video driven potrait animation
|
||||||
"""
|
"""
|
||||||
args_user = {
|
if input_image_path is not None and input_video_path is not None:
|
||||||
'source_image': input_image_path,
|
args_user = {
|
||||||
'driving_info': input_video_path,
|
'source_image': input_image_path,
|
||||||
'flag_relative': flag_relative_input,
|
'driving_info': input_video_path,
|
||||||
'flag_do_crop': flag_do_crop_input,
|
'flag_relative': flag_relative_input,
|
||||||
'flag_pasteback': flag_remap_input
|
'flag_do_crop': flag_do_crop_input,
|
||||||
}
|
'flag_pasteback': flag_remap_input
|
||||||
# update config from user input
|
}
|
||||||
self.args = update_args(self.args, args_user)
|
# update config from user input
|
||||||
self.live_portrait_wrapper.update_config(self.args.__dict__)
|
self.args = update_args(self.args, args_user)
|
||||||
self.cropper.update_config(self.args.__dict__)
|
self.live_portrait_wrapper.update_config(self.args.__dict__)
|
||||||
# video driven animation
|
self.cropper.update_config(self.args.__dict__)
|
||||||
video_path, video_path_concat = self.execute(self.args)
|
# video driven animation
|
||||||
return video_path, video_path_concat
|
video_path, video_path_concat = self.execute(self.args)
|
||||||
|
gr.Info("Run successfully!", duration=2)
|
||||||
|
return video_path, video_path_concat,
|
||||||
|
else:
|
||||||
|
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
||||||
|
|
||||||
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
||||||
""" for single image retargeting
|
""" for single image retargeting
|
||||||
"""
|
"""
|
||||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
if input_eye_ratio is None or input_eye_ratio is None:
|
||||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
|
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
|
elif self.f_s_user is None:
|
||||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
if self.start_prepare:
|
||||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
|
raise gr.Error(
|
||||||
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
|
"The reference portrait is under processing 💥! Please wait for a second.",
|
||||||
num_kp = self.x_s_user.shape[1]
|
duration=5
|
||||||
# default: use x_s
|
)
|
||||||
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
else:
|
||||||
# D(W(f_s; x_s, x′_d))
|
raise gr.Error(
|
||||||
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
|
"The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
|
||||||
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
duration=5
|
||||||
return out
|
)
|
||||||
|
else:
|
||||||
|
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||||
|
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
|
||||||
|
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
|
||||||
|
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||||
|
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
|
||||||
|
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
|
||||||
|
num_kp = self.x_s_user.shape[1]
|
||||||
|
# default: use x_s
|
||||||
|
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
||||||
|
# D(W(f_s; x_s, x′_d))
|
||||||
|
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
|
||||||
|
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||||
|
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
|
||||||
|
gr.Info("Run successfully!", duration=2)
|
||||||
|
return out, out_to_ori_blend
|
||||||
|
|
||||||
|
|
||||||
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
|
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
|
||||||
""" for single image retargeting
|
""" for single image retargeting
|
||||||
"""
|
"""
|
||||||
inference_cfg = self.live_portrait_wrapper.cfg
|
if input_image_path is not None:
|
||||||
######## process reference portrait ########
|
gr.Info("Upload successfully!", duration=2)
|
||||||
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
self.start_prepare = True
|
||||||
log(f"Load source image from {input_image_path}.")
|
inference_cfg = self.live_portrait_wrapper.cfg
|
||||||
crop_info = self.cropper.crop_single_image(img_rgb)
|
######## process reference portrait ########
|
||||||
if flag_do_crop:
|
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
||||||
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
log(f"Load source image from {input_image_path}.")
|
||||||
|
crop_info = self.cropper.crop_single_image(img_rgb)
|
||||||
|
if flag_do_crop:
|
||||||
|
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
||||||
|
else:
|
||||||
|
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
||||||
|
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||||
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||||
|
############################################
|
||||||
|
|
||||||
|
# record global info for next time use
|
||||||
|
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||||
|
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||||
|
self.x_s_info_user = x_s_info
|
||||||
|
self.source_lmk_user = crop_info['lmk_crop']
|
||||||
|
self.img_rgb = img_rgb
|
||||||
|
self.crop_M_c2o = crop_info['M_c2o']
|
||||||
|
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||||
|
# update slider
|
||||||
|
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
|
||||||
|
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
|
||||||
|
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
|
||||||
|
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
|
||||||
|
# for vis
|
||||||
|
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
|
||||||
|
return eye_close_ratio, lip_close_ratio, self.I_s_vis
|
||||||
else:
|
else:
|
||||||
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
# when press the clear button, go here
|
||||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
return 0.8, 0.8, self.I_s_vis
|
||||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
|
||||||
############################################
|
|
||||||
|
|
||||||
# record global info for next time use
|
|
||||||
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
|
||||||
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
|
||||||
self.x_s_info_user = x_s_info
|
|
||||||
self.source_lmk_user = crop_info['lmk_crop']
|
|
||||||
|
|
||||||
# update slider
|
|
||||||
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
|
|
||||||
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
|
|
||||||
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
|
|
||||||
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
|
|
||||||
|
|
||||||
return eye_close_ratio, lip_close_ratio
|
|
||||||
|
@ -20,10 +20,10 @@ from .config.crop_config import CropConfig
|
|||||||
from .utils.cropper import Cropper
|
from .utils.cropper import Cropper
|
||||||
from .utils.camera import get_rotation_matrix
|
from .utils.camera import get_rotation_matrix
|
||||||
from .utils.video import images2video, concat_frames
|
from .utils.video import images2video, concat_frames
|
||||||
from .utils.crop import _transform_img
|
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
||||||
from .utils.retargeting_utils import calc_lip_close_ratio
|
from .utils.retargeting_utils import calc_lip_close_ratio
|
||||||
from .utils.io import load_image_rgb, load_driving_info
|
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
|
||||||
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template, resize_to_limit
|
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
|
||||||
from .utils.rprint import rlog as log
|
from .utils.rprint import rlog as log
|
||||||
from .live_portrait_wrapper import LivePortraitWrapper
|
from .live_portrait_wrapper import LivePortraitWrapper
|
||||||
|
|
||||||
@ -90,10 +90,7 @@ class LivePortraitPipeline(object):
|
|||||||
|
|
||||||
######## prepare for pasteback ########
|
######## prepare for pasteback ########
|
||||||
if inference_cfg.flag_pasteback:
|
if inference_cfg.flag_pasteback:
|
||||||
if inference_cfg.mask_crop is None:
|
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||||
inference_cfg.mask_crop = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
|
|
||||||
mask_ori = _transform_img(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
|
||||||
mask_ori = mask_ori.astype(np.float32) / 255.
|
|
||||||
I_p_paste_lst = []
|
I_p_paste_lst = []
|
||||||
#########################################
|
#########################################
|
||||||
|
|
||||||
@ -172,9 +169,7 @@ class LivePortraitPipeline(object):
|
|||||||
I_p_lst.append(I_p_i)
|
I_p_lst.append(I_p_i)
|
||||||
|
|
||||||
if inference_cfg.flag_pasteback:
|
if inference_cfg.flag_pasteback:
|
||||||
I_p_i_to_ori = _transform_img(I_p_i, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
||||||
I_p_i_to_ori_blend = np.clip(mask_ori * I_p_i_to_ori + (1 - mask_ori) * img_rgb, 0, 255).astype(np.uint8)
|
|
||||||
out = np.hstack([I_p_i_to_ori, I_p_i_to_ori_blend])
|
|
||||||
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
||||||
|
|
||||||
mkdir(args.output_dir)
|
mkdir(args.output_dir)
|
||||||
|
@ -12,7 +12,6 @@ import yaml
|
|||||||
|
|
||||||
from src.utils.timer import Timer
|
from src.utils.timer import Timer
|
||||||
from src.utils.helper import load_model, concat_feat
|
from src.utils.helper import load_model, concat_feat
|
||||||
from src.utils.retargeting_utils import compute_eye_delta, compute_lip_delta
|
|
||||||
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
||||||
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
||||||
from src.config.inference_config import InferenceConfig
|
from src.config.inference_config import InferenceConfig
|
||||||
@ -211,33 +210,6 @@ class LivePortraitWrapper(object):
|
|||||||
|
|
||||||
return delta
|
return delta
|
||||||
|
|
||||||
def retarget_keypoints(self, frame_idx, num_keypoints, input_eye_ratios, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source, driving_transformed_kp):
|
|
||||||
# TODO: GPT style, refactor it...
|
|
||||||
if self.cfg.flag_eye_retargeting:
|
|
||||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
|
||||||
eye_delta = compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source)
|
|
||||||
else:
|
|
||||||
# α_eyes = 0
|
|
||||||
eye_delta = None
|
|
||||||
|
|
||||||
if self.cfg.flag_lip_retargeting:
|
|
||||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
|
||||||
lip_delta = compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source)
|
|
||||||
else:
|
|
||||||
# α_lip = 0
|
|
||||||
lip_delta = None
|
|
||||||
|
|
||||||
if self.cfg.flag_relative: # use x_s
|
|
||||||
new_driving_kp = kp_source + \
|
|
||||||
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
|
|
||||||
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
|
|
||||||
else: # use x_d,i
|
|
||||||
new_driving_kp = driving_transformed_kp + \
|
|
||||||
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
|
|
||||||
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
|
|
||||||
|
|
||||||
return new_driving_kp
|
|
||||||
|
|
||||||
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
kp_source: BxNx3
|
kp_source: BxNx3
|
||||||
|
@ -4,14 +4,17 @@
|
|||||||
cropping function and the related preprocess functions for cropping
|
cropping function and the related preprocess functions for cropping
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .rprint import rprint as print
|
import os.path as osp
|
||||||
from math import sin, cos, acos, degrees
|
from math import sin, cos, acos, degrees
|
||||||
|
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
|
||||||
|
from .rprint import rprint as print
|
||||||
|
|
||||||
DTYPE = np.float32
|
DTYPE = np.float32
|
||||||
CV2_INTERP = cv2.INTER_LINEAR
|
CV2_INTERP = cv2.INTER_LINEAR
|
||||||
|
|
||||||
|
def make_abs_path(fn):
|
||||||
|
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
||||||
|
|
||||||
def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
|
def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
|
||||||
""" conduct similarity or affine transformation to the image, do not do border operation!
|
""" conduct similarity or affine transformation to the image, do not do border operation!
|
||||||
@ -391,3 +394,19 @@ def average_bbox_lst(bbox_lst):
|
|||||||
bbox_arr = np.array(bbox_lst)
|
bbox_arr = np.array(bbox_lst)
|
||||||
return np.mean(bbox_arr, axis=0).tolist()
|
return np.mean(bbox_arr, axis=0).tolist()
|
||||||
|
|
||||||
|
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
|
||||||
|
"""prepare mask for later image paste back
|
||||||
|
"""
|
||||||
|
if mask_crop is None:
|
||||||
|
mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
|
||||||
|
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
|
||||||
|
mask_ori = mask_ori.astype(np.float32) / 255.
|
||||||
|
return mask_ori
|
||||||
|
|
||||||
|
def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
|
||||||
|
"""paste back the image
|
||||||
|
"""
|
||||||
|
dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
|
||||||
|
result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
|
||||||
|
result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
|
||||||
|
return result
|
@ -1,5 +1,6 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from typing import List, Union, Tuple
|
from typing import List, Union, Tuple
|
||||||
@ -72,6 +73,7 @@ class Cropper(object):
|
|||||||
|
|
||||||
if len(src_face) == 0:
|
if len(src_face) == 0:
|
||||||
log('No face detected in the source image.')
|
log('No face detected in the source image.')
|
||||||
|
raise gr.Error("No face detected in the source image 💥!", duration=5)
|
||||||
raise Exception("No face detected in the source image!")
|
raise Exception("No face detected in the source image!")
|
||||||
elif len(src_face) > 1:
|
elif len(src_face) > 1:
|
||||||
log(f'More than one face detected in the image, only pick one face by rule {direction}.')
|
log(f'More than one face detected in the image, only pick one face by rule {direction}.')
|
||||||
|
@ -154,22 +154,3 @@ def load_description(fp):
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def resize_to_limit(img, max_dim=1280, n=2):
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
if max_dim > 0 and max(h, w) > max_dim:
|
|
||||||
if h > w:
|
|
||||||
new_h = max_dim
|
|
||||||
new_w = int(w * (max_dim / h))
|
|
||||||
else:
|
|
||||||
new_w = max_dim
|
|
||||||
new_h = int(h * (max_dim / w))
|
|
||||||
img = cv2.resize(img, (new_w, new_h))
|
|
||||||
n = max(n, 1)
|
|
||||||
new_h = img.shape[0] - (img.shape[0] % n)
|
|
||||||
new_w = img.shape[1] - (img.shape[1] % n)
|
|
||||||
if new_h == 0 or new_w == 0:
|
|
||||||
return img
|
|
||||||
if new_h != img.shape[0] or new_w != img.shape[1]:
|
|
||||||
img = img[:new_h, :new_w]
|
|
||||||
return img
|
|
||||||
|
@ -40,7 +40,7 @@ def contiguous(obj):
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def _resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
|
def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
|
||||||
"""
|
"""
|
||||||
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
|
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
|
||||||
:param img: the image to be processed.
|
:param img: the image to be processed.
|
||||||
@ -87,7 +87,7 @@ def load_img_online(obj, mode="bgr", **kwargs):
|
|||||||
img = obj
|
img = obj
|
||||||
|
|
||||||
# Resize image to satisfy constraints
|
# Resize image to satisfy constraints
|
||||||
img = _resize_to_limit(img, max_dim=max_dim, n=n)
|
img = resize_to_limit(img, max_dim=max_dim, n=n)
|
||||||
|
|
||||||
if mode.lower() == "bgr":
|
if mode.lower() == "bgr":
|
||||||
return contiguous(img)
|
return contiguous(img)
|
||||||
|
@ -4,7 +4,6 @@ Functions to compute distance ratios between specific pairs of facial landmarks
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
|
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
|
||||||
@ -53,24 +52,3 @@ def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
|
|||||||
np.ndarray: Calculated lip-close ratio.
|
np.ndarray: Calculated lip-close ratio.
|
||||||
"""
|
"""
|
||||||
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
|
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
|
||||||
|
|
||||||
|
|
||||||
def compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source):
|
|
||||||
input_eye_ratio = input_eye_ratios[frame_idx][0][0]
|
|
||||||
eye_close_ratio = calc_eye_close_ratio(source_landmarks[None])
|
|
||||||
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(portrait_wrapper.device_id)
|
|
||||||
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio]).reshape(1, 1).cuda(portrait_wrapper.device_id)
|
|
||||||
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
|
|
||||||
# print(combined_eye_ratio_tensor.mean())
|
|
||||||
eye_delta = portrait_wrapper.retarget_eye(kp_source, combined_eye_ratio_tensor)
|
|
||||||
return eye_delta
|
|
||||||
|
|
||||||
|
|
||||||
def compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source):
|
|
||||||
input_lip_ratio = input_lip_ratios[frame_idx][0]
|
|
||||||
lip_close_ratio = calc_lip_close_ratio(source_landmarks[None])
|
|
||||||
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(portrait_wrapper.device_id)
|
|
||||||
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio]).cuda(portrait_wrapper.device_id)
|
|
||||||
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
|
|
||||||
lip_delta = portrait_wrapper.retarget_lip(kp_source, combined_lip_ratio_tensor)
|
|
||||||
return lip_delta
|
|
||||||
|
Loading…
Reference in New Issue
Block a user