feat: refine and upgrade gradio (#6)

This commit is contained in:
ZhizhouZhong 2024-07-05 11:36:03 +08:00 committed by GitHub
parent 6473a3b8b5
commit 293cb9ee31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 167 additions and 172 deletions

72
app.py
View File

@ -4,10 +4,9 @@
The entrance of the gradio The entrance of the gradio
""" """
import os
import os.path as osp
import gradio as gr
import tyro import tyro
import gradio as gr
import os.path as osp
from src.utils.helper import load_description from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig from src.config.crop_config import CropConfig
@ -43,18 +42,24 @@ data_examples = [
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True], [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True],
] ]
#################### interface logic #################### #################### interface logic ####################
# Define components first # Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
output_image = gr.Image(label="The animated image with the given eye-close and lip-close ratio.", type="numpy") retargeting_input_image = gr.Image(type="numpy")
output_image = gr.Image( type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video()
output_video_concat = gr.Video()
with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md)) gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md")) gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"): with gr.Accordion(open=True, label="Reference Portrait"):
image_input = gr.Image(label="Please upload the reference portrait here.", type="filepath") image_input = gr.Image(type="filepath")
with gr.Accordion(open=True, label="Driving Video"): with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video(label="Please upload the driving video here.") video_input = gr.Video()
gr.Markdown(load_description("assets/gradio_description_animation.md")) gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Animation Options"): with gr.Accordion(open=True, label="Animation Options"):
@ -63,16 +68,17 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop") flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
with gr.Row(): with gr.Row():
process_button_animation = gr.Button("🚀 Animate", variant="primary") with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"): with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video = gr.Video(label="The animated video after pasted back.") output_video.render()
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video"): with gr.Accordion(open=True, label="The animated video"):
output_video_concat = gr.Video(label="The animated video and driving video.") output_video_concat.render()
with gr.Row():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
with gr.Row(): with gr.Row():
# Examples # Examples
gr.Markdown("## You could choose the examples below ⬇️") gr.Markdown("## You could choose the examples below ⬇️")
@ -89,28 +95,36 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
examples_per_page=5 examples_per_page=5
) )
gr.Markdown(load_description("assets/gradio_description_retargeting.md")) gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
with gr.Row():
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Row():
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton(
[
eye_retargeting_slider,
lip_retargeting_slider,
retargeting_input_image,
output_image,
output_image_paste_back
],
value="🧹 Clear"
)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
process_button_close_ratio = gr.Button("🤖 Calculate the eye-close and lip-close ratio") with gr.Accordion(open=True, label="Retargeting Input"):
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary") retargeting_input_image.render()
process_button_reset_retargeting = gr.ClearButton([output_image, eye_retargeting_slider, lip_retargeting_slider], value="🧹 Clear")
# with gr.Column():
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Eye and lip Retargeting Result"): with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render() output_image.render()
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render()
# binding functions for buttons # binding functions for buttons
process_button_close_ratio.click(
fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider],
show_progress=True
)
process_button_retargeting.click( process_button_retargeting.click(
fn=gradio_pipeline.execute_image, fn=gradio_pipeline.execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider], inputs=[eye_retargeting_slider, lip_retargeting_slider],
outputs=output_image, outputs=[output_image, output_image_paste_back],
show_progress=True show_progress=True
) )
process_button_animation.click( process_button_animation.click(
@ -125,8 +139,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
outputs=[output_video, output_video_concat], outputs=[output_video, output_video_concat],
show_progress=True show_progress=True
) )
process_button_reset.click() image_input.change(
process_button_reset_retargeting fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
)
########################################################## ##########################################################
demo.launch( demo.launch(

View File

@ -1,7 +1 @@
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please:</span> <span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">1. Please <strong>first</strong> press the <strong>🤖 Calculate the eye-close and lip-close ratio</strong> button, and wait for the result shown in the sliders.</span>
</div>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">2. Please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. Then the result would be shown in the middle block. You can try running it multiple times!</span>
</div>

View File

@ -2,7 +2,7 @@
<div> <div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1> <h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;> <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href=""><img src="https://img.shields.io/badge/arXiv-XXXX.XXXX-red"></a> <a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a> <a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a> <a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div> </div>

View File

@ -3,13 +3,14 @@
""" """
Pipeline for gradio Pipeline for gradio
""" """
import gradio as gr
from .config.argument_config import ArgumentConfig from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix from .utils.camera import get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from .utils.rprint import rlog as log
def update_args(args, user_args): def update_args(args, user_args):
"""update the args according to user inputs """update the args according to user inputs
@ -26,10 +27,15 @@ class GradioPipeline(LivePortraitPipeline):
# self.live_portrait_wrapper = self.live_portrait_wrapper # self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args self.args = args
# for single image retargeting # for single image retargeting
self.start_prepare = False
self.f_s_user = None self.f_s_user = None
self.x_c_s_info_user = None self.x_c_s_info_user = None
self.x_s_user = None self.x_s_user = None
self.source_lmk_user = None self.source_lmk_user = None
self.mask_ori = None
self.img_rgb = None
self.crop_M_c2o = None
def execute_video( def execute_video(
self, self,
@ -41,64 +47,94 @@ class GradioPipeline(LivePortraitPipeline):
): ):
""" for video driven potrait animation """ for video driven potrait animation
""" """
args_user = { if input_image_path is not None and input_video_path is not None:
'source_image': input_image_path, args_user = {
'driving_info': input_video_path, 'source_image': input_image_path,
'flag_relative': flag_relative_input, 'driving_info': input_video_path,
'flag_do_crop': flag_do_crop_input, 'flag_relative': flag_relative_input,
'flag_pasteback': flag_remap_input 'flag_do_crop': flag_do_crop_input,
} 'flag_pasteback': flag_remap_input
# update config from user input }
self.args = update_args(self.args, args_user) # update config from user input
self.live_portrait_wrapper.update_config(self.args.__dict__) self.args = update_args(self.args, args_user)
self.cropper.update_config(self.args.__dict__) self.live_portrait_wrapper.update_config(self.args.__dict__)
# video driven animation self.cropper.update_config(self.args.__dict__)
video_path, video_path_concat = self.execute(self.args) # video driven animation
return video_path, video_path_concat video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
else:
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float): def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting """ for single image retargeting
""" """
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) if input_eye_ratio is None or input_eye_ratio is None:
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user) raise gr.Error("Invalid ratio input 💥!", duration=5)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor) elif self.f_s_user is None:
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) if self.start_prepare:
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user) raise gr.Error(
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor) "The reference portrait is under processing 💥! Please wait for a second.",
num_kp = self.x_s_user.shape[1] duration=5
# default: use x_s )
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) else:
# D(W(f_s; x_s, x_d)) raise gr.Error(
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new) "The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
out = self.live_portrait_wrapper.parse_output(out['out'])[0] duration=5
return out )
else:
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
num_kp = self.x_s_user.shape[1]
# default: use x_s
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def prepare_retargeting(self, input_image_path, flag_do_crop = True): def prepare_retargeting(self, input_image_path, flag_do_crop = True):
""" for single image retargeting """ for single image retargeting
""" """
inference_cfg = self.live_portrait_wrapper.cfg if input_image_path is not None:
######## process reference portrait ######## gr.Info("Upload successfully!", duration=2)
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16) self.start_prepare = True
log(f"Load source image from {input_image_path}.") inference_cfg = self.live_portrait_wrapper.cfg
crop_info = self.cropper.crop_single_image(img_rgb) ######## process reference portrait ########
if flag_do_crop: img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
# record global info for next time use
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
self.img_rgb = img_rgb
self.crop_M_c2o = crop_info['M_c2o']
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
# for vis
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
return eye_close_ratio, lip_close_ratio, self.I_s_vis
else: else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb) # when press the clear button, go here
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) return 0.8, 0.8, self.I_s_vis
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
# record global info for next time use
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
return eye_close_ratio, lip_close_ratio

View File

@ -20,10 +20,10 @@ from .config.crop_config import CropConfig
from .utils.cropper import Cropper from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames from .utils.video import images2video, concat_frames
from .utils.crop import _transform_img from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.retargeting_utils import calc_lip_close_ratio from .utils.retargeting_utils import calc_lip_close_ratio
from .utils.io import load_image_rgb, load_driving_info from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template, resize_to_limit from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
from .utils.rprint import rlog as log from .utils.rprint import rlog as log
from .live_portrait_wrapper import LivePortraitWrapper from .live_portrait_wrapper import LivePortraitWrapper
@ -90,10 +90,7 @@ class LivePortraitPipeline(object):
######## prepare for pasteback ######## ######## prepare for pasteback ########
if inference_cfg.flag_pasteback: if inference_cfg.flag_pasteback:
if inference_cfg.mask_crop is None: mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
inference_cfg.mask_crop = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
mask_ori = mask_ori.astype(np.float32) / 255.
I_p_paste_lst = [] I_p_paste_lst = []
######################################### #########################################
@ -172,9 +169,7 @@ class LivePortraitPipeline(object):
I_p_lst.append(I_p_i) I_p_lst.append(I_p_i)
if inference_cfg.flag_pasteback: if inference_cfg.flag_pasteback:
I_p_i_to_ori = _transform_img(I_p_i, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
I_p_i_to_ori_blend = np.clip(mask_ori * I_p_i_to_ori + (1 - mask_ori) * img_rgb, 0, 255).astype(np.uint8)
out = np.hstack([I_p_i_to_ori, I_p_i_to_ori_blend])
I_p_paste_lst.append(I_p_i_to_ori_blend) I_p_paste_lst.append(I_p_i_to_ori_blend)
mkdir(args.output_dir) mkdir(args.output_dir)

View File

@ -12,7 +12,6 @@ import yaml
from src.utils.timer import Timer from src.utils.timer import Timer
from src.utils.helper import load_model, concat_feat from src.utils.helper import load_model, concat_feat
from src.utils.retargeting_utils import compute_eye_delta, compute_lip_delta
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig from src.config.inference_config import InferenceConfig
@ -211,33 +210,6 @@ class LivePortraitWrapper(object):
return delta return delta
def retarget_keypoints(self, frame_idx, num_keypoints, input_eye_ratios, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source, driving_transformed_kp):
# TODO: GPT style, refactor it...
if self.cfg.flag_eye_retargeting:
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eye_delta = compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source)
else:
# α_eyes = 0
eye_delta = None
if self.cfg.flag_lip_retargeting:
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source)
else:
# α_lip = 0
lip_delta = None
if self.cfg.flag_relative: # use x_s
new_driving_kp = kp_source + \
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
else: # use x_d,i
new_driving_kp = driving_transformed_kp + \
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
return new_driving_kp
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" """
kp_source: BxNx3 kp_source: BxNx3

View File

@ -4,14 +4,17 @@
cropping function and the related preprocess functions for cropping cropping function and the related preprocess functions for cropping
""" """
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
import numpy as np import numpy as np
from .rprint import rprint as print import os.path as osp
from math import sin, cos, acos, degrees from math import sin, cos, acos, degrees
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
from .rprint import rprint as print
DTYPE = np.float32 DTYPE = np.float32
CV2_INTERP = cv2.INTER_LINEAR CV2_INTERP = cv2.INTER_LINEAR
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None): def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
""" conduct similarity or affine transformation to the image, do not do border operation! """ conduct similarity or affine transformation to the image, do not do border operation!
@ -391,3 +394,19 @@ def average_bbox_lst(bbox_lst):
bbox_arr = np.array(bbox_lst) bbox_arr = np.array(bbox_lst)
return np.mean(bbox_arr, axis=0).tolist() return np.mean(bbox_arr, axis=0).tolist()
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
"""prepare mask for later image paste back
"""
if mask_crop is None:
mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
mask_ori = mask_ori.astype(np.float32) / 255.
return mask_ori
def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
"""paste back the image
"""
dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
return result

View File

@ -1,5 +1,6 @@
# coding: utf-8 # coding: utf-8
import gradio as gr
import numpy as np import numpy as np
import os.path as osp import os.path as osp
from typing import List, Union, Tuple from typing import List, Union, Tuple
@ -72,6 +73,7 @@ class Cropper(object):
if len(src_face) == 0: if len(src_face) == 0:
log('No face detected in the source image.') log('No face detected in the source image.')
raise gr.Error("No face detected in the source image 💥!", duration=5)
raise Exception("No face detected in the source image!") raise Exception("No face detected in the source image!")
elif len(src_face) > 1: elif len(src_face) > 1:
log(f'More than one face detected in the image, only pick one face by rule {direction}.') log(f'More than one face detected in the image, only pick one face by rule {direction}.')

View File

@ -154,22 +154,3 @@ def load_description(fp):
content = f.read() content = f.read()
return content return content
def resize_to_limit(img, max_dim=1280, n=2):
h, w = img.shape[:2]
if max_dim > 0 and max(h, w) > max_dim:
if h > w:
new_h = max_dim
new_w = int(w * (max_dim / h))
else:
new_w = max_dim
new_h = int(h * (max_dim / w))
img = cv2.resize(img, (new_w, new_h))
n = max(n, 1)
new_h = img.shape[0] - (img.shape[0] % n)
new_w = img.shape[1] - (img.shape[1] % n)
if new_h == 0 or new_w == 0:
return img
if new_h != img.shape[0] or new_w != img.shape[1]:
img = img[:new_h, :new_w]
return img

View File

@ -40,7 +40,7 @@ def contiguous(obj):
return obj return obj
def _resize_to_limit(img: np.ndarray, max_dim=1920, n=2): def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
""" """
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
:param img: the image to be processed. :param img: the image to be processed.
@ -87,7 +87,7 @@ def load_img_online(obj, mode="bgr", **kwargs):
img = obj img = obj
# Resize image to satisfy constraints # Resize image to satisfy constraints
img = _resize_to_limit(img, max_dim=max_dim, n=n) img = resize_to_limit(img, max_dim=max_dim, n=n)
if mode.lower() == "bgr": if mode.lower() == "bgr":
return contiguous(img) return contiguous(img)

View File

@ -4,7 +4,6 @@ Functions to compute distance ratios between specific pairs of facial landmarks
""" """
import numpy as np import numpy as np
import torch
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray: def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
@ -53,24 +52,3 @@ def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
np.ndarray: Calculated lip-close ratio. np.ndarray: Calculated lip-close ratio.
""" """
return calculate_distance_ratio(lmk, 90, 102, 48, 66) return calculate_distance_ratio(lmk, 90, 102, 48, 66)
def compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source):
input_eye_ratio = input_eye_ratios[frame_idx][0][0]
eye_close_ratio = calc_eye_close_ratio(source_landmarks[None])
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(portrait_wrapper.device_id)
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio]).reshape(1, 1).cuda(portrait_wrapper.device_id)
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
# print(combined_eye_ratio_tensor.mean())
eye_delta = portrait_wrapper.retarget_eye(kp_source, combined_eye_ratio_tensor)
return eye_delta
def compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source):
input_lip_ratio = input_lip_ratios[frame_idx][0]
lip_close_ratio = calc_lip_close_ratio(source_landmarks[None])
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(portrait_wrapper.device_id)
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio]).cuda(portrait_wrapper.device_id)
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
lip_delta = portrait_wrapper.retarget_lip(kp_source, combined_lip_ratio_tensor)
return lip_delta