diff --git a/.gitignore b/.gitignore
index c7612d0..a6a28fc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,3 +21,4 @@ animations/*
tmp/*
.vscode/launch.json
**/*.DS_Store
+gradio_temp/**
diff --git a/app.py b/app.py
index 9b59a4d..ad2b4f8 100644
--- a/app.py
+++ b/app.py
@@ -4,6 +4,7 @@
The entrance of the gradio
"""
+import os
import tyro
import subprocess
import gradio as gr
@@ -47,6 +48,9 @@ gradio_pipeline = GradioPipeline(
args=args
)
+if args.gradio_temp_dir not in (None, ''):
+ os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
+ os.makedirs(args.gradio_temp_dir, exist_ok=True)
def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline.execute_video(*args, **kwargs)
@@ -69,25 +73,27 @@ data_examples_i2v = [
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
data_examples_v2v = [
- [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7],
- # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 1e-7],
- # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 1e-7],
- [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7],
- # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7],
- [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7],
+ [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
+ # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
+ # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
+ [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
+ # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
+ [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
]
#################### interface logic ####################
# Define components first
+retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
+head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch")
+head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw")
+head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll")
retargeting_input_image = gr.Image(type="filepath")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
-# output_video_v2v = gr.Video(autoplay=False)
-# output_video_concat_v2v = gr.Video(autoplay=False)
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
@@ -108,6 +114,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
+ [osp.join(example_portrait_dir, "s22.jpg")],
+ [osp.join(example_portrait_dir, "s23.jpg")],
],
inputs=[source_image_input],
cache_examples=False,
@@ -149,6 +157,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
+ [osp.join(example_video_dir, "d20.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
@@ -168,14 +177,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
- driving_smooth_observation_variance = gr.Number(value=1e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
+ driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
with gr.Row():
- with gr.Column():
- process_button_animation = gr.Button("๐ Animate", variant="primary")
- with gr.Column():
- process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="๐งน Clear")
+ process_button_animation = gr.Button("๐ Animate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
@@ -183,6 +189,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
+ with gr.Row():
+ process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="๐งน Clear")
with gr.Row():
# Examples
@@ -227,20 +235,15 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
# Retargeting
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True):
+ retargeting_source_scale.render()
eye_retargeting_slider.render()
lip_retargeting_slider.render()
+ with gr.Row(visible=True):
+ head_pitch_slider.render()
+ head_yaw_slider.render()
+ head_roll_slider.render()
with gr.Row(visible=True):
process_button_retargeting = gr.Button("๐ Retargeting", variant="primary")
- process_button_reset_retargeting = gr.ClearButton(
- [
- eye_retargeting_slider,
- lip_retargeting_slider,
- retargeting_input_image,
- output_image,
- output_image_paste_back
- ],
- value="๐งน Clear"
- )
with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"):
@@ -253,6 +256,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
+ [osp.join(example_portrait_dir, "s22.jpg")],
+ [osp.join(example_portrait_dir, "s23.jpg")],
],
inputs=[retargeting_input_image],
cache_examples=False,
@@ -263,15 +268,30 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render()
+ with gr.Row(visible=True):
+ process_button_reset_retargeting = gr.ClearButton(
+ [
+ eye_retargeting_slider,
+ lip_retargeting_slider,
+ head_pitch_slider,
+ head_yaw_slider,
+ head_roll_slider,
+ retargeting_input_image,
+ output_image,
+ output_image_paste_back
+ ],
+ value="๐งน Clear"
+ )
# binding functions for buttons
process_button_retargeting.click(
# fn=gradio_pipeline.execute_image,
fn=gpu_wrapped_execute_image,
- inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
+ inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input],
outputs=[output_image, output_image_paste_back],
show_progress=True
)
+
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
@@ -296,6 +316,12 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
show_progress=True
)
+ retargeting_input_image.change(
+ fn=gradio_pipeline.init_retargeting,
+ inputs=[retargeting_source_scale, retargeting_input_image],
+ outputs=[eye_retargeting_slider, lip_retargeting_slider]
+ )
+
demo.launch(
server_port=args.server_port,
share=args.share,
diff --git a/assets/docs/changelog/2024-07-24.md b/assets/docs/changelog/2024-07-24.md
new file mode 100644
index 0000000..e54aa45
--- /dev/null
+++ b/assets/docs/changelog/2024-07-24.md
@@ -0,0 +1,5 @@
+
+
+
+ Pose Editing Interface in the Gradio Interface
+
diff --git a/assets/docs/pose-edit-2024-07-24.jpg b/assets/docs/pose-edit-2024-07-24.jpg
new file mode 100644
index 0000000..74650bc
Binary files /dev/null and b/assets/docs/pose-edit-2024-07-24.jpg differ
diff --git a/assets/examples/driving/d20.mp4 b/assets/examples/driving/d20.mp4
new file mode 100644
index 0000000..30822f9
Binary files /dev/null and b/assets/examples/driving/d20.mp4 differ
diff --git a/assets/examples/source/s22.jpg b/assets/examples/source/s22.jpg
new file mode 100644
index 0000000..9ca08bb
Binary files /dev/null and b/assets/examples/source/s22.jpg differ
diff --git a/assets/examples/source/s23.jpg b/assets/examples/source/s23.jpg
new file mode 100644
index 0000000..4e1373a
Binary files /dev/null and b/assets/examples/source/s23.jpg differ
diff --git a/assets/gradio/gradio_description_retargeting.md b/assets/gradio/gradio_description_retargeting.md
index 64f1a7c..bd7b2bb 100644
--- a/assets/gradio/gradio_description_retargeting.md
+++ b/assets/gradio/gradio_description_retargeting.md
@@ -9,6 +9,6 @@
Retargeting
Upload a Source Portrait as Retargeting Input, then drag the sliders and click the ๐ Retargeting button. You can try running it multiple times.
- ๐ Set both ratios to 0.8 to see what's going on!
+ ๐ Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!
diff --git a/assets/gradio/gradio_description_upload.md b/assets/gradio/gradio_description_upload.md
index f5a018a..f6b975f 100644
--- a/assets/gradio/gradio_description_upload.md
+++ b/assets/gradio/gradio_description_upload.md
@@ -4,6 +4,9 @@
Step 1: Upload a Source Image or Video (any aspect ratio) โฌ๏ธ
+
+ Note: Better if Source Video has the same FPS as the Driving Video.
+
diff --git a/readme.md b/readme.md
index b378eeb..6518aa1 100644
--- a/readme.md
+++ b/readme.md
@@ -39,7 +39,8 @@
## ๐ฅ Updates
-- **`2024/07/19`**: โจ We support ๐๏ธ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md).
+- **`2024/07/24`**: ๐จ We support pose editing for source portraits in the Gradio interface. We've also lowered the default detection threshold to support more input detections. [Have fun](assets/docs/changelog/2024-07-24.md)!
+- **`2024/07/19`**: โจ We support ๐๏ธ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md).
- **`2024/07/17`**: ๐ We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
- **`2024/07/10`**: ๐ช We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
- **`2024/07/09`**: ๐ค We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
diff --git a/src/config/argument_config.py b/src/config/argument_config.py
index 08d17a7..6653f9c 100644
--- a/src/config/argument_config.py
+++ b/src/config/argument_config.py
@@ -32,9 +32,10 @@ class ArgumentConfig(PrintableConfig):
flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
- driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
+ driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
########## source crop arguments ##########
+ det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
@@ -50,3 +51,4 @@ class ArgumentConfig(PrintableConfig):
share: bool = False # whether to share the server to public
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation
+ gradio_temp_dir: Optional[str] = None # directory to save gradio temp files
diff --git a/src/config/crop_config.py b/src/config/crop_config.py
index c7d64a5..6c1f8f2 100644
--- a/src/config/crop_config.py
+++ b/src/config/crop_config.py
@@ -15,6 +15,7 @@ class CropConfig(PrintableConfig):
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
+ det_thresh: float = 0.1 # detection threshold
########## source image or video cropping option ##########
dsize: int = 512 # crop size
scale: float = 2.8 # scale factor
diff --git a/src/config/inference_config.py b/src/config/inference_config.py
index 48bf88c..c1f8653 100644
--- a/src/config/inference_config.py
+++ b/src/config/inference_config.py
@@ -41,7 +41,7 @@ class InferenceConfig(PrintableConfig):
# NOT EXPORTED PARAMS
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
- driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
+ driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape
diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py
index cbe898e..003c5ca 100644
--- a/src/gradio_pipeline.py
+++ b/src/gradio_pipeline.py
@@ -6,6 +6,7 @@ Pipeline for gradio
import os.path as osp
import gradio as gr
+import torch
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
@@ -14,6 +15,7 @@ from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.helper import is_square_video
+from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
def update_args(args, user_args):
@@ -32,6 +34,7 @@ class GradioPipeline(LivePortraitPipeline):
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
+ @torch.no_grad()
def execute_video(
self,
input_source_image_path=None,
@@ -48,7 +51,7 @@ class GradioPipeline(LivePortraitPipeline):
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
- driving_smooth_observation_variance=1e-7,
+ driving_smooth_observation_variance=3e-7,
tab_selection=None,
):
""" for video-driven potrait animation or video editing
@@ -93,27 +96,41 @@ class GradioPipeline(LivePortraitPipeline):
else:
raise gr.Error("Please upload the source portrait or source video, and driving video ๐ค๐ค๐ค", duration=5)
- def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
+ @torch.no_grad()
+ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop=True):
""" for single image retargeting
"""
+ if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
+ raise gr.Error("Invalid relative pose input ๐ฅ!", duration=5)
# disposable feature
- f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
- self.prepare_retargeting(input_image, flag_do_crop)
+ f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
+ self.prepare_retargeting(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input ๐ฅ!", duration=5)
else:
- inference_cfg = self.live_portrait_wrapper.inference_cfg
- x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
- f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
+ device = self.live_portrait_wrapper.device
+ # inference_cfg = self.live_portrait_wrapper.inference_cfg
+ x_s_user = x_s_user.to(device)
+ f_s_user = f_s_user.to(device)
+ R_s_user = R_s_user.to(device)
+ R_d_user = R_d_user.to(device)
+
+ x_c_s = x_s_info['kp'].to(device)
+ delta_new = x_s_info['exp'].to(device)
+ scale_new = x_s_info['scale'].to(device)
+ t_new = x_s_info['t'].to(device)
+ R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user
+
+ x_d_new = scale_new * (x_c_s @ R_d_new + delta_new) + t_new
# โ_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
- combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
# โ_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
- combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
- # default: use x_s
- x_d_new = x_s_user + eyes_delta + lip_delta
+ x_d_new = x_d_new + eyes_delta + lip_delta
+ x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new)
# D(W(f_s; x_s, xโฒ_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
@@ -121,14 +138,18 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
- def prepare_retargeting(self, input_image, flag_do_crop=True):
+ @torch.no_grad()
+ def prepare_retargeting(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True):
""" for single image retargeting
"""
if input_image is not None:
# gr.Info("Upload successfully!", duration=2)
+ args_user = {'scale': retargeting_source_scale}
+ self.args = update_args(self.args, args_user)
+ self.cropper.update_config(self.args.__dict__)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ########
- img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if flag_do_crop:
@@ -136,14 +157,37 @@ class GradioPipeline(LivePortraitPipeline):
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
- R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
+ x_s_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
+ x_s_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
+ x_s_info_user_roll = x_s_info['roll'] + input_head_roll_variation
+ R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
+ R_d_user = get_rotation_matrix(x_s_info_user_pitch, x_s_info_user_yaw, x_s_info_user_roll)
############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
- return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
+ return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
raise gr.Error("Please upload a source portrait as the retargeting input ๐ค๐ค๐ค", duration=5)
+
+ def init_retargeting(self, retargeting_source_scale: float, input_image = None):
+ """ initialize the retargeting slider
+ """
+ if input_image != None:
+ args_user = {'scale': retargeting_source_scale}
+ self.args = update_args(self.args, args_user)
+ self.cropper.update_config(self.args.__dict__)
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
+ ######## process source portrait ########
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
+ log(f"Load source image from {input_image}.")
+ crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
+ if crop_info is None:
+ raise gr.Error("Source portrait NO face detected", duration=2)
+ source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None])
+ source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
+ return round(float(source_eye_ratio.mean()), 2), round(source_lip_ratio[0][0], 2)
+ return 0., 0.
diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py
index e344ffc..9eccc8e 100644
--- a/src/live_portrait_pipeline.py
+++ b/src/live_portrait_pipeline.py
@@ -200,17 +200,16 @@ class LivePortraitPipeline(object):
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else: # if the input is a source image, process it only once
- crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
- if crop_info is None:
- raise Exception("No face detected in the source image!")
- source_lmk = crop_info['lmk_crop']
- img_crop_256x256 = crop_info['img_crop_256x256']
-
if inf_cfg.flag_do_crop:
- I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
+ crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
+ if crop_info is None:
+ raise Exception("No face detected in the source image!")
+ source_lmk = crop_info['lmk_crop']
+ img_crop_256x256 = crop_info['img_crop_256x256']
else:
+ source_lmk = self.cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256
- I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
+ I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
@@ -218,7 +217,7 @@ class LivePortraitPipeline(object):
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
# let lip-open scalar to be 0 at first
- if flag_normalize_lip:
+ if flag_normalize_lip and source_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
@@ -245,14 +244,14 @@ class LivePortraitPipeline(object):
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
# let lip-open scalar to be 0 at first if the input is a video
- if flag_normalize_lip:
+ if flag_normalize_lip and source_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
- if flag_source_video_eye_retargeting:
+ if flag_source_video_eye_retargeting and source_lmk is not None:
if i == 0:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
@@ -312,12 +311,12 @@ class LivePortraitPipeline(object):
x_d_i_new += eye_delta_before_animation
else:
eyes_delta, lip_delta = None, None
- if inf_cfg.flag_eye_retargeting:
+ if inf_cfg.flag_eye_retargeting and source_lmk is not None:
c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# โ_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
- if inf_cfg.flag_lip_retargeting:
+ if inf_cfg.flag_lip_retargeting and source_lmk is not None:
c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# โ_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
diff --git a/src/utils/cropper.py b/src/utils/cropper.py
index e0e3789..c42e74b 100644
--- a/src/utils/cropper.py
+++ b/src/utils/cropper.py
@@ -67,7 +67,7 @@ class Cropper(object):
root=make_abs_path(self.crop_cfg.insightface_root),
providers=face_analysis_wrapper_provider,
)
- self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
+ self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh)
self.face_analysis_wrapper.warmup()
def update_config(self, user_args):
@@ -117,6 +117,24 @@ class Cropper(object):
return ret_dct
+ def calc_lmk_from_cropped_image(self, img_rgb_, **kwargs):
+ direction = kwargs.get("direction", "large-small")
+ src_face = self.face_analysis_wrapper.get(
+ contiguous(img_rgb_[..., ::-1]), # convert to BGR
+ flag_do_landmark_2d_106=True,
+ direction=direction,
+ )
+ if len(src_face) == 0:
+ log("No face detected in the source image.")
+ return None
+ elif len(src_face) > 1:
+ log(f"More than one face detected in the image, only pick one face by rule {direction}.")
+ src_face = src_face[0]
+ lmk = src_face.landmark_2d_106
+ lmk = self.landmark_runner.run(img_rgb_, lmk)
+
+ return lmk
+
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
diff --git a/src/utils/filter.py b/src/utils/filter.py
index 5238f49..a8e27ca 100644
--- a/src/utils/filter.py
+++ b/src/utils/filter.py
@@ -5,7 +5,7 @@ import numpy as np
from pykalman import KalmanFilter
-def smooth(x_d_lst, shape, device, observation_variance=1e-7, process_variance=1e-5):
+def smooth(x_d_lst, shape, device, observation_variance=3e-7, process_variance=1e-5):
x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst]
x_d_stacked = np.vstack(x_d_lst_reshape)
kf = KalmanFilter(