feat: pose editing for source portraits (#212)

* feat: edit-pose

* feat: edit-pose

* feat: edit pose (#207)

* feat: edit-pose

* chore: i2v cropper

* chore: i2v cropper

* chore: update gradio

* chore: update gradio

* chore: i2v cropper

* chore: update

* chore: update

* doc: update readme

* doc: update readme

---------

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
Co-authored-by: ZhizhouZhong <1819489045@qq.com>
This commit is contained in:
Jianzhu Guo 2024-07-24 22:22:43 +08:00 committed by GitHub
parent aae6e90fd6
commit 73ddb69d38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 158 additions and 58 deletions

1
.gitignore vendored
View File

@ -21,3 +21,4 @@ animations/*
tmp/*
.vscode/launch.json
**/*.DS_Store
gradio_temp/**

74
app.py
View File

@ -4,6 +4,7 @@
The entrance of the gradio
"""
import os
import tyro
import subprocess
import gradio as gr
@ -47,6 +48,9 @@ gradio_pipeline = GradioPipeline(
args=args
)
if args.gradio_temp_dir not in (None, ''):
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
os.makedirs(args.gradio_temp_dir, exist_ok=True)
def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline.execute_video(*args, **kwargs)
@ -69,25 +73,27 @@ data_examples_i2v = [
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
data_examples_v2v = [
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 1e-7],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 1e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7],
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
]
#################### interface logic ####################
# Define components first
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch")
head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw")
head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll")
retargeting_input_image = gr.Image(type="filepath")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
# output_video_v2v = gr.Video(autoplay=False)
# output_video_concat_v2v = gr.Video(autoplay=False)
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
@ -108,6 +114,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
[osp.join(example_portrait_dir, "s22.jpg")],
[osp.join(example_portrait_dir, "s23.jpg")],
],
inputs=[source_image_input],
cache_examples=False,
@ -149,6 +157,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
[osp.join(example_video_dir, "d20.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
@ -168,14 +177,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=1e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
with gr.Row():
with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
@ -183,6 +189,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
with gr.Row():
# Examples
@ -227,20 +235,15 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
# Retargeting
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True):
retargeting_source_scale.render()
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Row(visible=True):
head_pitch_slider.render()
head_yaw_slider.render()
head_roll_slider.render()
with gr.Row(visible=True):
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton(
[
eye_retargeting_slider,
lip_retargeting_slider,
retargeting_input_image,
output_image,
output_image_paste_back
],
value="🧹 Clear"
)
with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"):
@ -253,6 +256,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
[osp.join(example_portrait_dir, "s22.jpg")],
[osp.join(example_portrait_dir, "s23.jpg")],
],
inputs=[retargeting_input_image],
cache_examples=False,
@ -263,15 +268,30 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render()
with gr.Row(visible=True):
process_button_reset_retargeting = gr.ClearButton(
[
eye_retargeting_slider,
lip_retargeting_slider,
head_pitch_slider,
head_yaw_slider,
head_roll_slider,
retargeting_input_image,
output_image,
output_image_paste_back
],
value="🧹 Clear"
)
# binding functions for buttons
process_button_retargeting.click(
# fn=gradio_pipeline.execute_image,
fn=gpu_wrapped_execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input],
outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
@ -296,6 +316,12 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
show_progress=True
)
retargeting_input_image.change(
fn=gradio_pipeline.init_retargeting,
inputs=[retargeting_source_scale, retargeting_input_image],
outputs=[eye_retargeting_slider, lip_retargeting_slider]
)
demo.launch(
server_port=args.server_port,
share=args.share,

View File

@ -0,0 +1,5 @@
<p align="center">
<img src="../pose-edit-2024-07-24.jpg" alt="LivePortrait" width="960px">
<br>
Pose Editing Interface in the Gradio Interface
</p>

Binary file not shown.

After

Width:  |  Height:  |  Size: 217 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

View File

@ -9,6 +9,6 @@
<h2>Retargeting</h2>
<p>Upload a Source Portrait as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times.
<br>
<strong>😊 Set both ratios to 0.8 to see what's going on!</strong></p>
<strong>😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!</strong></p>
</div>
</div>

View File

@ -4,6 +4,9 @@
<div style="display: inline-block;">
Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
</div>
<div style="display: inline-block; font-size: 0.8em;">
<strong>Note:</strong> Better if Source Video has <strong>the same FPS</strong> as the Driving Video.
</div>
</div>
<div style="flex: 1; text-align: center; margin-left: 20px;">
<div style="display: inline-block;">

View File

@ -39,7 +39,8 @@
## 🔥 Updates
- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md).
- **`2024/07/24`**: 🎨 We support pose editing for source portraits in the Gradio interface. We've also lowered the default detection threshold to support more input detections. [Have fun](assets/docs/changelog/2024-07-24.md)!
- **`2024/07/19`**: ✨ We support 🎞️ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md).
- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!

View File

@ -32,9 +32,10 @@ class ArgumentConfig(PrintableConfig):
flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
########## source crop arguments ##########
det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
@ -50,3 +51,4 @@ class ArgumentConfig(PrintableConfig):
share: bool = False # whether to share the server to public
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation
gradio_temp_dir: Optional[str] = None # directory to save gradio temp files

View File

@ -15,6 +15,7 @@ class CropConfig(PrintableConfig):
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
det_thresh: float = 0.1 # detection threshold
########## source image or video cropping option ##########
dsize: int = 512 # crop size
scale: float = 2.8 # scale factor

View File

@ -41,7 +41,7 @@ class InferenceConfig(PrintableConfig):
# NOT EXPORTED PARAMS
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape

View File

@ -6,6 +6,7 @@ Pipeline for gradio
import os.path as osp
import gradio as gr
import torch
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
@ -14,6 +15,7 @@ from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.helper import is_square_video
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
def update_args(args, user_args):
@ -32,6 +34,7 @@ class GradioPipeline(LivePortraitPipeline):
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
@torch.no_grad()
def execute_video(
self,
input_source_image_path=None,
@ -48,7 +51,7 @@ class GradioPipeline(LivePortraitPipeline):
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=1e-7,
driving_smooth_observation_variance=3e-7,
tab_selection=None,
):
""" for video-driven potrait animation or video editing
@ -93,27 +96,41 @@ class GradioPipeline(LivePortraitPipeline):
else:
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
@torch.no_grad()
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop=True):
""" for single image retargeting
"""
if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
raise gr.Error("Invalid relative pose input 💥!", duration=5)
# disposable feature
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, flag_do_crop)
f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
device = self.live_portrait_wrapper.device
# inference_cfg = self.live_portrait_wrapper.inference_cfg
x_s_user = x_s_user.to(device)
f_s_user = f_s_user.to(device)
R_s_user = R_s_user.to(device)
R_d_user = R_d_user.to(device)
x_c_s = x_s_info['kp'].to(device)
delta_new = x_s_info['exp'].to(device)
scale_new = x_s_info['scale'].to(device)
t_new = x_s_info['t'].to(device)
R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user
x_d_new = scale_new * (x_c_s @ R_d_new + delta_new) + t_new
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
# default: use x_s
x_d_new = x_s_user + eyes_delta + lip_delta
x_d_new = x_d_new + eyes_delta + lip_delta
x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new)
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
@ -121,14 +138,18 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def prepare_retargeting(self, input_image, flag_do_crop=True):
@torch.no_grad()
def prepare_retargeting(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True):
""" for single image retargeting
"""
if input_image is not None:
# gr.Info("Upload successfully!", duration=2)
args_user = {'scale': retargeting_source_scale}
self.args = update_args(self.args, args_user)
self.cropper.update_config(self.args.__dict__)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if flag_do_crop:
@ -136,14 +157,37 @@ class GradioPipeline(LivePortraitPipeline):
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
x_s_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
x_s_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
x_s_info_user_roll = x_s_info['roll'] + input_head_roll_variation
R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
R_d_user = get_rotation_matrix(x_s_info_user_pitch, x_s_info_user_yaw, x_s_info_user_roll)
############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
def init_retargeting(self, retargeting_source_scale: float, input_image = None):
""" initialize the retargeting slider
"""
if input_image != None:
args_user = {'scale': retargeting_source_scale}
self.args = update_args(self.args, args_user)
self.cropper.update_config(self.args.__dict__)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if crop_info is None:
raise gr.Error("Source portrait NO face detected", duration=2)
source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None])
source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
return round(float(source_eye_ratio.mean()), 2), round(source_lip_ratio[0][0], 2)
return 0., 0.

View File

@ -200,17 +200,16 @@ class LivePortraitPipeline(object):
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else: # if the input is a source image, process it only once
crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop_256x256 = crop_info['img_crop_256x256']
if inf_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop_256x256 = crop_info['img_crop_256x256']
else:
source_lmk = self.cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
@ -218,7 +217,7 @@ class LivePortraitPipeline(object):
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
# let lip-open scalar to be 0 at first
if flag_normalize_lip:
if flag_normalize_lip and source_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
@ -245,14 +244,14 @@ class LivePortraitPipeline(object):
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
# let lip-open scalar to be 0 at first if the input is a video
if flag_normalize_lip:
if flag_normalize_lip and source_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
if flag_source_video_eye_retargeting:
if flag_source_video_eye_retargeting and source_lmk is not None:
if i == 0:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
@ -312,12 +311,12 @@ class LivePortraitPipeline(object):
x_d_i_new += eye_delta_before_animation
else:
eyes_delta, lip_delta = None, None
if inf_cfg.flag_eye_retargeting:
if inf_cfg.flag_eye_retargeting and source_lmk is not None:
c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
if inf_cfg.flag_lip_retargeting:
if inf_cfg.flag_lip_retargeting and source_lmk is not None:
c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)

View File

@ -67,7 +67,7 @@ class Cropper(object):
root=make_abs_path(self.crop_cfg.insightface_root),
providers=face_analysis_wrapper_provider,
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh)
self.face_analysis_wrapper.warmup()
def update_config(self, user_args):
@ -117,6 +117,24 @@ class Cropper(object):
return ret_dct
def calc_lmk_from_cropped_image(self, img_rgb_, **kwargs):
direction = kwargs.get("direction", "large-small")
src_face = self.face_analysis_wrapper.get(
contiguous(img_rgb_[..., ::-1]), # convert to BGR
flag_do_landmark_2d_106=True,
direction=direction,
)
if len(src_face) == 0:
log("No face detected in the source image.")
return None
elif len(src_face) > 1:
log(f"More than one face detected in the image, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(img_rgb_, lmk)
return lmk
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()

View File

@ -5,7 +5,7 @@ import numpy as np
from pykalman import KalmanFilter
def smooth(x_d_lst, shape, device, observation_variance=1e-7, process_variance=1e-5):
def smooth(x_d_lst, shape, device, observation_variance=3e-7, process_variance=1e-5):
x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst]
x_d_stacked = np.vstack(x_d_lst_reshape)
kf = KalmanFilter(