feat: refine and upgrade gradio (#6)

This commit is contained in:
ZhizhouZhong 2024-07-05 11:36:03 +08:00 committed by GitHub
parent 6473a3b8b5
commit 293cb9ee31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 167 additions and 172 deletions

72
app.py
View File

@ -4,10 +4,9 @@
The entrance of the gradio
"""
import os
import os.path as osp
import gradio as gr
import tyro
import gradio as gr
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
@ -43,18 +42,24 @@ data_examples = [
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True],
]
#################### interface logic ####################
# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
output_image = gr.Image(label="The animated image with the given eye-close and lip-close ratio.", type="numpy")
retargeting_input_image = gr.Image(type="numpy")
output_image = gr.Image( type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video()
output_video_concat = gr.Video()
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"):
image_input = gr.Image(label="Please upload the reference portrait here.", type="filepath")
image_input = gr.Image(type="filepath")
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video(label="Please upload the driving video here.")
video_input = gr.Video()
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
@ -63,16 +68,17 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
with gr.Row():
with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video = gr.Video(label="The animated video after pasted back.")
output_video.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat = gr.Video(label="The animated video and driving video.")
with gr.Row():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
output_video_concat.render()
with gr.Row():
# Examples
gr.Markdown("## You could choose the examples below ⬇️")
@ -90,27 +96,35 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
)
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
with gr.Row():
with gr.Column():
process_button_close_ratio = gr.Button("🤖 Calculate the eye-close and lip-close ratio")
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton([output_image, eye_retargeting_slider, lip_retargeting_slider], value="🧹 Clear")
# with gr.Column():
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Column():
with gr.Accordion(open=True, label="Eye and lip Retargeting Result"):
output_image.render()
# binding functions for buttons
process_button_close_ratio.click(
fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider],
show_progress=True
with gr.Row():
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton(
[
eye_retargeting_slider,
lip_retargeting_slider,
retargeting_input_image,
output_image,
output_image_paste_back
],
value="🧹 Clear"
)
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"):
retargeting_input_image.render()
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render()
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render()
# binding functions for buttons
process_button_retargeting.click(
fn=gradio_pipeline.execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider],
outputs=output_image,
outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_animation.click(
@ -125,8 +139,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
outputs=[output_video, output_video_concat],
show_progress=True
)
process_button_reset.click()
process_button_reset_retargeting
image_input.change(
fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
)
##########################################################
demo.launch(

View File

@ -1,7 +1 @@
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please:</span>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">1. Please <strong>first</strong> press the <strong>🤖 Calculate the eye-close and lip-close ratio</strong> button, and wait for the result shown in the sliders.</span>
</div>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">2. Please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. Then the result would be shown in the middle block. You can try running it multiple times!</span>
</div>
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>

View File

@ -2,7 +2,7 @@
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href=""><img src="https://img.shields.io/badge/arXiv-XXXX.XXXX-red"></a>
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div>

View File

@ -3,13 +3,14 @@
"""
Pipeline for gradio
"""
import gradio as gr
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from .utils.rprint import rlog as log
def update_args(args, user_args):
"""update the args according to user inputs
@ -26,10 +27,15 @@ class GradioPipeline(LivePortraitPipeline):
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
# for single image retargeting
self.start_prepare = False
self.f_s_user = None
self.x_c_s_info_user = None
self.x_s_user = None
self.source_lmk_user = None
self.mask_ori = None
self.img_rgb = None
self.crop_M_c2o = None
def execute_video(
self,
@ -41,6 +47,7 @@ class GradioPipeline(LivePortraitPipeline):
):
""" for video driven potrait animation
"""
if input_image_path is not None and input_video_path is not None:
args_user = {
'source_image': input_image_path,
'driving_info': input_video_path,
@ -54,11 +61,28 @@ class GradioPipeline(LivePortraitPipeline):
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
return video_path, video_path_concat
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
else:
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting
"""
if input_eye_ratio is None or input_eye_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
elif self.f_s_user is None:
if self.start_prepare:
raise gr.Error(
"The reference portrait is under processing 💥! Please wait for a second.",
duration=5
)
else:
raise gr.Error(
"The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
duration=5
)
else:
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
@ -71,11 +95,17 @@ class GradioPipeline(LivePortraitPipeline):
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
return out
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
""" for single image retargeting
"""
if input_image_path is not None:
gr.Info("Upload successfully!", duration=2)
self.start_prepare = True
inference_cfg = self.live_portrait_wrapper.cfg
######## process reference portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
@ -94,11 +124,17 @@ class GradioPipeline(LivePortraitPipeline):
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
self.img_rgb = img_rgb
self.crop_M_c2o = crop_info['M_c2o']
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
return eye_close_ratio, lip_close_ratio
# for vis
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
return eye_close_ratio, lip_close_ratio, self.I_s_vis
else:
# when press the clear button, go here
return 0.8, 0.8, self.I_s_vis

View File

@ -20,10 +20,10 @@ from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames
from .utils.crop import _transform_img
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.retargeting_utils import calc_lip_close_ratio
from .utils.io import load_image_rgb, load_driving_info
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template, resize_to_limit
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
from .utils.rprint import rlog as log
from .live_portrait_wrapper import LivePortraitWrapper
@ -90,10 +90,7 @@ class LivePortraitPipeline(object):
######## prepare for pasteback ########
if inference_cfg.flag_pasteback:
if inference_cfg.mask_crop is None:
inference_cfg.mask_crop = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
mask_ori = mask_ori.astype(np.float32) / 255.
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_paste_lst = []
#########################################
@ -172,9 +169,7 @@ class LivePortraitPipeline(object):
I_p_lst.append(I_p_i)
if inference_cfg.flag_pasteback:
I_p_i_to_ori = _transform_img(I_p_i, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_i_to_ori_blend = np.clip(mask_ori * I_p_i_to_ori + (1 - mask_ori) * img_rgb, 0, 255).astype(np.uint8)
out = np.hstack([I_p_i_to_ori, I_p_i_to_ori_blend])
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
I_p_paste_lst.append(I_p_i_to_ori_blend)
mkdir(args.output_dir)

View File

@ -12,7 +12,6 @@ import yaml
from src.utils.timer import Timer
from src.utils.helper import load_model, concat_feat
from src.utils.retargeting_utils import compute_eye_delta, compute_lip_delta
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
@ -211,33 +210,6 @@ class LivePortraitWrapper(object):
return delta
def retarget_keypoints(self, frame_idx, num_keypoints, input_eye_ratios, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source, driving_transformed_kp):
# TODO: GPT style, refactor it...
if self.cfg.flag_eye_retargeting:
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eye_delta = compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source)
else:
# α_eyes = 0
eye_delta = None
if self.cfg.flag_lip_retargeting:
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source)
else:
# α_lip = 0
lip_delta = None
if self.cfg.flag_relative: # use x_s
new_driving_kp = kp_source + \
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
else: # use x_d,i
new_driving_kp = driving_transformed_kp + \
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
return new_driving_kp
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3

View File

@ -4,14 +4,17 @@
cropping function and the related preprocess functions for cropping
"""
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
import numpy as np
from .rprint import rprint as print
import os.path as osp
from math import sin, cos, acos, degrees
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
from .rprint import rprint as print
DTYPE = np.float32
CV2_INTERP = cv2.INTER_LINEAR
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
""" conduct similarity or affine transformation to the image, do not do border operation!
@ -391,3 +394,19 @@ def average_bbox_lst(bbox_lst):
bbox_arr = np.array(bbox_lst)
return np.mean(bbox_arr, axis=0).tolist()
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
"""prepare mask for later image paste back
"""
if mask_crop is None:
mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
mask_ori = mask_ori.astype(np.float32) / 255.
return mask_ori
def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
"""paste back the image
"""
dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
return result

View File

@ -1,5 +1,6 @@
# coding: utf-8
import gradio as gr
import numpy as np
import os.path as osp
from typing import List, Union, Tuple
@ -72,6 +73,7 @@ class Cropper(object):
if len(src_face) == 0:
log('No face detected in the source image.')
raise gr.Error("No face detected in the source image 💥!", duration=5)
raise Exception("No face detected in the source image!")
elif len(src_face) > 1:
log(f'More than one face detected in the image, only pick one face by rule {direction}.')

View File

@ -154,22 +154,3 @@ def load_description(fp):
content = f.read()
return content
def resize_to_limit(img, max_dim=1280, n=2):
h, w = img.shape[:2]
if max_dim > 0 and max(h, w) > max_dim:
if h > w:
new_h = max_dim
new_w = int(w * (max_dim / h))
else:
new_w = max_dim
new_h = int(h * (max_dim / w))
img = cv2.resize(img, (new_w, new_h))
n = max(n, 1)
new_h = img.shape[0] - (img.shape[0] % n)
new_w = img.shape[1] - (img.shape[1] % n)
if new_h == 0 or new_w == 0:
return img
if new_h != img.shape[0] or new_w != img.shape[1]:
img = img[:new_h, :new_w]
return img

View File

@ -40,7 +40,7 @@ def contiguous(obj):
return obj
def _resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
"""
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
:param img: the image to be processed.
@ -87,7 +87,7 @@ def load_img_online(obj, mode="bgr", **kwargs):
img = obj
# Resize image to satisfy constraints
img = _resize_to_limit(img, max_dim=max_dim, n=n)
img = resize_to_limit(img, max_dim=max_dim, n=n)
if mode.lower() == "bgr":
return contiguous(img)

View File

@ -4,7 +4,6 @@ Functions to compute distance ratios between specific pairs of facial landmarks
"""
import numpy as np
import torch
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
@ -53,24 +52,3 @@ def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
np.ndarray: Calculated lip-close ratio.
"""
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
def compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source):
input_eye_ratio = input_eye_ratios[frame_idx][0][0]
eye_close_ratio = calc_eye_close_ratio(source_landmarks[None])
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(portrait_wrapper.device_id)
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio]).reshape(1, 1).cuda(portrait_wrapper.device_id)
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
# print(combined_eye_ratio_tensor.mean())
eye_delta = portrait_wrapper.retarget_eye(kp_source, combined_eye_ratio_tensor)
return eye_delta
def compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source):
input_lip_ratio = input_lip_ratios[frame_idx][0]
lip_close_ratio = calc_lip_close_ratio(source_landmarks[None])
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(portrait_wrapper.device_id)
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio]).cuda(portrait_wrapper.device_id)
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
lip_delta = portrait_wrapper.retarget_lip(kp_source, combined_lip_ratio_tensor)
return lip_delta