mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 12:22:38 +00:00
feat: pose editing for source portraits (#212)
* feat: edit-pose * feat: edit-pose * feat: edit pose (#207) * feat: edit-pose * chore: i2v cropper * chore: i2v cropper * chore: update gradio * chore: update gradio * chore: i2v cropper * chore: update * chore: update * doc: update readme * doc: update readme --------- Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com> Co-authored-by: ZhizhouZhong <1819489045@qq.com>
This commit is contained in:
parent
aae6e90fd6
commit
73ddb69d38
1
.gitignore
vendored
1
.gitignore
vendored
@ -21,3 +21,4 @@ animations/*
|
||||
tmp/*
|
||||
.vscode/launch.json
|
||||
**/*.DS_Store
|
||||
gradio_temp/**
|
||||
|
72
app.py
72
app.py
@ -4,6 +4,7 @@
|
||||
The entrance of the gradio
|
||||
"""
|
||||
|
||||
import os
|
||||
import tyro
|
||||
import subprocess
|
||||
import gradio as gr
|
||||
@ -47,6 +48,9 @@ gradio_pipeline = GradioPipeline(
|
||||
args=args
|
||||
)
|
||||
|
||||
if args.gradio_temp_dir not in (None, ''):
|
||||
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
|
||||
os.makedirs(args.gradio_temp_dir, exist_ok=True)
|
||||
|
||||
def gpu_wrapped_execute_video(*args, **kwargs):
|
||||
return gradio_pipeline.execute_video(*args, **kwargs)
|
||||
@ -69,25 +73,27 @@ data_examples_i2v = [
|
||||
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
|
||||
]
|
||||
data_examples_v2v = [
|
||||
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7],
|
||||
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 1e-7],
|
||||
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 1e-7],
|
||||
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7],
|
||||
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7],
|
||||
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7],
|
||||
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
|
||||
]
|
||||
#################### interface logic ####################
|
||||
|
||||
# Define components first
|
||||
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
|
||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||
head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch")
|
||||
head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw")
|
||||
head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll")
|
||||
retargeting_input_image = gr.Image(type="filepath")
|
||||
output_image = gr.Image(type="numpy")
|
||||
output_image_paste_back = gr.Image(type="numpy")
|
||||
output_video_i2v = gr.Video(autoplay=False)
|
||||
output_video_concat_i2v = gr.Video(autoplay=False)
|
||||
# output_video_v2v = gr.Video(autoplay=False)
|
||||
# output_video_concat_v2v = gr.Video(autoplay=False)
|
||||
|
||||
|
||||
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
|
||||
@ -108,6 +114,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||
[osp.join(example_portrait_dir, "s22.jpg")],
|
||||
[osp.join(example_portrait_dir, "s23.jpg")],
|
||||
],
|
||||
inputs=[source_image_input],
|
||||
cache_examples=False,
|
||||
@ -149,6 +157,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
[osp.join(example_video_dir, "d19.mp4")],
|
||||
[osp.join(example_video_dir, "d14.mp4")],
|
||||
[osp.join(example_video_dir, "d6.mp4")],
|
||||
[osp.join(example_video_dir, "d20.mp4")],
|
||||
],
|
||||
inputs=[driving_video_input],
|
||||
cache_examples=False,
|
||||
@ -168,14 +177,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
|
||||
driving_smooth_observation_variance = gr.Number(value=1e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||
with gr.Column():
|
||||
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
||||
@ -183,6 +189,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video"):
|
||||
output_video_concat_i2v.render()
|
||||
with gr.Row():
|
||||
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
|
||||
|
||||
with gr.Row():
|
||||
# Examples
|
||||
@ -227,20 +235,15 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
# Retargeting
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
|
||||
with gr.Row(visible=True):
|
||||
retargeting_source_scale.render()
|
||||
eye_retargeting_slider.render()
|
||||
lip_retargeting_slider.render()
|
||||
with gr.Row(visible=True):
|
||||
head_pitch_slider.render()
|
||||
head_yaw_slider.render()
|
||||
head_roll_slider.render()
|
||||
with gr.Row(visible=True):
|
||||
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
||||
process_button_reset_retargeting = gr.ClearButton(
|
||||
[
|
||||
eye_retargeting_slider,
|
||||
lip_retargeting_slider,
|
||||
retargeting_input_image,
|
||||
output_image,
|
||||
output_image_paste_back
|
||||
],
|
||||
value="🧹 Clear"
|
||||
)
|
||||
with gr.Row(visible=True):
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Retargeting Input"):
|
||||
@ -253,6 +256,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||
[osp.join(example_portrait_dir, "s22.jpg")],
|
||||
[osp.join(example_portrait_dir, "s23.jpg")],
|
||||
],
|
||||
inputs=[retargeting_input_image],
|
||||
cache_examples=False,
|
||||
@ -263,15 +268,30 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Paste-back Result"):
|
||||
output_image_paste_back.render()
|
||||
with gr.Row(visible=True):
|
||||
process_button_reset_retargeting = gr.ClearButton(
|
||||
[
|
||||
eye_retargeting_slider,
|
||||
lip_retargeting_slider,
|
||||
head_pitch_slider,
|
||||
head_yaw_slider,
|
||||
head_roll_slider,
|
||||
retargeting_input_image,
|
||||
output_image,
|
||||
output_image_paste_back
|
||||
],
|
||||
value="🧹 Clear"
|
||||
)
|
||||
|
||||
# binding functions for buttons
|
||||
process_button_retargeting.click(
|
||||
# fn=gradio_pipeline.execute_image,
|
||||
fn=gpu_wrapped_execute_image,
|
||||
inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
|
||||
inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input],
|
||||
outputs=[output_image, output_image_paste_back],
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
process_button_animation.click(
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
@ -296,6 +316,12 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
retargeting_input_image.change(
|
||||
fn=gradio_pipeline.init_retargeting,
|
||||
inputs=[retargeting_source_scale, retargeting_input_image],
|
||||
outputs=[eye_retargeting_slider, lip_retargeting_slider]
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
|
5
assets/docs/changelog/2024-07-24.md
Normal file
5
assets/docs/changelog/2024-07-24.md
Normal file
@ -0,0 +1,5 @@
|
||||
<p align="center">
|
||||
<img src="../pose-edit-2024-07-24.jpg" alt="LivePortrait" width="960px">
|
||||
<br>
|
||||
Pose Editing Interface in the Gradio Interface
|
||||
</p>
|
BIN
assets/docs/pose-edit-2024-07-24.jpg
Normal file
BIN
assets/docs/pose-edit-2024-07-24.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 217 KiB |
BIN
assets/examples/driving/d20.mp4
Normal file
BIN
assets/examples/driving/d20.mp4
Normal file
Binary file not shown.
BIN
assets/examples/source/s22.jpg
Normal file
BIN
assets/examples/source/s22.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 156 KiB |
BIN
assets/examples/source/s23.jpg
Normal file
BIN
assets/examples/source/s23.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
@ -9,6 +9,6 @@
|
||||
<h2>Retargeting</h2>
|
||||
<p>Upload a Source Portrait as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times.
|
||||
<br>
|
||||
<strong>😊 Set both ratios to 0.8 to see what's going on!</strong></p>
|
||||
<strong>😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!</strong></p>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -4,6 +4,9 @@
|
||||
<div style="display: inline-block;">
|
||||
Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
|
||||
</div>
|
||||
<div style="display: inline-block; font-size: 0.8em;">
|
||||
<strong>Note:</strong> Better if Source Video has <strong>the same FPS</strong> as the Driving Video.
|
||||
</div>
|
||||
</div>
|
||||
<div style="flex: 1; text-align: center; margin-left: 20px;">
|
||||
<div style="display: inline-block;">
|
||||
|
@ -39,7 +39,8 @@
|
||||
|
||||
|
||||
## 🔥 Updates
|
||||
- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md).
|
||||
- **`2024/07/24`**: 🎨 We support pose editing for source portraits in the Gradio interface. We've also lowered the default detection threshold to support more input detections. [Have fun](assets/docs/changelog/2024-07-24.md)!
|
||||
- **`2024/07/19`**: ✨ We support 🎞️ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md).
|
||||
- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
|
||||
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
|
||||
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
|
||||
|
@ -32,9 +32,10 @@ class ArgumentConfig(PrintableConfig):
|
||||
flag_relative_motion: bool = True # whether to use relative motion
|
||||
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
||||
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
|
||||
driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
|
||||
########## source crop arguments ##########
|
||||
det_thresh: float = 0.15 # detection threshold
|
||||
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
||||
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
|
||||
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
|
||||
@ -50,3 +51,4 @@ class ArgumentConfig(PrintableConfig):
|
||||
share: bool = False # whether to share the server to public
|
||||
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
|
||||
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation
|
||||
gradio_temp_dir: Optional[str] = None # directory to save gradio temp files
|
||||
|
@ -15,6 +15,7 @@ class CropConfig(PrintableConfig):
|
||||
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
|
||||
device_id: int = 0 # gpu device id
|
||||
flag_force_cpu: bool = False # force cpu inference, WIP
|
||||
det_thresh: float = 0.1 # detection threshold
|
||||
########## source image or video cropping option ##########
|
||||
dsize: int = 512 # crop size
|
||||
scale: float = 2.8 # scale factor
|
||||
|
@ -41,7 +41,7 @@ class InferenceConfig(PrintableConfig):
|
||||
# NOT EXPORTED PARAMS
|
||||
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
||||
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
|
||||
driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
anchor_frame: int = 0 # TO IMPLEMENT
|
||||
|
||||
input_shape: Tuple[int, int] = (256, 256) # input shape
|
||||
|
@ -6,6 +6,7 @@ Pipeline for gradio
|
||||
|
||||
import os.path as osp
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from .config.argument_config import ArgumentConfig
|
||||
from .live_portrait_pipeline import LivePortraitPipeline
|
||||
@ -14,6 +15,7 @@ from .utils.rprint import rlog as log
|
||||
from .utils.crop import prepare_paste_back, paste_back
|
||||
from .utils.camera import get_rotation_matrix
|
||||
from .utils.helper import is_square_video
|
||||
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
||||
|
||||
|
||||
def update_args(args, user_args):
|
||||
@ -32,6 +34,7 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
||||
self.args = args
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_video(
|
||||
self,
|
||||
input_source_image_path=None,
|
||||
@ -48,7 +51,7 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
scale_crop_driving_video=2.2,
|
||||
vx_ratio_crop_driving_video=0.0,
|
||||
vy_ratio_crop_driving_video=-0.1,
|
||||
driving_smooth_observation_variance=1e-7,
|
||||
driving_smooth_observation_variance=3e-7,
|
||||
tab_selection=None,
|
||||
):
|
||||
""" for video-driven potrait animation or video editing
|
||||
@ -93,27 +96,41 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
else:
|
||||
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
|
||||
|
||||
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
|
||||
@torch.no_grad()
|
||||
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop=True):
|
||||
""" for single image retargeting
|
||||
"""
|
||||
if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
|
||||
raise gr.Error("Invalid relative pose input 💥!", duration=5)
|
||||
# disposable feature
|
||||
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
|
||||
self.prepare_retargeting(input_image, flag_do_crop)
|
||||
f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
|
||||
self.prepare_retargeting(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop)
|
||||
|
||||
if input_eye_ratio is None or input_lip_ratio is None:
|
||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||
else:
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
|
||||
f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
|
||||
device = self.live_portrait_wrapper.device
|
||||
# inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
x_s_user = x_s_user.to(device)
|
||||
f_s_user = f_s_user.to(device)
|
||||
R_s_user = R_s_user.to(device)
|
||||
R_d_user = R_d_user.to(device)
|
||||
|
||||
x_c_s = x_s_info['kp'].to(device)
|
||||
delta_new = x_s_info['exp'].to(device)
|
||||
scale_new = x_s_info['scale'].to(device)
|
||||
t_new = x_s_info['t'].to(device)
|
||||
R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user
|
||||
|
||||
x_d_new = scale_new * (x_c_s @ R_d_new + delta_new) + t_new
|
||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user)
|
||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
|
||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
|
||||
# default: use x_s
|
||||
x_d_new = x_s_user + eyes_delta + lip_delta
|
||||
x_d_new = x_d_new + eyes_delta + lip_delta
|
||||
x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new)
|
||||
# D(W(f_s; x_s, x′_d))
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
|
||||
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
@ -121,14 +138,18 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return out, out_to_ori_blend
|
||||
|
||||
def prepare_retargeting(self, input_image, flag_do_crop=True):
|
||||
@torch.no_grad()
|
||||
def prepare_retargeting(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True):
|
||||
""" for single image retargeting
|
||||
"""
|
||||
if input_image is not None:
|
||||
# gr.Info("Upload successfully!", duration=2)
|
||||
args_user = {'scale': retargeting_source_scale}
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
######## process source portrait ########
|
||||
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
|
||||
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
|
||||
log(f"Load source image from {input_image}.")
|
||||
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
|
||||
if flag_do_crop:
|
||||
@ -136,14 +157,37 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
else:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
x_s_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
|
||||
x_s_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
|
||||
x_s_info_user_roll = x_s_info['roll'] + input_head_roll_variation
|
||||
R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
R_d_user = get_rotation_matrix(x_s_info_user_pitch, x_s_info_user_yaw, x_s_info_user_roll)
|
||||
############################################
|
||||
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
source_lmk_user = crop_info['lmk_crop']
|
||||
crop_M_c2o = crop_info['M_c2o']
|
||||
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
|
||||
return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
|
||||
else:
|
||||
# when press the clear button, go here
|
||||
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
|
||||
|
||||
def init_retargeting(self, retargeting_source_scale: float, input_image = None):
|
||||
""" initialize the retargeting slider
|
||||
"""
|
||||
if input_image != None:
|
||||
args_user = {'scale': retargeting_source_scale}
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
######## process source portrait ########
|
||||
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
|
||||
log(f"Load source image from {input_image}.")
|
||||
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
|
||||
if crop_info is None:
|
||||
raise gr.Error("Source portrait NO face detected", duration=2)
|
||||
source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None])
|
||||
source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
|
||||
return round(float(source_eye_ratio.mean()), 2), round(source_lip_ratio[0][0], 2)
|
||||
return 0., 0.
|
||||
|
@ -200,15 +200,14 @@ class LivePortraitPipeline(object):
|
||||
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else: # if the input is a source image, process it only once
|
||||
if inf_cfg.flag_do_crop:
|
||||
crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
|
||||
if crop_info is None:
|
||||
raise Exception("No face detected in the source image!")
|
||||
source_lmk = crop_info['lmk_crop']
|
||||
img_crop_256x256 = crop_info['img_crop_256x256']
|
||||
|
||||
if inf_cfg.flag_do_crop:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||
else:
|
||||
source_lmk = self.cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
|
||||
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||
@ -218,7 +217,7 @@ class LivePortraitPipeline(object):
|
||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
|
||||
# let lip-open scalar to be 0 at first
|
||||
if flag_normalize_lip:
|
||||
if flag_normalize_lip and source_lmk is not None:
|
||||
c_d_lip_before_animation = [0.]
|
||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
||||
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
|
||||
@ -245,14 +244,14 @@ class LivePortraitPipeline(object):
|
||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
|
||||
# let lip-open scalar to be 0 at first if the input is a video
|
||||
if flag_normalize_lip:
|
||||
if flag_normalize_lip and source_lmk is not None:
|
||||
c_d_lip_before_animation = [0.]
|
||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
||||
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
|
||||
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
||||
|
||||
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
|
||||
if flag_source_video_eye_retargeting:
|
||||
if flag_source_video_eye_retargeting and source_lmk is not None:
|
||||
if i == 0:
|
||||
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
|
||||
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
|
||||
@ -312,12 +311,12 @@ class LivePortraitPipeline(object):
|
||||
x_d_i_new += eye_delta_before_animation
|
||||
else:
|
||||
eyes_delta, lip_delta = None, None
|
||||
if inf_cfg.flag_eye_retargeting:
|
||||
if inf_cfg.flag_eye_retargeting and source_lmk is not None:
|
||||
c_d_eyes_i = c_d_eyes_lst[i]
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
||||
if inf_cfg.flag_lip_retargeting:
|
||||
if inf_cfg.flag_lip_retargeting and source_lmk is not None:
|
||||
c_d_lip_i = c_d_lip_lst[i]
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||
|
@ -67,7 +67,7 @@ class Cropper(object):
|
||||
root=make_abs_path(self.crop_cfg.insightface_root),
|
||||
providers=face_analysis_wrapper_provider,
|
||||
)
|
||||
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
|
||||
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh)
|
||||
self.face_analysis_wrapper.warmup()
|
||||
|
||||
def update_config(self, user_args):
|
||||
@ -117,6 +117,24 @@ class Cropper(object):
|
||||
|
||||
return ret_dct
|
||||
|
||||
def calc_lmk_from_cropped_image(self, img_rgb_, **kwargs):
|
||||
direction = kwargs.get("direction", "large-small")
|
||||
src_face = self.face_analysis_wrapper.get(
|
||||
contiguous(img_rgb_[..., ::-1]), # convert to BGR
|
||||
flag_do_landmark_2d_106=True,
|
||||
direction=direction,
|
||||
)
|
||||
if len(src_face) == 0:
|
||||
log("No face detected in the source image.")
|
||||
return None
|
||||
elif len(src_face) > 1:
|
||||
log(f"More than one face detected in the image, only pick one face by rule {direction}.")
|
||||
src_face = src_face[0]
|
||||
lmk = src_face.landmark_2d_106
|
||||
lmk = self.landmark_runner.run(img_rgb_, lmk)
|
||||
|
||||
return lmk
|
||||
|
||||
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
|
||||
"""Tracking based landmarks/alignment and cropping"""
|
||||
trajectory = Trajectory()
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
from pykalman import KalmanFilter
|
||||
|
||||
|
||||
def smooth(x_d_lst, shape, device, observation_variance=1e-7, process_variance=1e-5):
|
||||
def smooth(x_d_lst, shape, device, observation_variance=3e-7, process_variance=1e-5):
|
||||
x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst]
|
||||
x_d_stacked = np.vstack(x_d_lst_reshape)
|
||||
kf = KalmanFilter(
|
||||
|
Loading…
Reference in New Issue
Block a user