From 293cb9ee31a07b21f59eeb7d9becdf3504436f87 Mon Sep 17 00:00:00 2001
From: ZhizhouZhong <1819489045@qq.com>
Date: Fri, 5 Jul 2024 11:36:03 +0800
Subject: [PATCH] feat: refine and upgrade gradio (#6)
---
app.py | 72 +++++++-----
assets/gradio_description_retargeting.md | 8 +-
assets/gradio_title.md | 2 +-
src/gradio_pipeline.py | 144 ++++++++++++++---------
src/live_portrait_pipeline.py | 15 +--
src/live_portrait_wrapper.py | 28 -----
src/utils/crop.py | 23 +++-
src/utils/cropper.py | 2 +
src/utils/helper.py | 19 ---
src/utils/io.py | 4 +-
src/utils/retargeting_utils.py | 22 ----
11 files changed, 167 insertions(+), 172 deletions(-)
diff --git a/app.py b/app.py
index 1998082..44b0740 100644
--- a/app.py
+++ b/app.py
@@ -4,10 +4,9 @@
The entrance of the gradio
"""
-import os
-import os.path as osp
-import gradio as gr
import tyro
+import gradio as gr
+import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
@@ -43,18 +42,24 @@ data_examples = [
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True],
]
#################### interface logic ####################
+
# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
-output_image = gr.Image(label="The animated image with the given eye-close and lip-close ratio.", type="numpy")
+retargeting_input_image = gr.Image(type="numpy")
+output_image = gr.Image( type="numpy")
+output_image_paste_back = gr.Image(type="numpy")
+output_video = gr.Video()
+output_video_concat = gr.Video()
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"):
- image_input = gr.Image(label="Please upload the reference portrait here.", type="filepath")
+ image_input = gr.Image(type="filepath")
with gr.Accordion(open=True, label="Driving Video"):
- video_input = gr.Video(label="Please upload the driving video here.")
+ video_input = gr.Video()
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
@@ -63,16 +68,17 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
with gr.Row():
- process_button_animation = gr.Button("๐ Animate", variant="primary")
+ with gr.Column():
+ process_button_animation = gr.Button("๐ Animate", variant="primary")
+ with gr.Column():
+ process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="๐งน Clear")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
- output_video = gr.Video(label="The animated video after pasted back.")
+ output_video.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
- output_video_concat = gr.Video(label="The animated video and driving video.")
- with gr.Row():
- process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="๐งน Clear")
+ output_video_concat.render()
with gr.Row():
# Examples
gr.Markdown("## You could choose the examples below โฌ๏ธ")
@@ -89,28 +95,36 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
examples_per_page=5
)
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
+ with gr.Row():
+ eye_retargeting_slider.render()
+ lip_retargeting_slider.render()
+ with gr.Row():
+ process_button_retargeting = gr.Button("๐ Retargeting", variant="primary")
+ process_button_reset_retargeting = gr.ClearButton(
+ [
+ eye_retargeting_slider,
+ lip_retargeting_slider,
+ retargeting_input_image,
+ output_image,
+ output_image_paste_back
+ ],
+ value="๐งน Clear"
+ )
with gr.Row():
with gr.Column():
- process_button_close_ratio = gr.Button("๐ค Calculate the eye-close and lip-close ratio")
- process_button_retargeting = gr.Button("๐ Retargeting", variant="primary")
- process_button_reset_retargeting = gr.ClearButton([output_image, eye_retargeting_slider, lip_retargeting_slider], value="๐งน Clear")
- # with gr.Column():
- eye_retargeting_slider.render()
- lip_retargeting_slider.render()
+ with gr.Accordion(open=True, label="Retargeting Input"):
+ retargeting_input_image.render()
with gr.Column():
- with gr.Accordion(open=True, label="Eye and lip Retargeting Result"):
+ with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render()
+ with gr.Column():
+ with gr.Accordion(open=True, label="Paste-back Result"):
+ output_image_paste_back.render()
# binding functions for buttons
- process_button_close_ratio.click(
- fn=gradio_pipeline.prepare_retargeting,
- inputs=image_input,
- outputs=[eye_retargeting_slider, lip_retargeting_slider],
- show_progress=True
- )
process_button_retargeting.click(
fn=gradio_pipeline.execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider],
- outputs=output_image,
+ outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_animation.click(
@@ -125,8 +139,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
outputs=[output_video, output_video_concat],
show_progress=True
)
- process_button_reset.click()
- process_button_reset_retargeting
+ image_input.change(
+ fn=gradio_pipeline.prepare_retargeting,
+ inputs=image_input,
+ outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
+ )
+
##########################################################
demo.launch(
diff --git a/assets/gradio_description_retargeting.md b/assets/gradio_description_retargeting.md
index 0a5dcba..5fe6ebf 100644
--- a/assets/gradio_description_retargeting.md
+++ b/assets/gradio_description_retargeting.md
@@ -1,7 +1 @@
-๐ฅ To change the target eye-close and lip-close ratio of the reference portrait, please:
-
LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control
diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py
index 45d661b..d176cfc 100644
--- a/src/gradio_pipeline.py
+++ b/src/gradio_pipeline.py
@@ -3,13 +3,14 @@
"""
Pipeline for gradio
"""
-
+import gradio as gr
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
+from .utils.rprint import rlog as log
+from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
-from .utils.rprint import rlog as log
def update_args(args, user_args):
"""update the args according to user inputs
@@ -26,10 +27,15 @@ class GradioPipeline(LivePortraitPipeline):
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
# for single image retargeting
+ self.start_prepare = False
self.f_s_user = None
self.x_c_s_info_user = None
self.x_s_user = None
self.source_lmk_user = None
+ self.mask_ori = None
+ self.img_rgb = None
+ self.crop_M_c2o = None
+
def execute_video(
self,
@@ -41,64 +47,94 @@ class GradioPipeline(LivePortraitPipeline):
):
""" for video driven potrait animation
"""
- args_user = {
- 'source_image': input_image_path,
- 'driving_info': input_video_path,
- 'flag_relative': flag_relative_input,
- 'flag_do_crop': flag_do_crop_input,
- 'flag_pasteback': flag_remap_input
- }
- # update config from user input
- self.args = update_args(self.args, args_user)
- self.live_portrait_wrapper.update_config(self.args.__dict__)
- self.cropper.update_config(self.args.__dict__)
- # video driven animation
- video_path, video_path_concat = self.execute(self.args)
- return video_path, video_path_concat
+ if input_image_path is not None and input_video_path is not None:
+ args_user = {
+ 'source_image': input_image_path,
+ 'driving_info': input_video_path,
+ 'flag_relative': flag_relative_input,
+ 'flag_do_crop': flag_do_crop_input,
+ 'flag_pasteback': flag_remap_input
+ }
+ # update config from user input
+ self.args = update_args(self.args, args_user)
+ self.live_portrait_wrapper.update_config(self.args.__dict__)
+ self.cropper.update_config(self.args.__dict__)
+ # video driven animation
+ video_path, video_path_concat = self.execute(self.args)
+ gr.Info("Run successfully!", duration=2)
+ return video_path, video_path_concat,
+ else:
+ raise gr.Error("The input reference portrait or driving video hasn't been prepared yet ๐ฅ!", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting
"""
- # โ_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
- combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
- eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
- # โ_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
- combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
- lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
- num_kp = self.x_s_user.shape[1]
- # default: use x_s
- x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
- # D(W(f_s; x_s, xโฒ_d))
- out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
- out = self.live_portrait_wrapper.parse_output(out['out'])[0]
- return out
+ if input_eye_ratio is None or input_eye_ratio is None:
+ raise gr.Error("Invalid ratio input ๐ฅ!", duration=5)
+ elif self.f_s_user is None:
+ if self.start_prepare:
+ raise gr.Error(
+ "The reference portrait is under processing ๐ฅ! Please wait for a second.",
+ duration=5
+ )
+ else:
+ raise gr.Error(
+ "The reference portrait hasn't been prepared yet ๐ฅ! Please scroll to the top of the page to upload.",
+ duration=5
+ )
+ else:
+ # โ_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
+ # โ_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
+ lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
+ num_kp = self.x_s_user.shape[1]
+ # default: use x_s
+ x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
+ # D(W(f_s; x_s, xโฒ_d))
+ out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
+ out = self.live_portrait_wrapper.parse_output(out['out'])[0]
+ out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
+ gr.Info("Run successfully!", duration=2)
+ return out, out_to_ori_blend
+
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
""" for single image retargeting
"""
- inference_cfg = self.live_portrait_wrapper.cfg
- ######## process reference portrait ########
- img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
- log(f"Load source image from {input_image_path}.")
- crop_info = self.cropper.crop_single_image(img_rgb)
- if flag_do_crop:
- I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
+ if input_image_path is not None:
+ gr.Info("Upload successfully!", duration=2)
+ self.start_prepare = True
+ inference_cfg = self.live_portrait_wrapper.cfg
+ ######## process reference portrait ########
+ img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
+ log(f"Load source image from {input_image_path}.")
+ crop_info = self.cropper.crop_single_image(img_rgb)
+ if flag_do_crop:
+ I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
+ else:
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
+ ############################################
+
+ # record global info for next time use
+ self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
+ self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
+ self.x_s_info_user = x_s_info
+ self.source_lmk_user = crop_info['lmk_crop']
+ self.img_rgb = img_rgb
+ self.crop_M_c2o = crop_info['M_c2o']
+ self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
+ # update slider
+ eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
+ eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
+ lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
+ lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
+ # for vis
+ self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
+ return eye_close_ratio, lip_close_ratio, self.I_s_vis
else:
- I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
- x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
- R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
- ############################################
-
- # record global info for next time use
- self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
- self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
- self.x_s_info_user = x_s_info
- self.source_lmk_user = crop_info['lmk_crop']
-
- # update slider
- eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
- eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
- lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
- lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
-
- return eye_close_ratio, lip_close_ratio
+ # when press the clear button, go here
+ return 0.8, 0.8, self.I_s_vis
diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py
index 668e3a3..933c911 100644
--- a/src/live_portrait_pipeline.py
+++ b/src/live_portrait_pipeline.py
@@ -20,10 +20,10 @@ from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames
-from .utils.crop import _transform_img
+from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.retargeting_utils import calc_lip_close_ratio
-from .utils.io import load_image_rgb, load_driving_info
-from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template, resize_to_limit
+from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
+from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
from .utils.rprint import rlog as log
from .live_portrait_wrapper import LivePortraitWrapper
@@ -90,10 +90,7 @@ class LivePortraitPipeline(object):
######## prepare for pasteback ########
if inference_cfg.flag_pasteback:
- if inference_cfg.mask_crop is None:
- inference_cfg.mask_crop = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
- mask_ori = _transform_img(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
- mask_ori = mask_ori.astype(np.float32) / 255.
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_paste_lst = []
#########################################
@@ -172,9 +169,7 @@ class LivePortraitPipeline(object):
I_p_lst.append(I_p_i)
if inference_cfg.flag_pasteback:
- I_p_i_to_ori = _transform_img(I_p_i, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
- I_p_i_to_ori_blend = np.clip(mask_ori * I_p_i_to_ori + (1 - mask_ori) * img_rgb, 0, 255).astype(np.uint8)
- out = np.hstack([I_p_i_to_ori, I_p_i_to_ori_blend])
+ I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
I_p_paste_lst.append(I_p_i_to_ori_blend)
mkdir(args.output_dir)
diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py
index 2cb2eab..ac3c63a 100644
--- a/src/live_portrait_wrapper.py
+++ b/src/live_portrait_wrapper.py
@@ -12,7 +12,6 @@ import yaml
from src.utils.timer import Timer
from src.utils.helper import load_model, concat_feat
-from src.utils.retargeting_utils import compute_eye_delta, compute_lip_delta
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
@@ -211,33 +210,6 @@ class LivePortraitWrapper(object):
return delta
- def retarget_keypoints(self, frame_idx, num_keypoints, input_eye_ratios, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source, driving_transformed_kp):
- # TODO: GPT style, refactor it...
- if self.cfg.flag_eye_retargeting:
- # โ_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
- eye_delta = compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source)
- else:
- # ฮฑ_eyes = 0
- eye_delta = None
-
- if self.cfg.flag_lip_retargeting:
- # โ_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
- lip_delta = compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source)
- else:
- # ฮฑ_lip = 0
- lip_delta = None
-
- if self.cfg.flag_relative: # use x_s
- new_driving_kp = kp_source + \
- (eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
- (lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
- else: # use x_d,i
- new_driving_kp = driving_transformed_kp + \
- (eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
- (lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
-
- return new_driving_kp
-
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
diff --git a/src/utils/crop.py b/src/utils/crop.py
index c061ef4..8f23363 100644
--- a/src/utils/crop.py
+++ b/src/utils/crop.py
@@ -4,14 +4,17 @@
cropping function and the related preprocess functions for cropping
"""
-import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
import numpy as np
-from .rprint import rprint as print
+import os.path as osp
from math import sin, cos, acos, degrees
+import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
+from .rprint import rprint as print
DTYPE = np.float32
CV2_INTERP = cv2.INTER_LINEAR
+def make_abs_path(fn):
+ return osp.join(osp.dirname(osp.realpath(__file__)), fn)
def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
""" conduct similarity or affine transformation to the image, do not do border operation!
@@ -391,3 +394,19 @@ def average_bbox_lst(bbox_lst):
bbox_arr = np.array(bbox_lst)
return np.mean(bbox_arr, axis=0).tolist()
+def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
+ """prepare mask for later image paste back
+ """
+ if mask_crop is None:
+ mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
+ mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
+ mask_ori = mask_ori.astype(np.float32) / 255.
+ return mask_ori
+
+def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
+ """paste back the image
+ """
+ dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
+ result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
+ result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
+ return result
\ No newline at end of file
diff --git a/src/utils/cropper.py b/src/utils/cropper.py
index e8ee194..d5d511c 100644
--- a/src/utils/cropper.py
+++ b/src/utils/cropper.py
@@ -1,5 +1,6 @@
# coding: utf-8
+import gradio as gr
import numpy as np
import os.path as osp
from typing import List, Union, Tuple
@@ -72,6 +73,7 @@ class Cropper(object):
if len(src_face) == 0:
log('No face detected in the source image.')
+ raise gr.Error("No face detected in the source image ๐ฅ!", duration=5)
raise Exception("No face detected in the source image!")
elif len(src_face) > 1:
log(f'More than one face detected in the image, only pick one face by rule {direction}.')
diff --git a/src/utils/helper.py b/src/utils/helper.py
index 267f97f..05c991e 100644
--- a/src/utils/helper.py
+++ b/src/utils/helper.py
@@ -154,22 +154,3 @@ def load_description(fp):
content = f.read()
return content
-
-def resize_to_limit(img, max_dim=1280, n=2):
- h, w = img.shape[:2]
- if max_dim > 0 and max(h, w) > max_dim:
- if h > w:
- new_h = max_dim
- new_w = int(w * (max_dim / h))
- else:
- new_w = max_dim
- new_h = int(h * (max_dim / w))
- img = cv2.resize(img, (new_w, new_h))
- n = max(n, 1)
- new_h = img.shape[0] - (img.shape[0] % n)
- new_w = img.shape[1] - (img.shape[1] % n)
- if new_h == 0 or new_w == 0:
- return img
- if new_h != img.shape[0] or new_w != img.shape[1]:
- img = img[:new_h, :new_w]
- return img
diff --git a/src/utils/io.py b/src/utils/io.py
index f930c48..29a7e00 100644
--- a/src/utils/io.py
+++ b/src/utils/io.py
@@ -40,7 +40,7 @@ def contiguous(obj):
return obj
-def _resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
+def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
"""
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
:param img: the image to be processed.
@@ -87,7 +87,7 @@ def load_img_online(obj, mode="bgr", **kwargs):
img = obj
# Resize image to satisfy constraints
- img = _resize_to_limit(img, max_dim=max_dim, n=n)
+ img = resize_to_limit(img, max_dim=max_dim, n=n)
if mode.lower() == "bgr":
return contiguous(img)
diff --git a/src/utils/retargeting_utils.py b/src/utils/retargeting_utils.py
index 2028590..20a1bdd 100644
--- a/src/utils/retargeting_utils.py
+++ b/src/utils/retargeting_utils.py
@@ -4,7 +4,6 @@ Functions to compute distance ratios between specific pairs of facial landmarks
"""
import numpy as np
-import torch
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
@@ -53,24 +52,3 @@ def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
np.ndarray: Calculated lip-close ratio.
"""
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
-
-
-def compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source):
- input_eye_ratio = input_eye_ratios[frame_idx][0][0]
- eye_close_ratio = calc_eye_close_ratio(source_landmarks[None])
- eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(portrait_wrapper.device_id)
- input_eye_ratio_tensor = torch.Tensor([input_eye_ratio]).reshape(1, 1).cuda(portrait_wrapper.device_id)
- combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
- # print(combined_eye_ratio_tensor.mean())
- eye_delta = portrait_wrapper.retarget_eye(kp_source, combined_eye_ratio_tensor)
- return eye_delta
-
-
-def compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source):
- input_lip_ratio = input_lip_ratios[frame_idx][0]
- lip_close_ratio = calc_lip_close_ratio(source_landmarks[None])
- lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(portrait_wrapper.device_id)
- input_lip_ratio_tensor = torch.Tensor([input_lip_ratio]).cuda(portrait_wrapper.device_id)
- combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
- lip_delta = portrait_wrapper.retarget_lip(kp_source, combined_lip_ratio_tensor)
- return lip_delta