Merge pull request #1 from KwaiVGI/main

update fork
This commit is contained in:
_v3 2024-08-14 18:58:58 +01:00 committed by GitHub
commit 7d2201cb64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
147 changed files with 12313 additions and 788 deletions

17
.gitignore vendored
View File

@ -9,9 +9,26 @@ __pycache__/
**/*.pth **/*.pth
**/*.onnx **/*.onnx
pretrained_weights/*.md
pretrained_weights/docs
pretrained_weights/liveportrait
pretrained_weights/liveportrait_animals
# Ipython notebook # Ipython notebook
*.ipynb *.ipynb
# Temporary files or benchmark resources # Temporary files or benchmark resources
animations/* animations/*
tmp/* tmp/*
.vscode/launch.json
**/*.DS_Store
gradio_temp/**
# Windows dependencies
ffmpeg/
LivePortrait_env/
# XPose build files
src/utils/dependencies/XPose/models/UniPose/ops/build
src/utils/dependencies/XPose/models/UniPose/ops/dist
src/utils/dependencies/XPose/models/UniPose/ops/MultiScaleDeformableAttention.egg-info

View File

@ -19,3 +19,12 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
---
The code of InsightFace is released under the MIT License.
The models of InsightFace are for non-commercial research purposes only.
If you want to use the LivePortrait project for commercial purposes, you
should remove and replace InsightFaces detection models to fully comply with
the MIT license.

434
app.py
View File

@ -1,10 +1,12 @@
# coding: utf-8 # coding: utf-8
""" """
The entrance of the gradio The entrance of the gradio for human
""" """
import os
import tyro import tyro
import subprocess
import gradio as gr import gradio as gr
import os.path as osp import os.path as osp
from src.utils.helper import load_description from src.utils.helper import load_description
@ -18,137 +20,451 @@ def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
# set tyro theme # set tyro theme
tyro.extras.set_accent_color("bright_cyan") tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig) args = tyro.cli(ArgumentConfig)
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
if osp.exists(ffmpeg_dir):
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
# specify configs for inference # specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
# global_tab_selection = None
gradio_pipeline = GradioPipeline( gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg, inference_cfg=inference_cfg,
crop_cfg=crop_cfg, crop_cfg=crop_cfg,
args=args args=args
) )
if args.gradio_temp_dir not in (None, ''):
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
os.makedirs(args.gradio_temp_dir, exist_ok=True)
def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline.execute_video(*args, **kwargs)
def gpu_wrapped_execute_image_retargeting(*args, **kwargs):
return gradio_pipeline.execute_image_retargeting(*args, **kwargs)
def gpu_wrapped_execute_video_retargeting(*args, **kwargs):
return gradio_pipeline.execute_video_retargeting(*args, **kwargs)
def reset_sliders(*args, **kwargs):
return 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5, True, True
# assets # assets
title_md = "assets/gradio_title.md" title_md = "assets/gradio/gradio_title.md"
example_portrait_dir = "assets/examples/source" example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving" example_video_dir = "assets/examples/driving"
data_examples = [ data_examples_i2v = [
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True], [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True], [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True, True], [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True, True], [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True, True], [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
data_examples_v2v = [
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
] ]
#################### interface logic #################### #################### interface logic ####################
# Define components first # Define components first
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
retargeting_input_image = gr.Image(type="numpy") video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch")
head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw")
head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll")
mov_x = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="x-axis movement")
mov_y = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="y-axis movement")
mov_z = gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.01, label="z-axis movement")
lip_variation_zero = gr.Slider(minimum=-0.09, maximum=0.09, value=0, step=0.01, label="pouting")
lip_variation_one = gr.Slider(minimum=-20.0, maximum=15.0, value=0, step=0.01, label="pursing 😐")
lip_variation_two = gr.Slider(minimum=0.0, maximum=15.0, value=0, step=0.01, label="grin 😁")
lip_variation_three = gr.Slider(minimum=-90.0, maximum=120.0, value=0, step=1.0, label="lip close <-> open")
smile = gr.Slider(minimum=-0.3, maximum=1.3, value=0, step=0.01, label="smile 😄")
wink = gr.Slider(minimum=0, maximum=39, value=0, step=0.01, label="wink 😉")
eyebrow = gr.Slider(minimum=-30, maximum=30, value=0, step=0.01, label="eyebrow 🤨")
eyeball_direction_x = gr.Slider(minimum=-30.0, maximum=30.0, value=0, step=0.01, label="eye gaze (horizontal) 👀")
eyeball_direction_y = gr.Slider(minimum=-63.0, maximum=63.0, value=0, step=0.01, label="eye gaze (vertical) 🙄")
retargeting_input_image = gr.Image(type="filepath")
retargeting_input_video = gr.Video()
output_image = gr.Image(type="numpy") output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video() retargeting_output_image = gr.Image(type="numpy")
output_video_concat = gr.Video() retargeting_output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video(autoplay=False)
output_video_paste_back = gr.Video(autoplay=False)
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md)) gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
gr.Markdown(load_description("assets/gradio/gradio_description_upload.md"))
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Source Portrait"): with gr.Column():
image_input = gr.Image(type="filepath") with gr.Tabs():
with gr.TabItem("🖼️ Source Image") as tab_image:
with gr.Accordion(open=True, label="Source Image"):
source_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s9.jpg")],
[osp.join(example_portrait_dir, "s6.jpg")],
[osp.join(example_portrait_dir, "s10.jpg")],
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
[osp.join(example_portrait_dir, "s22.jpg")],
[osp.join(example_portrait_dir, "s23.jpg")],
],
inputs=[source_image_input],
cache_examples=False,
)
with gr.TabItem("🎞️ Source Video") as tab_video:
with gr.Accordion(open=True, label="Source Video"):
source_video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s13.mp4")],
# [osp.join(example_portrait_dir, "s14.mp4")],
# [osp.join(example_portrait_dir, "s15.mp4")],
[osp.join(example_portrait_dir, "s18.mp4")],
# [osp.join(example_portrait_dir, "s19.mp4")],
[osp.join(example_portrait_dir, "s20.mp4")],
],
inputs=[source_video_input],
cache_examples=False,
)
tab_selection = gr.Textbox(visible=False)
tab_image.select(lambda: "Image", None, tab_selection)
tab_video.select(lambda: "Video", None, tab_selection)
with gr.Accordion(open=True, label="Cropping Options for Source Image or Video"):
with gr.Row():
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column():
with gr.Tabs():
with gr.TabItem("🎞️ Driving Video") as v_tab_video:
with gr.Accordion(open=True, label="Driving Video"): with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video() driving_video_input = gr.Video()
gr.Markdown(load_description("assets/gradio_description_animation.md")) gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
[osp.join(example_video_dir, "d20.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
)
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
gr.Examples(
examples=[
[osp.join(example_video_dir, "d1.pkl")],
[osp.join(example_video_dir, "d2.pkl")],
[osp.join(example_video_dir, "d5.pkl")],
[osp.join(example_video_dir, "d7.pkl")],
[osp.join(example_video_dir, "d8.pkl")],
],
inputs=[driving_video_pickle_input],
cache_examples=False,
)
v_tab_selection = gr.Textbox(visible=False)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
v_tab_video.select(lambda: "Video", None, v_tab_selection)
# with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
with gr.Row():
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Animation Options"): with gr.Accordion(open=True, label="Animation Options"):
with gr.Row(): with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_stitching_input = gr.Checkbox(value=True, label="stitching")
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
with gr.Row(): with gr.Row():
with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary") process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"): with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video.render() output_video_i2v.render()
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video"): with gr.Accordion(open=True, label="The animated video"):
output_video_concat.render() output_video_concat_i2v.render()
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
with gr.Row(): with gr.Row():
# Examples # Examples
gr.Markdown("## You could choose the examples below ⬇️") gr.Markdown("## You could also choose the examples below by one click ⬇️")
with gr.Row(): with gr.Row():
with gr.Tabs():
with gr.TabItem("🖼️ Portrait Animation"):
gr.Examples( gr.Examples(
examples=data_examples, examples=data_examples_i2v,
fn=gpu_wrapped_execute_video,
inputs=[ inputs=[
image_input, source_image_input,
video_input, driving_video_input,
flag_relative_input, flag_relative_input,
flag_do_crop_input, flag_do_crop_input,
flag_remap_input flag_remap_input,
flag_crop_driving_video_input,
], ],
examples_per_page=5 outputs=[output_image, output_image_paste_back],
examples_per_page=len(data_examples_i2v),
cache_examples=False,
) )
gr.Markdown(load_description("assets/gradio_description_retargeting.md")) with gr.TabItem("🎞️ Portrait Video Editing"):
with gr.Row(): gr.Examples(
examples=data_examples_v2v,
fn=gpu_wrapped_execute_video,
inputs=[
source_video_input,
driving_video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
driving_smooth_observation_variance,
],
outputs=[output_image, output_image_paste_back],
examples_per_page=len(data_examples_v2v),
cache_examples=False,
)
# Retargeting Image
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True):
flag_do_crop_input_retargeting_image = gr.Checkbox(value=True, label="do crop (source)")
flag_stitching_retargeting_input = gr.Checkbox(value=True, label="stitching")
retargeting_source_scale.render()
eye_retargeting_slider.render() eye_retargeting_slider.render()
lip_retargeting_slider.render() lip_retargeting_slider.render()
with gr.Row(): with gr.Row(visible=True):
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary") with gr.Column():
with gr.Accordion(open=True, label="Facial movement sliders"):
with gr.Row(visible=True):
head_pitch_slider.render()
head_yaw_slider.render()
head_roll_slider.render()
with gr.Row(visible=True):
mov_x.render()
mov_y.render()
mov_z.render()
with gr.Column():
with gr.Accordion(open=True, label="Facial expression sliders"):
with gr.Row(visible=True):
lip_variation_zero.render()
lip_variation_one.render()
lip_variation_two.render()
with gr.Row(visible=True):
lip_variation_three.render()
smile.render()
wink.render()
with gr.Row(visible=True):
eyebrow.render()
eyeball_direction_x.render()
eyeball_direction_y.render()
with gr.Row(visible=True):
reset_button = gr.Button("🔄 Reset")
reset_button.click(
fn=reset_sliders,
inputs=None,
outputs=[
head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z,
lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y,
retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image
]
)
with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Image Input"):
retargeting_input_image.render()
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s9.jpg")],
[osp.join(example_portrait_dir, "s6.jpg")],
[osp.join(example_portrait_dir, "s10.jpg")],
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
[osp.join(example_portrait_dir, "s22.jpg")],
# [osp.join(example_portrait_dir, "s23.jpg")],
[osp.join(example_portrait_dir, "s42.jpg")],
],
inputs=[retargeting_input_image],
cache_examples=False,
)
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"):
retargeting_output_image.render()
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
retargeting_output_image_paste_back.render()
with gr.Row(visible=True):
process_button_reset_retargeting = gr.ClearButton( process_button_reset_retargeting = gr.ClearButton(
[ [
eye_retargeting_slider,
lip_retargeting_slider,
retargeting_input_image, retargeting_input_image,
output_image, retargeting_output_image,
output_image_paste_back retargeting_output_image_paste_back,
], ],
value="🧹 Clear" value="🧹 Clear"
) )
with gr.Row():
# Retargeting Video
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting_video.md"), visible=True)
with gr.Row(visible=True):
flag_do_crop_input_retargeting_video = gr.Checkbox(value=True, label="do crop (source)")
video_retargeting_source_scale.render()
video_lip_retargeting_slider.render()
driving_smooth_observation_variance_retargeting.render()
with gr.Row(visible=True):
process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
with gr.Row(visible=True):
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"): with gr.Accordion(open=True, label="Retargeting Video Input"):
retargeting_input_image.render() retargeting_input_video.render()
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s13.mp4")],
# [osp.join(example_portrait_dir, "s18.mp4")],
[osp.join(example_portrait_dir, "s20.mp4")],
[osp.join(example_portrait_dir, "s29.mp4")],
[osp.join(example_portrait_dir, "s32.mp4")],
],
inputs=[retargeting_input_video],
cache_examples=False,
)
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"): with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render() output_video.render()
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"): with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render() output_video_paste_back.render()
# binding functions for buttons with gr.Row(visible=True):
process_button_retargeting.click( process_button_reset_retargeting = gr.ClearButton(
fn=gradio_pipeline.execute_image, [
inputs=[eye_retargeting_slider, lip_retargeting_slider], video_lip_retargeting_slider,
outputs=[output_image, output_image_paste_back], retargeting_input_video,
show_progress=True output_video,
output_video_paste_back
],
value="🧹 Clear"
) )
# binding functions for buttons
process_button_animation.click( process_button_animation.click(
fn=gradio_pipeline.execute_video, fn=gpu_wrapped_execute_video,
inputs=[ inputs=[
image_input, source_image_input,
video_input, source_video_input,
driving_video_pickle_input,
driving_video_input,
flag_relative_input, flag_relative_input,
flag_do_crop_input, flag_do_crop_input,
flag_remap_input flag_remap_input,
flag_stitching_input,
driving_option_input,
driving_multiplier,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
scale,
vx_ratio,
vy_ratio,
scale_crop_driving_video,
vx_ratio_crop_driving_video,
vy_ratio_crop_driving_video,
driving_smooth_observation_variance,
tab_selection,
v_tab_selection,
], ],
outputs=[output_video, output_video_concat], outputs=[output_video_i2v, output_video_concat_i2v],
show_progress=True show_progress=True
) )
image_input.change(
fn=gradio_pipeline.prepare_retargeting, retargeting_input_image.change(
inputs=image_input, fn=gradio_pipeline.init_retargeting_image,
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image] inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image],
outputs=[eye_retargeting_slider, lip_retargeting_slider]
) )
########################################################## sliders = [eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y]
for slider in sliders:
# NOTE: gradio >= 4.0.0 may cause slow response
slider.change(
fn=gpu_wrapped_execute_image_retargeting,
inputs=[
eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z,
lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y,
retargeting_input_image, retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image
],
outputs=[retargeting_output_image, retargeting_output_image_paste_back],
)
process_button_retargeting_video.click(
fn=gpu_wrapped_execute_video_retargeting,
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video],
outputs=[output_video, output_video_paste_back],
show_progress=True
)
demo.launch( demo.launch(
server_name=args.server_name,
server_port=args.server_port, server_port=args.server_port,
share=args.share, share=args.share,
server_name=args.server_name
) )

248
app_animals.py Normal file
View File

@ -0,0 +1,248 @@
# coding: utf-8
"""
The entrance of the gradio for animal
"""
import os
import tyro
import subprocess
import gradio as gr
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipelineAnimal
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
if osp.exists(ffmpeg_dir):
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
gradio_pipeline_animal: GradioPipelineAnimal = GradioPipelineAnimal(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
if args.gradio_temp_dir not in (None, ''):
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
os.makedirs(args.gradio_temp_dir, exist_ok=True)
def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline_animal.execute_video(*args, **kwargs)
# assets
title_md = "assets/gradio/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples_i2v = [
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "d3.mp4"), True, False, False, False],
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "d6.mp4"), True, False, False, False],
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "d19.mp4"), True, False, False, False],
]
data_examples_i2v_pickle = [
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "wink.pkl"), True, False, False, False],
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "talking.pkl"), True, False, False, False],
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "aggrieved.pkl"), True, False, False, False],
]
#################### interface logic ####################
# Define components first
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
output_video_i2v_gif = gr.Image(type="numpy")
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio/gradio_description_upload_animal.md"))
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="🐱 Source Animal Image"):
source_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s25.jpg")],
[osp.join(example_portrait_dir, "s30.jpg")],
[osp.join(example_portrait_dir, "s31.jpg")],
[osp.join(example_portrait_dir, "s32.jpg")],
[osp.join(example_portrait_dir, "s39.jpg")],
[osp.join(example_portrait_dir, "s40.jpg")],
[osp.join(example_portrait_dir, "s41.jpg")],
[osp.join(example_portrait_dir, "s38.jpg")],
[osp.join(example_portrait_dir, "s36.jpg")],
],
inputs=[source_image_input],
cache_examples=False,
)
with gr.Accordion(open=True, label="Cropping Options for Source Image"):
with gr.Row():
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column():
with gr.Tabs():
with gr.TabItem("📁 Driving Pickle") as tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File()
gr.Examples(
examples=[
[osp.join(example_video_dir, "wink.pkl")],
[osp.join(example_video_dir, "shy.pkl")],
[osp.join(example_video_dir, "aggrieved.pkl")],
[osp.join(example_video_dir, "open_lip.pkl")],
[osp.join(example_video_dir, "laugh.pkl")],
[osp.join(example_video_dir, "talking.pkl")],
[osp.join(example_video_dir, "shake_face.pkl")],
],
inputs=[driving_video_pickle_input],
cache_examples=False,
)
with gr.TabItem("🎞️ Driving Video") as tab_video:
with gr.Accordion(open=True, label="Driving Video"):
driving_video_input = gr.Video()
gr.Examples(
examples=[
# [osp.join(example_video_dir, "d0.mp4")],
# [osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
[osp.join(example_video_dir, "d3.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
)
tab_selection = gr.Textbox(visible=False)
tab_pickle.select(lambda: "Pickle", None, tab_selection)
tab_video.select(lambda: "Video", None, tab_selection)
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
with gr.Row():
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Row():
with gr.Accordion(open=False, label="Animation Options"):
with gr.Row():
flag_stitching = gr.Checkbox(value=False, label="stitching (not recommended)")
flag_remap_input = gr.Checkbox(value=False, label="paste-back (not recommended)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier", minimum=0.0, maximum=2.0, step=0.02)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
with gr.Row():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the cropped image space"):
output_video_i2v.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated gif in the cropped image space"):
output_video_i2v_gif.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, driving_video_input, output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], value="🧹 Clear")
with gr.Row():
# Examples
gr.Markdown("## You could also choose the examples below by one click ⬇️")
with gr.Row():
with gr.Tabs():
with gr.TabItem("📁 Driving Pickle") as tab_video:
gr.Examples(
examples=data_examples_i2v_pickle,
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_pickle_input,
flag_do_crop_input,
flag_stitching,
flag_remap_input,
flag_crop_driving_video_input,
],
outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
examples_per_page=len(data_examples_i2v_pickle),
cache_examples=False,
)
with gr.TabItem("🎞️ Driving Video") as tab_video:
gr.Examples(
examples=data_examples_i2v,
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_input,
flag_do_crop_input,
flag_stitching,
flag_remap_input,
flag_crop_driving_video_input,
],
outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
examples_per_page=len(data_examples_i2v),
cache_examples=False,
)
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_input,
driving_video_pickle_input,
flag_do_crop_input,
flag_remap_input,
driving_multiplier,
flag_stitching,
flag_crop_driving_video_input,
scale,
vx_ratio,
vy_ratio,
scale_crop_driving_video,
vx_ratio_crop_driving_video,
vy_ratio_crop_driving_video,
tab_selection,
],
outputs=[output_video_i2v, output_video_concat_i2v, output_video_i2v_gif],
show_progress=True
)
demo.launch(
server_port=args.server_port,
share=args.share,
server_name=args.server_name
)

2
assets/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
examples/driving/*.pkl
examples/driving/*_crop.mp4

Binary file not shown.

After

Width:  |  Height:  |  Size: 364 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 344 KiB

View File

@ -0,0 +1,22 @@
## 2024/07/10
**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️
The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
### Updates
- <strong>Audio and video concatenating: </strong> If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option.
### About driving video
- For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
### Others
- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).

View File

@ -0,0 +1,24 @@
## 2024/07/19
**Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! 🎉**
We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk).
### Updates
- <strong>Portrait video editing (v2v):</strong> Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `--flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with `--flag_source_video_eye_retargeting`.
- <strong>More options in Gradio:</strong> We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control.
<p align="center">
<img src="../LivePortrait-Gradio-2024-07-19.jpg" alt="LivePortrait" width="800px">
<br>
The Gradio Interface for LivePortrait
</p>
### Community Contributions
- **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance:
- [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KwaiVGI/LivePortrait/issues/150))
- [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KwaiVGI/LivePortrait/issues/126), [#142](https://github.com/KwaiVGI/LivePortrait/issues/142))
- **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KwaiVGI/LivePortrait/issues/119).

View File

@ -0,0 +1,12 @@
## 2024/07/24
### Updates
- **Portrait pose editing:** You can change the `relative pitch`, `relative yaw`, and `relative roll` in the Gradio interface to adjust the pose of the source portrait.
- **Detection threshold:** We have added a `--det_thresh` argument with a default value of 0.15 to increase recall, meaning more types of faces (e.g., monkeys, human-like) will be detected. You can set it to other values, e.g., 0.5, by using `python app.py --det_thresh 0.5`.
<p align="center">
<img src="../pose-edit-2024-07-24.jpg" alt="LivePortrait" width="960px">
<br>
Pose Editing in the Gradio Interface
</p>

View File

@ -0,0 +1,75 @@
## 2024/08/02
<table class="center" style="width: 80%; margin-left: auto; margin-right: auto;">
<tr>
<td style="text-align: center"><b>Animals Singing Dance Monkey 🎤</b></td>
</tr>
<tr>
<td style="border: none; text-align: center;">
<video controls loop src="https://github.com/user-attachments/assets/38d5b6e5-d29b-458d-9f2c-4dd52546cb41" muted="false" style="width: 60%;"></video>
</td>
</tr>
</table>
🎉 We are excited to announce the release of a new version featuring animals mode, along with several other updates. Special thanks to the dedicated efforts of the LivePortrait team. 💪 We also provided an one-click installer for Windows users, checkout the details [here](./2024-08-05.md).
### Updates on Animals mode
We are pleased to announce the release of the animals mode, which is fine-tuned on approximately 230K frames of various animals (mostly cats and dogs). The trained weights have been updated in the `liveportrait_animals` subdirectory, available on [HuggingFace](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/) or [Google Drive](https://drive.google.com/drive/u/0/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib). You should [download the weights](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#2-download-pretrained-weights) before running. There are two ways to run this mode.
> Please note that we have not trained the stitching and retargeting modules for the animals model due to several technical issues. _This may be addressed in future updates._ Therefore, we recommend **disabling stitching by setting the `--no_flag_stitching`** option when running the model. Additionally, `paste-back` is also not recommended.
#### Install X-Pose
We have chosen [X-Pose](https://github.com/IDEA-Research/X-Pose) as the keypoints detector for animals. This relies on `transformers==4.22.0` and `pillow>=10.2.0` (which are already updated in `requirements.txt`) and requires building an OP named `MultiScaleDeformableAttention`.
Refer to the [PyTorch installation](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#for-linux-or-windows-users) for Linux and Windows users.
Next, build the OP `MultiScaleDeformableAttention` by running:
```bash
cd src/utils/dependencies/XPose/models/UniPose/ops
python setup.py build install
cd - # this returns to the previous directory
```
To run the model, use the `inference_animals.py` script:
```bash
python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --no_flag_stitching --driving_multiplier 1.75
```
Alternatively, you can use Gradio for a more user-friendly interface. Launch it with:
```bash
python app_animals.py # --server_port 8889 --server_name "0.0.0.0" --share
```
> [!WARNING]
> [X-Pose](https://github.com/IDEA-Research/X-Pose) is only for Non-commercial Scientific Research Purposes, you should remove and replace it with other detectors if you use it for commercial purposes.
### Updates on Humans mode
- **Driving Options**: We have introduced an `expression-friendly` driving option to **reduce head wobbling**, now set as the default. While it may be less effective with large head poses, you can also select the `pose-friendly` option, which is the same as the previous version. This can be set using `--driving_option` or selected in the Gradio interface. Additionally, we added a `--driving_multiplier` option to adjust driving intensity, with a default value of 1, which can also be set in the Gradio interface.
- **Retargeting Video in Gradio**: We have implemented a video retargeting feature. You can specify a `target lip-open ratio` to adjust the mouth movement in the source video. For instance, setting it to 0 will close the mouth in the source video 🤐.
### Others
- [**Poe supports LivePortrait**](https://poe.com/LivePortrait). Check out the news on [X](https://x.com/poe_platform/status/1816136105781256260).
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) (1.1K 🌟) now includes MediaPipe as an alternative to InsightFace, ensuring the license remains under MIT and Apache 2.0.
- [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) features real-time portrait pose/expression editing and animation, and is registered with ComfyUI-Manager.
**Below are some screenshots of the new features and improvements:**
| ![The Gradio Interface of Animals Mode](../animals-mode-gradio-2024-08-02.jpg) |
|:---:|
| **The Gradio Interface of Animals Mode** |
| ![Driving Options and Multiplier](../driving-option-multiplier-2024-08-02.jpg) |
|:---:|
| **Driving Options and Multiplier** |
| ![The Feature of Retargeting Video](../retargeting-video-2024-08-02.jpg) |
|:---:|
| **The Feature of Retargeting Video** |

View File

@ -0,0 +1,18 @@
## One-click Windows Installer
### Download the installer from HuggingFace
```bash
# !pip install -U "huggingface_hub[cli]"
huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip
```
If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download:
```bash
# !pip install -U "huggingface_hub[cli]"
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip
```
Alternatively, you can manually download it from the [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) page.
Then, simply unzip the package `LivePortrait-Windows-v20240806.zip` and double-click `run_windows_human.bat` for the Humans mode, or `run_windows_animal.bat` for the **Animals mode**.

View File

@ -0,0 +1,9 @@
## Precise Portrait Editing
Inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) ([@PowerHouseMan](https://github.com/PowerHouseMan)), we have implemented a version of Precise Portrait Editing in the Gradio interface. With each adjustment of the slider, the edited image updates in real-time. You can click the `🔄 Reset` button to reset all slider parameters. However, the performance may not be as fast as the ComfyUI plugin.
<p align="center">
<img src="../editing-portrait-2024-08-06.jpg" alt="LivePortrait" width="960px">
<br>
Preciese Portrait Editing in the Gradio Interface
</p>

View File

@ -0,0 +1,28 @@
## The directory structure of `pretrained_weights`
```text
pretrained_weights
├── insightface
│ └── models
│ └── buffalo_l
│ ├── 2d106det.onnx
│ └── det_10g.onnx
├── liveportrait
│ ├── base_models
│ │ ├── appearance_feature_extractor.pth
│ │ ├── motion_extractor.pth
│ │ ├── spade_generator.pth
│ │ └── warping_module.pth
│ ├── landmark.onnx
│ └── retargeting_models
│ └── stitching_retargeting_module.pth
└── liveportrait_animals
├── base_models
│ ├── appearance_feature_extractor.pth
│ ├── motion_extractor.pth
│ ├── spade_generator.pth
│ └── warping_module.pth
├── retargeting_models
│ └── stitching_retargeting_module.pth
└── xpose.pth
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 301 KiB

View File

@ -0,0 +1,29 @@
## Install FFmpeg
Make sure you have `ffmpeg` and `ffprobe` installed on your system. If you don't have them installed, follow the instructions below.
> [!Note]
> The installation is copied from [SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 🤗
### Conda Users
```bash
conda install ffmpeg
```
### Ubuntu/Debian Users
```bash
sudo apt install ffmpeg
sudo apt install libsox-dev
conda install -c conda-forge 'ffmpeg<7'
```
### Windows Users
Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root.
### MacOS Users
```bash
brew install ffmpeg
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 491 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 217 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

13
assets/docs/speed.md Normal file
View File

@ -0,0 +1,13 @@
### Speed
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
| Model | Parameters(M) | Model Size(MB) | Inference(ms) |
|-----------------------------------|:-------------:|:--------------:|:-------------:|
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
| Motion Extractor | 28.12 | 108 | 0.84 |
| Spade Generator | 55.37 | 212 | 7.59 |
| Warping Module | 45.53 | 174 | 5.21 |
| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 390 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 500 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 220 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

View File

@ -0,0 +1,6 @@
<div style="font-size: 1.2em; text-align: center;">
Step 3: Click the <strong>🚀 Animate</strong> button below to generate, or click <strong>🧹 Clear</strong> to erase the results
</div>
<!-- <div style="font-size: 1.1em; text-align: center;">
<strong style="color: red;">Note:</strong> If both <strong>Source Image</strong> and <strong>Video</strong> are uploaded, the <strong>Source Image</strong> will be used. Please click the <strong>🧹 Clear</strong> button, then re-upload the <strong>Source Image</strong> or <strong>Video</strong>.
</div> -->

View File

@ -0,0 +1,19 @@
<span style="font-size: 1.2em;">🔥 To animate the source image or video with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. In the <strong>Animation Options for Source Image or Video</strong> section, we recommend enabling the <code>do crop (source)</code> option if faces occupy a small portion of your source image or video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. In the <strong>Animation Options for Driving Video</strong> section, the <code>relative head rotation</code> and <code>smooth strength</code> options only take effect if the source input is a video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
3. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
4. If you want to upload your own driving video, <strong>the best practice</strong>:
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
</div>

View File

@ -0,0 +1,13 @@
<br>
<!-- ## Retargeting -->
<!-- <span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span> -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
<div>
<h2>Retargeting and Editing Portraits</h2>
<p>Upload a source portrait, and the <code>eyes-open ratio</code> and <code>lip-open ratio</code> will be auto-calculated. Adjust the sliders to see instant edits. Feel free to experiment! 🎨</p>
<strong>😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!</strong></p>
</div>
</div>

View File

@ -0,0 +1,9 @@
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
<div>
<h2>Retargeting Video</h2>
<p>Upload a Source Video as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting Video</strong> button. You can try running it multiple times.
<br>
<strong>🤐 Set target lip-open ratio to 0 to see what's going on!</strong></p>
</div>
</div>

View File

@ -0,0 +1,19 @@
<br>
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
<div style="flex: 1; text-align: center; margin-right: 20px;">
<div style="display: inline-block;">
Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
</div>
<div style="display: inline-block; font-size: 0.8em;">
<strong>Note:</strong> Better if Source Video has <strong>the same FPS</strong> as the Driving Video.
</div>
</div>
<div style="flex: 1; text-align: center; margin-left: 20px;">
<div style="display: inline-block;">
Step 2: Upload a <strong>Driving Video</strong> (any aspect ratio) ⬇️
</div>
<div style="display: inline-block; font-size: 0.8em;">
<strong>Tips:</strong> Focus on the head, minimize shoulder movement, <strong>neutral expression</strong> in first frame.
</div>
</div>
</div>

View File

@ -0,0 +1,16 @@
<br>
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
<div style="flex: 1; text-align: center; margin-right: 20px;">
<div style="display: inline-block;">
Step 1: Upload a <strong>Source Animal Image</strong> (any aspect ratio) ⬇️
</div>
</div>
<div style="flex: 1; text-align: center; margin-left: 20px;">
<div style="display: inline-block;">
Step 2: Upload a <strong>Driving Pickle</strong> or <strong>Driving Video</strong> (any aspect ratio) ⬇️
</div>
<div style="display: inline-block; font-size: 0.8em;">
<strong>Tips:</strong> Focus on the head, minimize shoulder movement, <strong>neutral expression</strong> in first frame.
</div>
</div>
</div>

View File

@ -0,0 +1,20 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<!-- <span>Add mimics and lip sync to your static portrait driven by a video</span> -->
<!-- <span>Efficient Portrait Animation with Stitching and Retargeting Control</span> -->
<!-- <br> -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
&nbsp;
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
&nbsp;
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
&nbsp;
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
&nbsp;
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/github/stars/KwaiVGI/LivePortrait
"></a>
</div>
</div>
</div>

View File

@ -1,7 +0,0 @@
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>

View File

@ -1 +0,0 @@
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>

View File

@ -1,2 +0,0 @@
## 🤗 This is the official gradio demo for **LivePortrait**.
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>

View File

@ -1,10 +0,0 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div>
</div>
</div>

View File

@ -1,6 +1,12 @@
# coding: utf-8 # coding: utf-8
"""
for human
"""
import os
import os.path as osp
import tyro import tyro
import subprocess
from src.config.argument_config import ArgumentConfig from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig from src.config.crop_config import CropConfig
@ -11,14 +17,40 @@ def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source):
raise FileNotFoundError(f"source info not found: {args.source}")
if not osp.exists(args.driving):
raise FileNotFoundError(f"driving info not found: {args.driving}")
def main(): def main():
# set tyro theme # set tyro theme
tyro.extras.set_accent_color("bright_cyan") tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig) args = tyro.cli(ArgumentConfig)
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
if osp.exists(ffmpeg_dir):
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
fast_check_args(args)
# specify configs for inference # specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig crop_cfg = partial_fields(CropConfig, args.__dict__)
live_portrait_pipeline = LivePortraitPipeline( live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=inference_cfg, inference_cfg=inference_cfg,
@ -29,5 +61,5 @@ def main():
live_portrait_pipeline.execute(args) live_portrait_pipeline.execute(args)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

65
inference_animals.py Normal file
View File

@ -0,0 +1,65 @@
# coding: utf-8
"""
for animal
"""
import os
import os.path as osp
import tyro
import subprocess
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline_animal import LivePortraitPipelineAnimal
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source):
raise FileNotFoundError(f"source info not found: {args.source}")
if not osp.exists(args.driving):
raise FileNotFoundError(f"driving info not found: {args.driving}")
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
if osp.exists(ffmpeg_dir):
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
fast_check_args(args)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
live_portrait_pipeline_animal = LivePortraitPipelineAnimal(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)
# run
live_portrait_pipeline_animal.execute(args)
if __name__ == "__main__":
main()

218
readme.md
View File

@ -2,9 +2,9 @@
<div align='center'> <div align='center'>
<a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup>&emsp; <a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup>&emsp; <a href='https://github.com/Mystery099' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup>&emsp; <a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup>&emsp; <a href='https://github.com/zzzweakman' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup>&emsp;
<a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup>&emsp; <a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup>&emsp;
</div> </div>
@ -16,6 +16,9 @@
<div align='center'> <div align='center'>
<sup>1 </sup>Kuaishou Technology&emsp; <sup>2 </sup>University of Science and Technology of China&emsp; <sup>3 </sup>Fudan University&emsp; <sup>1 </sup>Kuaishou Technology&emsp; <sup>2 </sup>University of Science and Technology of China&emsp; <sup>3 </sup>Fudan University&emsp;
</div> </div>
<div align='center'>
<small><sup></sup> Corresponding author</small>
</div>
<br> <br>
<div align="center"> <div align="center">
@ -23,6 +26,7 @@
<a href='https://arxiv.org/pdf/2407.03168'><img src='https://img.shields.io/badge/arXiv-LivePortrait-red'></a> <a href='https://arxiv.org/pdf/2407.03168'><img src='https://img.shields.io/badge/arXiv-LivePortrait-red'></a>
<a href='https://liveportrait.github.io'><img src='https://img.shields.io/badge/Project-LivePortrait-green'></a> <a href='https://liveportrait.github.io'><img src='https://img.shields.io/badge/Project-LivePortrait-green'></a>
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/github/stars/KwaiVGI/LivePortrait"></a>
</div> </div>
<br> <br>
@ -33,55 +37,102 @@
</p> </p>
## 🔥 Updates ## 🔥 Updates
- **`2024/07/04`**: 🔥 We released the initial version of the inference code and models. Continuous updates, stay tuned! - **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, insipred by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md).
- **`2024/07/04`**: 😊 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168). - **`2024/08/05`**: 📦 Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md).
- **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)!
- **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main) or [BaiduYun](https://pan.baidu.com/s/1FWsWqKe0eNfXrwjEhhCqlw?pwd=86q2). Simply unzip and double-click `run_windows.bat` to enjoy!
- **`2024/07/24`**: 🎨 We support pose editing for source portraits in the Gradio interface. Weve also lowered the default detection threshold to increase recall. [Have fun](assets/docs/changelog/2024-07-24.md)!
- **`2024/07/19`**: ✨ We support 🎞️ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md).
- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
- **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned!
- **`2024/07/04`**: 🔥 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
## Introduction
## Introduction 📖
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168). This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖. We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
## 🔥 Getting Started ## Getting Started 🏁
### 1. Clone the code and prepare the environment ### 1. Clone the code and prepare the environment 🛠️
> [!Note]
> Make sure your system has [`git`](https://git-scm.com/), [`conda`](https://anaconda.org/anaconda/conda), and [`FFmpeg`](https://ffmpeg.org/download.html) installed. For details on FFmpeg installation, see [**how to install FFmpeg**](assets/docs/how-to-install-ffmpeg.md).
```bash ```bash
git clone https://github.com/KwaiVGI/LivePortrait git clone https://github.com/KwaiVGI/LivePortrait
cd LivePortrait cd LivePortrait
# create env using conda # create env using conda
conda create -n LivePortrait python==3.9.18 conda create -n LivePortrait python=3.9
conda activate LivePortrait conda activate LivePortrait
# install dependencies with pip ```
#### For Linux or Windows Users
[X-Pose](https://github.com/IDEA-Research/X-Pose) requires your `torch` version to be compatible with the CUDA version.
Firstly, check your current CUDA version by:
```bash
nvcc -V # example versions: 11.1, 11.8, 12.1, etc.
```
Then, install the corresponding torch version. Here are examples for different CUDA versions. Visit the [PyTorch Official Website](https://pytorch.org/get-started/previous-versions) for installation commands if your CUDA version is not listed:
```bash
# for CUDA 11.1
pip install torch==1.10.1+cu111 torchvision==0.11.2 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
# for CUDA 11.8
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
# for CUDA 12.1
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
# ...
```
Finally, install the remaining dependencies:
```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### 2. Download pretrained weights #### For macOS with Apple Silicon Users
Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows: The [X-Pose](https://github.com/IDEA-Research/X-Pose) dependency does not support macOS, so you can skip its installation. While Humans mode works as usual, Animals mode is not supported. Use the provided requirements file for macOS with Apple Silicon:
```text ```bash
pretrained_weights # for macOS with Apple Silicon users
├── insightface pip install -r requirements_macOS.txt
│ └── models
│ └── buffalo_l
│ ├── 2d106det.onnx
│ └── det_10g.onnx
└── liveportrait
├── base_models
│ ├── appearance_feature_extractor.pth
│ ├── motion_extractor.pth
│ ├── spade_generator.pth
│ └── warping_module.pth
├── landmark.onnx
└── retargeting_models
└── stitching_retargeting_module.pth
``` ```
### 2. Download pretrained weights 📥
The easiest way to download the pretrained weights is from HuggingFace:
```bash
# !pip install -U "huggingface_hub[cli]"
huggingface-cli download KwaiVGI/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs"
```
If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download:
```bash
# !pip install -U "huggingface_hub[cli]"
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download KwaiVGI/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs"
```
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn) (WIP). Unzip and place them in `./pretrained_weights`.
Ensuring the directory structure is as or contains [**this**](assets/docs/directory-structure.md).
### 3. Inference 🚀 ### 3. Inference 🚀
#### Fast hands-on (humans) 👤
```bash ```bash
# For Linux and Windows users
python inference.py python inference.py
# For macOS users with Apple Silicon (Intel is not tested). NOTE: this maybe 20x slower than RTX 4090
PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py
``` ```
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result. If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image or video, and generated result.
<p align="center"> <p align="center">
<img src="./assets/docs/inference.gif" alt="image"> <img src="./assets/docs/inference.gif" alt="image">
@ -90,55 +141,122 @@ If the script runs successfully, you will get an output mp4 file named `animatio
Or, you can change the input by specifying the `-s` and `-d` arguments: Or, you can change the input by specifying the `-s` and `-d` arguments:
```bash ```bash
# source input is an image
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
# or disable pasting back # source input is a video ✨
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d0.mp4
# more options to see # more options to see
python inference.py -h python inference.py -h
``` ```
**More interesting results can be found in our [Homepage](https://liveportrait.github.io)** 😊 #### Fast hands-on (animals) 🐱🐶
Animals mode is ONLY tested on Linux and Windows with NVIDIA GPU.
### 4. Gradio interface You need to build an OP named `MultiScaleDeformableAttention` first, which is used by [X-Pose](https://github.com/IDEA-Research/X-Pose), a general keypoint detection framework.
```bash
cd src/utils/dependencies/XPose/models/UniPose/ops
python setup.py build install
cd - # equal to cd ../../../../../../../
```
We also provide a Gradio interface for a better experience, just run by: Then
```bash
python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --driving_multiplier 1.75 --no_flag_stitching
```
If the script runs successfully, you will get an output mp4 file named `animations/s39--wink_concat.mp4`.
<p align="center">
<img src="./assets/docs/inference-animals.gif" alt="image">
</p>
#### Driving video auto-cropping 📢📢📢
> [!IMPORTANT]
> To use your own driving video, we **recommend**: ⬇️
> - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
> - Focus on the head area, similar to the example videos.
> - Minimize shoulder movement.
> - Make sure the first frame of driving video is a frontal face with **neutral expression**.
Below is an auto-cropping case by `--flag_crop_driving_video`:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video
```
If you find the results of auto-cropping is not well, you can modify the `--scale_crop_driving_video`, `--vy_ratio_crop_driving_video` options to adjust the scale and offset, or do it manually.
#### Motion template making
You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl # portrait animation
python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d5.pkl # portrait video editing
```
### 4. Gradio interface 🤗
We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
```bash ```bash
python app.py # For Linux and Windows users (and macOS with Intel??)
python app.py # humans mode
# For macOS with Apple Silicon users, Intel not supported, this maybe 20x slower than RTX 4090
PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py # humans mode
``` ```
We also provide a Gradio interface of animals mode, which is only tested on Linux with NVIDIA GPU:
```bash
python app_animals.py # animals mode 🐱🐶
```
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
🚀 We also provide an acceleration option `--flag_do_torch_compile`. The first-time inference triggers an optimization process (about one minute), making subsequent inferences 20-30% faster. Performance gains may vary with different CUDA versions.
```bash
# enable torch.compile for faster inference
python app.py --flag_do_torch_compile
```
**Note**: This method is not supported on Windows and macOS.
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
### 5. Inference speed evaluation 🚀🚀🚀 ### 5. Inference speed evaluation 🚀🚀🚀
We have also provided a script to evaluate the inference speed of each module: We have also provided a script to evaluate the inference speed of each module:
```bash ```bash
# For NVIDIA GPU
python speed.py python speed.py
``` ```
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`: The results are [**here**](./assets/docs/speed.md).
| Model | Parameters(M) | Model Size(MB) | Inference(ms) | ## Community Resources 🤗
|-----------------------------------|:-------------:|:--------------:|:-------------:|
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
| Motion Extractor | 28.12 | 108 | 0.84 |
| Spade Generator | 55.37 | 212 | 7.59 |
| Warping Module | 45.53 | 174 | 5.21 |
| Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 |
*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.* Discover the invaluable resources contributed by our community to enhance your LivePortrait experience:
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai)
- [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) by [@PowerHouseMan](https://github.com/PowerHouseMan).
- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007)
- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker)
- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch)
- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph)
- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr)
## Acknowledgements And many more amazing contributions from our community!
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
## Acknowledgements 💐
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions.
## Citation 💖 ## Citation 💖
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX: If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
```bibtex ```bibtex
@article{guo2024live, @article{guo2024liveportrait,
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control}, title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang}, author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di},
year = {2024}, journal = {arXiv preprint arXiv:2407.03168},
journal = {arXiv preprint:2407.03168}, year = {2024}
} }
``` ```
## Contact 📧
[**Jianzhu Guo (郭建珠)**](https://guojianzhu.com); **guojianzhu1994@gmail.com**

View File

@ -1,22 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu118 -r requirements_base.txt
torch==2.3.0
torchvision==0.18.0
torchaudio==2.3.0
numpy==1.26.4
pyyaml==6.0.1
opencv-python==4.10.0.84
scipy==1.13.1
imageio==2.34.2
lmdb==1.4.1
tqdm==4.66.4
rich==13.7.1
ffmpeg==1.4
onnxruntime-gpu==1.18.0 onnxruntime-gpu==1.18.0
onnx==1.16.1 transformers==4.22.0
scikit-image==0.24.0
albumentations==1.4.10
matplotlib==3.9.0
imageio-ffmpeg==0.5.1
tyro==0.8.5
gradio==4.37.1

18
requirements_base.txt Normal file
View File

@ -0,0 +1,18 @@
numpy==1.26.4
pyyaml==6.0.1
opencv-python==4.10.0.84
scipy==1.13.1
imageio==2.34.2
lmdb==1.4.1
tqdm==4.66.4
rich==13.7.1
ffmpeg-python==0.2.0
onnx==1.16.1
scikit-image==0.24.0
albumentations==1.4.10
matplotlib==3.9.0
imageio-ffmpeg==0.5.1
tyro==0.8.5
gradio==4.37.1
pykalman==0.9.7
pillow>=10.2.0

7
requirements_macOS.txt Normal file
View File

@ -0,0 +1,7 @@
-r requirements_base.txt
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.3.0
torchvision==0.18.0
torchaudio==2.3.0
onnxruntime-silicon==1.16.3

View File

@ -6,25 +6,28 @@ Benchmark the inference speed of each module in LivePortrait.
TODO: heavy GPT style, need to refactor TODO: heavy GPT style, need to refactor
""" """
import yaml
import torch import torch
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
import yaml
import time import time
import numpy as np import numpy as np
from src.utils.helper import load_model, concat_feat from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig from src.config.inference_config import InferenceConfig
def initialize_inputs(batch_size=1): def initialize_inputs(batch_size=1, device_id=0):
""" """
Generate random input tensors and move them to GPU Generate random input tensors and move them to GPU
""" """
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half() feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
kp_source = torch.randn(batch_size, 21, 3).cuda().half() kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
kp_driving = torch.randn(batch_size, 21, 3).cuda().half() kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half() source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half() generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
eye_close_ratio = torch.randn(batch_size, 3).cuda().half() eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
lip_close_ratio = torch.randn(batch_size, 2).cuda().half() lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
feat_stitching = concat_feat(kp_source, kp_driving).half() feat_stitching = concat_feat(kp_source, kp_driving).half()
feat_eye = concat_feat(kp_source, eye_close_ratio).half() feat_eye = concat_feat(kp_source, eye_close_ratio).half()
feat_lip = concat_feat(kp_source, lip_close_ratio).half() feat_lip = concat_feat(kp_source, lip_close_ratio).half()
@ -99,7 +102,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
Measure inference times for each model Measure inference times for each model
""" """
times = {name: [] for name in compiled_models.keys()} times = {name: [] for name in compiled_models.keys()}
times['Retargeting Models'] = [] times['Stitching and Retargeting Modules'] = []
overall_times = [] overall_times = []
@ -133,7 +136,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
stitching_retargeting_module['eye'](inputs['feat_eye']) stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip']) stitching_retargeting_module['lip'](inputs['feat_lip'])
torch.cuda.synchronize() torch.cuda.synchronize()
times['Retargeting Models'].append(time.time() - start) times['Stitching and Retargeting Modules'].append(time.time() - start)
overall_times.append(time.time() - overall_start) overall_times.append(time.time() - overall_start)
@ -166,15 +169,15 @@ def main():
""" """
Main function to benchmark speed and model parameters Main function to benchmark speed and model parameters
""" """
# Sample input tensors
inputs = initialize_inputs()
# Load configuration # Load configuration
cfg = InferenceConfig(device_id=0) cfg = InferenceConfig()
model_config_path = cfg.models_config model_config_path = cfg.models_config
with open(model_config_path, 'r') as file: with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file) model_config = yaml.safe_load(file)
# Sample input tensors
inputs = initialize_inputs(device_id = cfg.device_id)
# Load and compile models # Load and compile models
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)

View File

@ -1,44 +1,57 @@
# coding: utf-8 # coding: utf-8
""" """
config for user All configs for user
""" """
import os.path as osp
from dataclasses import dataclass from dataclasses import dataclass
import tyro import tyro
from typing_extensions import Annotated from typing_extensions import Annotated
from typing import Optional, Literal
from .base_config import PrintableConfig, make_abs_path from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig): class ArgumentConfig(PrintableConfig):
########## input arguments ########## ########## input arguments ##########
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human)
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
#####################################
########## inference arguments ########## ########## inference arguments ##########
device_id: int = 0 flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
flag_eye_retargeting: bool = False device_id: int = 0 # gpu device id
flag_lip_retargeting: bool = False flag_force_cpu: bool = False # force cpu inference, WIP!
flag_stitching: bool = True # we recommend setting it to True! flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_relative: bool = True # whether to use relative motion flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
driving_option: Literal["expression-friendly", "pose-friendly"] = "expression-friendly" # "expression-friendly" or "pose-friendly"; "expression-friendly" would adapt the driving motion with the global multiplier, and could be used when the source is a human image
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
########## source crop arguments ##########
det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
######################################### source_max_dim: int = 1280 # the max dim of height and width of source image or video, you can change it to a larger number, e.g., 1920
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
########## crop arguments ########## ########## driving crop arguments ##########
dsize: int = 512 scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video
scale: float = 2.3 vx_ratio_crop_driving_video: float = 0. # adjust y offset
vx_ratio: float = 0 # vx ratio vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
vy_ratio: float = -0.125 # vy ratio +up, -down
####################################
########## gradio arguments ########## ########## gradio arguments ##########
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
share: bool = True share: bool = False # whether to share the server to public
server_name: str = "0.0.0.0" server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation
gradio_temp_dir: Optional[str] = None # directory to save gradio temp files

View File

@ -4,15 +4,32 @@
parameters used for crop faces parameters used for crop faces
""" """
import os.path as osp
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union, List
from .base_config import PrintableConfig from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class CropConfig(PrintableConfig): class CropConfig(PrintableConfig):
insightface_root: str = make_abs_path("../../pretrained_weights/insightface")
landmark_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait/landmark.onnx")
xpose_config_file_path: str = make_abs_path("../utils/dependencies/XPose/config_model/UniPose_SwinT.py")
xpose_embedding_cache_path: str = make_abs_path('../utils/resources/clip_embedding')
xpose_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait_animals/xpose.pth")
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
det_thresh: float = 0.1 # detection threshold
########## source image or video cropping option ##########
dsize: int = 512 # crop size dsize: int = 512 # crop size
scale: float = 2.3 # scale factor scale: float = 2.3 # scale factor
vx_ratio: float = 0 # vx ratio vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down vy_ratio: float = -0.125 # vy ratio +up, -down
max_face_num: int = 0 # max face number, 0 mean no limit
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
animal_face_type: str = "animal_face_9" # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks
########## driving video auto cropping option ##########
scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video
vx_ratio_crop_driving_video: float = 0.0 # adjust y offset
vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
direction: str = "large-small" # direction of cropping

View File

@ -4,46 +4,61 @@
config dataclass used for inference config dataclass used for inference
""" """
import os.path as osp import cv2
from dataclasses import dataclass from numpy import ndarray
from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig): class InferenceConfig(PrintableConfig):
# HUMAN MODEL CONFIG, NOT EXPORTED PARAMS
models_config: str = make_abs_path('./models.yaml') # portrait animation config models_config: str = make_abs_path('./models.yaml') # portrait animation config
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint # ANIMAL MODEL CONFIG, NOT EXPORTED PARAMS
flag_use_half_precision: bool = True # whether to use half precision checkpoint_F_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
checkpoint_M_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/motion_extractor.pth') # path to checkpoint pf M
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False checkpoint_G_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/spade_generator.pth') # path to checkpoint of G
lip_zero_threshold: float = 0.03 checkpoint_W_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/warping_module.pth') # path to checkpoint of W
checkpoint_S_animal: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip, NOTE: use human temporarily!
# EXPORTED PARAMS
flag_use_half_precision: bool = True
flag_crop_driving_video: bool = False
device_id: int = 0
flag_normalize_lip: bool = True
flag_source_video_eye_retargeting: bool = False
flag_video_editing_head_rotation: bool = False
flag_eye_retargeting: bool = False flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True! flag_stitching: bool = True
flag_relative_motion: bool = True
flag_pasteback: bool = True
flag_do_crop: bool = True
flag_do_rot: bool = True
flag_force_cpu: bool = False
flag_do_torch_compile: bool = False
driving_option: str = "pose-friendly" # "expression-friendly" or "pose-friendly"
driving_multiplier: float = 1.0
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
flag_relative: bool = True # whether to use relative motion # NOT EXPORTED PARAMS
anchor_frame: int = 0 # set this value if find_best_frame is True lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape input_shape: Tuple[int, int] = (256, 256) # input shape
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
output_fps: int = 30 # fps for output video
crf: int = 15 # crf for output video crf: int = 15 # crf for output video
output_fps: int = 25 # default output fps
flag_write_result: bool = True # whether to write output video mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space size_gif: int = 256 # default gif size, TO IMPLEMENT
mask_crop = None
flag_write_gif: bool = False
size_gif: int = 256
ref_max_shape: int = 1280
ref_shape_n: int = 2
device_id: int = 0
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True

View File

@ -3,15 +3,28 @@
""" """
Pipeline for gradio Pipeline for gradio
""" """
import os.path as osp
import os
import cv2
from rich.progress import track
import gradio as gr import gradio as gr
import numpy as np
import torch
from .config.argument_config import ArgumentConfig from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online from .live_portrait_pipeline_animal import LivePortraitPipelineAnimal
from .utils.io import load_img_online, load_video, resize_to_limit
from .utils.filter import smooth
from .utils.rprint import rlog as log from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix from .utils.camera import get_rotation_matrix
from .utils.video import get_fps, has_audio_stream, concat_frames, images2video, add_audio_to_video
from .utils.helper import is_square_video, mkdir, dct2device, basename
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
def update_args(args, user_args): def update_args(args, user_args):
"""update the args according to user inputs """update the args according to user inputs
""" """
@ -20,40 +33,179 @@ def update_args(args, user_args):
setattr(args, k, v) setattr(args, k, v)
return args return args
class GradioPipeline(LivePortraitPipeline): class GradioPipeline(LivePortraitPipeline):
"""gradio for human
"""
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig): def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg) super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper = self.live_portrait_wrapper # self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args self.args = args
# for single image retargeting
self.start_prepare = False
self.f_s_user = None
self.x_c_s_info_user = None
self.x_s_user = None
self.source_lmk_user = None
self.mask_ori = None
self.img_rgb = None
self.crop_M_c2o = None
@torch.no_grad()
def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs):
if eyeball_direction_x > 0:
delta_new[0, 11, 0] += eyeball_direction_x * 0.0007
delta_new[0, 15, 0] += eyeball_direction_x * 0.001
else:
delta_new[0, 11, 0] += eyeball_direction_x * 0.001
delta_new[0, 15, 0] += eyeball_direction_x * 0.0007
delta_new[0, 11, 1] += eyeball_direction_y * -0.001
delta_new[0, 15, 1] += eyeball_direction_y * -0.001
blink = -eyeball_direction_y / 2.
delta_new[0, 11, 1] += blink * -0.001
delta_new[0, 13, 1] += blink * 0.0003
delta_new[0, 15, 1] += blink * -0.001
delta_new[0, 16, 1] += blink * 0.0003
return delta_new
@torch.no_grad()
def update_delta_new_smile(self, smile, delta_new, **kwargs):
delta_new[0, 20, 1] += smile * -0.01
delta_new[0, 14, 1] += smile * -0.02
delta_new[0, 17, 1] += smile * 0.0065
delta_new[0, 17, 2] += smile * 0.003
delta_new[0, 13, 1] += smile * -0.00275
delta_new[0, 16, 1] += smile * -0.00275
delta_new[0, 3, 1] += smile * -0.0035
delta_new[0, 7, 1] += smile * -0.0035
return delta_new
@torch.no_grad()
def update_delta_new_wink(self, wink, delta_new, **kwargs):
delta_new[0, 11, 1] += wink * 0.001
delta_new[0, 13, 1] += wink * -0.0003
delta_new[0, 17, 0] += wink * 0.0003
delta_new[0, 17, 1] += wink * 0.0003
delta_new[0, 3, 1] += wink * -0.0003
return delta_new
@torch.no_grad()
def update_delta_new_eyebrow(self, eyebrow, delta_new, **kwargs):
if eyebrow > 0:
delta_new[0, 1, 1] += eyebrow * 0.001
delta_new[0, 2, 1] += eyebrow * -0.001
else:
delta_new[0, 1, 0] += eyebrow * -0.001
delta_new[0, 2, 0] += eyebrow * 0.001
delta_new[0, 1, 1] += eyebrow * 0.0003
delta_new[0, 2, 1] += eyebrow * -0.0003
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_zero(self, lip_variation_zero, delta_new, **kwargs):
delta_new[0, 19, 0] += lip_variation_zero
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_one(self, lip_variation_one, delta_new, **kwargs):
delta_new[0, 14, 1] += lip_variation_one * 0.001
delta_new[0, 3, 1] += lip_variation_one * -0.0005
delta_new[0, 7, 1] += lip_variation_one * -0.0005
delta_new[0, 17, 2] += lip_variation_one * -0.0005
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_two(self, lip_variation_two, delta_new, **kwargs):
delta_new[0, 20, 2] += lip_variation_two * -0.001
delta_new[0, 20, 1] += lip_variation_two * -0.001
delta_new[0, 14, 1] += lip_variation_two * -0.001
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_three(self, lip_variation_three, delta_new, **kwargs):
delta_new[0, 19, 1] += lip_variation_three * 0.001
delta_new[0, 19, 2] += lip_variation_three * 0.0001
delta_new[0, 17, 1] += lip_variation_three * -0.0001
return delta_new
@torch.no_grad()
def update_delta_new_mov_x(self, mov_x, delta_new, **kwargs):
delta_new[0, 5, 0] += mov_x
return delta_new
@torch.no_grad()
def update_delta_new_mov_y(self, mov_y, delta_new, **kwargs):
delta_new[0, 5, 1] += mov_y
return delta_new
@torch.no_grad()
def execute_video( def execute_video(
self, self,
input_image_path, input_source_image_path=None,
input_video_path, input_source_video_path=None,
flag_relative_input, input_driving_video_pickle_path=None,
flag_do_crop_input, input_driving_video_path=None,
flag_remap_input, flag_relative_input=True,
flag_do_crop_input=True,
flag_remap_input=True,
flag_stitching_input=True,
driving_option_input="pose-friendly",
driving_multiplier=1.0,
flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=3e-7,
tab_selection=None,
v_tab_selection=None
): ):
""" for video driven potrait animation """ for video-driven portrait animation or video editing
""" """
if input_image_path is not None and input_video_path is not None: if tab_selection == 'Image':
input_source_path = input_source_image_path
elif tab_selection == 'Video':
input_source_path = input_source_video_path
else:
input_source_path = input_source_image_path
if v_tab_selection == 'Video':
input_driving_path = input_driving_video_path
elif v_tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path
else:
input_driving_path = input_driving_video_path
if input_source_path is not None and input_driving_path is not None:
if osp.exists(input_driving_path) and v_tab_selection == 'Video' and is_square_video(input_driving_path) is False:
flag_crop_driving_video_input = True
log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
args_user = { args_user = {
'source_image': input_image_path, 'source': input_source_path,
'driving_info': input_video_path, 'driving': input_driving_path,
'flag_relative': flag_relative_input, 'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input, 'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input, 'flag_pasteback': flag_remap_input,
'flag_stitching': flag_stitching_input,
'driving_option': driving_option_input,
'driving_multiplier': driving_multiplier,
'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
'scale_crop_driving_video': scale_crop_driving_video,
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
'driving_smooth_observation_variance': driving_smooth_observation_variance,
} }
# update config from user input # update config from user input
self.args = update_args(self.args, args_user) self.args = update_args(self.args, args_user)
@ -62,79 +214,368 @@ class GradioPipeline(LivePortraitPipeline):
# video driven animation # video driven animation
video_path, video_path_concat = self.execute(self.args) video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2) gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat, return video_path, video_path_concat
else: else:
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5) raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float): @torch.no_grad()
def execute_image_retargeting(
self,
input_eye_ratio: float,
input_lip_ratio: float,
input_head_pitch_variation: float,
input_head_yaw_variation: float,
input_head_roll_variation: float,
mov_x: float,
mov_y: float,
mov_z: float,
lip_variation_zero: float,
lip_variation_one: float,
lip_variation_two: float,
lip_variation_three: float,
smile: float,
wink: float,
eyebrow: float,
eyeball_direction_x: float,
eyeball_direction_y: float,
input_image,
retargeting_source_scale: float,
flag_stitching_retargeting_input=True,
flag_do_crop_input_retargeting_image=True):
""" for single image retargeting """ for single image retargeting
""" """
if input_eye_ratio is None or input_eye_ratio is None: if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
raise gr.Error("Invalid relative pose input 💥!", duration=5)
# disposable feature
f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting_image(
input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5) raise gr.Error("Invalid ratio input 💥!", duration=5)
elif self.f_s_user is None:
if self.start_prepare:
raise gr.Error(
"The source portrait is under processing 💥! Please wait for a second.",
duration=5
)
else: else:
raise gr.Error( device = self.live_portrait_wrapper.device
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.", # inference_cfg = self.live_portrait_wrapper.inference_cfg
duration=5 x_s_user = x_s_user.to(device)
) f_s_user = f_s_user.to(device)
else: R_s_user = R_s_user.to(device)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) R_d_user = R_d_user.to(device)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user) mov_x = torch.tensor(mov_x).to(device)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor) mov_y = torch.tensor(mov_y).to(device)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) mov_z = torch.tensor(mov_z).to(device)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user) eyeball_direction_x = torch.tensor(eyeball_direction_x).to(device)
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor) eyeball_direction_y = torch.tensor(eyeball_direction_y).to(device)
num_kp = self.x_s_user.shape[1] smile = torch.tensor(smile).to(device)
# default: use x_s wink = torch.tensor(wink).to(device)
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) eyebrow = torch.tensor(eyebrow).to(device)
# D(W(f_s; x_s, x_d)) lip_variation_zero = torch.tensor(lip_variation_zero).to(device)
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new) lip_variation_one = torch.tensor(lip_variation_one).to(device)
lip_variation_two = torch.tensor(lip_variation_two).to(device)
lip_variation_three = torch.tensor(lip_variation_three).to(device)
x_c_s = x_s_info['kp'].to(device)
delta_new = x_s_info['exp'].to(device)
scale_new = x_s_info['scale'].to(device)
t_new = x_s_info['t'].to(device)
R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user
if eyeball_direction_x != 0 or eyeball_direction_y != 0:
delta_new = self.update_delta_new_eyeball_direction(eyeball_direction_x, eyeball_direction_y, delta_new)
if smile != 0:
delta_new = self.update_delta_new_smile(smile, delta_new)
if wink != 0:
delta_new = self.update_delta_new_wink(wink, delta_new)
if eyebrow != 0:
delta_new = self.update_delta_new_eyebrow(eyebrow, delta_new)
if lip_variation_zero != 0:
delta_new = self.update_delta_new_lip_variation_zero(lip_variation_zero, delta_new)
if lip_variation_one != 0:
delta_new = self.update_delta_new_lip_variation_one(lip_variation_one, delta_new)
if lip_variation_two != 0:
delta_new = self.update_delta_new_lip_variation_two(lip_variation_two, delta_new)
if lip_variation_three != 0:
delta_new = self.update_delta_new_lip_variation_three(lip_variation_three, delta_new)
if mov_x != 0:
delta_new = self.update_delta_new_mov_x(-mov_x, delta_new)
if mov_y !=0 :
delta_new = self.update_delta_new_mov_y(mov_y, delta_new)
x_d_new = mov_z * scale_new * (x_c_s @ R_d_new + delta_new) + t_new
eyes_delta, lip_delta = None, None
if input_eye_ratio != self.source_eye_ratio:
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
if input_lip_ratio != self.source_lip_ratio:
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
x_d_new = x_d_new + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
if flag_stitching_retargeting_input:
x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0] out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori) if flag_do_crop_input_retargeting_image:
gr.Info("Run successfully!", duration=2) out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
else:
out_to_ori_blend = out
return out, out_to_ori_blend return out, out_to_ori_blend
@torch.no_grad()
def prepare_retargeting(self, input_image_path, flag_do_crop = True): def prepare_retargeting_image(
self,
input_image,
input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation,
retargeting_source_scale,
flag_do_crop=True):
""" for single image retargeting """ for single image retargeting
""" """
if input_image_path is not None: if input_image is not None:
gr.Info("Upload successfully!", duration=2) # gr.Info("Upload successfully!", duration=2)
self.start_prepare = True args_user = {'scale': retargeting_source_scale}
inference_cfg = self.live_portrait_wrapper.cfg self.args = update_args(self.args, args_user)
self.cropper.update_config(self.args.__dict__)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ######## ######## process source portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16) img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
if flag_do_crop: if flag_do_crop:
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
else: else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb) I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
source_lmk_user = self.cropper.calc_lmk_from_cropped_image(img_rgb)
crop_M_c2o = None
mask_ori = None
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) x_d_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
x_d_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
x_d_info_user_roll = x_s_info['roll'] + input_head_roll_variation
R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
R_d_user = get_rotation_matrix(x_d_info_user_pitch, x_d_info_user_yaw, x_d_info_user_roll)
############################################ ############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
# record global info for next time use @torch.no_grad()
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) def init_retargeting_image(self, retargeting_source_scale: float, source_eye_ratio: float, source_lip_ratio:float, input_image = None):
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) """ initialize the retargeting slider
self.x_s_info_user = x_s_info """
self.source_lmk_user = crop_info['lmk_crop'] if input_image != None:
self.img_rgb = img_rgb args_user = {'scale': retargeting_source_scale}
self.crop_M_c2o = crop_info['M_c2o'] self.args = update_args(self.args, args_user)
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) self.cropper.update_config(self.args.__dict__)
# update slider # inference_cfg = self.live_portrait_wrapper.inference_cfg
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None]) ######## process source portrait ########
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean()) img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None]) log(f"Load source image from {input_image}.")
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean()) crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
# for vis if crop_info is None:
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0] raise gr.Error("Source portrait NO face detected", duration=2)
return eye_close_ratio, lip_close_ratio, self.I_s_vis source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None])
source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
self.source_eye_ratio = round(float(source_eye_ratio.mean()), 2)
self.source_lip_ratio = round(float(source_lip_ratio[0][0]), 2)
log("Calculating eyes-open and lip-open ratios successfully!")
return self.source_eye_ratio, self.source_lip_ratio
else:
return source_eye_ratio, source_lip_ratio
@torch.no_grad()
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True):
""" retargeting the lip-open ratio of each source frame
"""
# disposable feature
device = self.live_portrait_wrapper.device
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
if input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
I_p_pstbk_lst = None
if flag_do_crop_input_retargeting_video:
I_p_pstbk_lst = []
I_p_lst = []
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
x_s_user_i = x_s_user_lst[i].to(device)
f_s_user_i = f_s_user_lst[i].to(device)
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
x_d_i_new = x_s_user_i + lip_delta_retargeting
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if flag_do_crop_input_retargeting_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(self.args.output_dir)
flag_source_has_audio = has_audio_stream(input_video)
######### build the final concatenation result #########
# source frame | generation
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
if flag_source_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=source_fps)
######### build the final result #########
if flag_source_has_audio:
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
add_audio_to_video(wfp, input_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
gr.Info("Run successfully!", duration=2)
return wfp_concat, wfp
@torch.no_grad()
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
""" for video retargeting
"""
if input_video is not None:
# gr.Info("Upload successfully!", duration=2)
args_user = {'scale': retargeting_source_scale}
self.args = update_args(self.args, args_user)
self.cropper.update_config(self.args.__dict__)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source video ########
source_rgb_lst = load_video(input_video)
source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
source_fps = int(get_fps(input_video))
n_frames = len(source_rgb_lst)
log(f"Load source video from {input_video}. FPS is {source_fps}")
if flag_do_crop:
ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
if len(ret_s["frame_crop_lst"]) != n_frames:
n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
else:
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
source_M_c2o_lst, mask_ori_lst = None, None
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
# save the motion template
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
c_d_lip_retargeting = [input_lip_ratio]
f_s_user_lst, x_s_user_lst, lip_delta_retargeting_lst = [], [], []
for i in track(range(n_frames), description='Preparing retargeting video...', total=n_frames):
x_s_info = source_template_dct['motion'][i]
x_s_info = dct2device(x_s_info, device)
x_s_user = x_s_info['x_s']
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
combined_lip_ratio_tensor_retargeting = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_retargeting, source_lmk)
lip_delta_retargeting = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor_retargeting)
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32))
lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting)
return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames
else: else:
# when press the clear button, go here # when press the clear button, go here
return 0.8, 0.8, self.I_s_vis raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
class GradioPipelineAnimal(LivePortraitPipelineAnimal):
"""gradio for animal
"""
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
inference_cfg.flag_crop_driving_video = True # ensure the face_analysis_wrapper is enabled
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper_animal = self.live_portrait_wrapper_animal
self.args = args
@torch.no_grad()
def execute_video(
self,
input_source_image_path=None,
input_driving_video_path=None,
input_driving_video_pickle_path=None,
flag_do_crop_input=False,
flag_remap_input=False,
driving_multiplier=1.0,
flag_stitching=False,
flag_crop_driving_video_input=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
tab_selection=None,
):
""" for video-driven potrait animation
"""
input_source_path = input_source_image_path
if tab_selection == 'Video':
input_driving_path = input_driving_video_path
elif tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path
else:
input_driving_path = input_driving_video_pickle_path
if input_source_path is not None and input_driving_path is not None:
if osp.exists(input_driving_path) and tab_selection == 'Video' and is_square_video(input_driving_path) is False:
flag_crop_driving_video_input = True
log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
args_user = {
'source': input_source_path,
'driving': input_driving_path,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'driving_multiplier': driving_multiplier,
'flag_stitching': flag_stitching,
'flag_crop_driving_video': flag_crop_driving_video_input,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
'scale_crop_driving_video': scale_crop_driving_video,
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper_animal.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat, video_gif_path = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat, video_gif_path
else:
raise gr.Error("Please upload the source animal image, and driving video 🤗🤗🤗", duration=5)

View File

@ -1,16 +1,15 @@
# coding: utf-8 # coding: utf-8
""" """
Pipeline of LivePortrait Pipeline of LivePortrait (Human)
""" """
# TODO: import torch
# 1. 当前假定所有的模板都是已经裁好的,需要修改下 torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
# 2. pick样例图 source + driving
import cv2 import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np import numpy as np
import pickle import os
import os.path as osp import os.path as osp
from rich.progress import track from rich.progress import track
@ -19,12 +18,13 @@ from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig from .config.crop_config import CropConfig
from .utils.cropper import Cropper from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back from .utils.crop import prepare_paste_back, paste_back
from .utils.retargeting_utils import calc_lip_close_ratio from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, is_square_video, calc_motion_multiplier
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template from .utils.filter import smooth
from .utils.rprint import rlog as log from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapper from .live_portrait_wrapper import LivePortraitWrapper
@ -35,156 +35,388 @@ def make_abs_path(fn):
class LivePortraitPipeline(object): class LivePortraitPipeline(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg) self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg) self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs):
n_frames = I_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_eyes_lst': [],
'c_lip_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s, R, δ and t for inference
I_i = I_lst[i]
x_i_info = self.live_portrait_wrapper.get_kp_info(I_i)
x_s = self.live_portrait_wrapper.transform_keypoint(x_i_info)
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
item_dct = {
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
'R': R_i.cpu().numpy().astype(np.float32),
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_i_info['t'].cpu().numpy().astype(np.float32),
'kp': x_i_info['kp'].cpu().numpy().astype(np.float32),
'x_s': x_s.cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_eyes = c_eyes_lst[i].astype(np.float32)
template_dct['c_eyes_lst'].append(c_eyes)
c_lip = c_lip_lst[i].astype(np.float32)
template_dct['c_lip_lst'].append(c_lip)
return template_dct
def execute(self, args: ArgumentConfig): def execute(self, args: ArgumentConfig):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience # for convenience
######## process source portrait ######## inf_cfg = self.live_portrait_wrapper.inference_cfg
img_rgb = load_image_rgb(args.source_image) device = self.live_portrait_wrapper.device
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) crop_cfg = self.cropper.crop_cfg
log(f"Load source image from {args.source_image}")
crop_info = self.cropper.crop_single_image(img_rgb) ######## load source input ########
source_lmk = crop_info['lmk_crop'] flag_is_source_video = False
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] source_fps = None
if inference_cfg.flag_do_crop: if is_image(args.source):
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) flag_is_source_video = False
img_rgb = load_image_rgb(args.source)
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
log(f"Load source image from {args.source}")
source_rgb_lst = [img_rgb]
elif is_video(args.source):
flag_is_source_video = True
source_rgb_lst = load_video(args.source)
source_rgb_lst = [resize_to_limit(img, inf_cfg.source_max_dim, inf_cfg.source_division) for img in source_rgb_lst]
source_fps = int(get_fps(args.source))
log(f"Load source video from {args.source}, FPS is {source_fps}")
else: # source input is an unknown format
raise Exception(f"Unknown source format: {args.source}")
######## process driving info ########
flag_load_from_template = is_template(args.driving)
driving_rgb_crop_256x256_lst = None
wfp_template = None
if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
driving_template_dct = load(args.driving)
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames']
if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
else: else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb) n_frames = driving_n_frames
# set output_fps
output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving) and is_video(args.driving):
# load from video file, AND make motion template
output_fps = int(get_fps(args.driving))
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_rgb_lst = load_video(args.driving)
driving_n_frames = len(driving_rgb_lst)
######## make motion template ########
log("Start making driving motion template...")
if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
driving_rgb_lst = driving_rgb_lst[:n_frames]
else:
n_frames = driving_n_frames
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else:
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
#######################################
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_ratio(driving_lmk_crop_lst)
# save the motion template
I_d_lst = self.live_portrait_wrapper.prepare_videos(driving_rgb_crop_256x256_lst)
driving_template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}")
else:
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
######## prepare for pasteback ########
I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
I_p_pstbk_lst = []
log("Prepared pasteback mask done.")
I_p_lst = []
R_d_0, x_d_0_info = None, None
flag_normalize_lip = inf_cfg.flag_normalize_lip # not overwrite
flag_source_video_eye_retargeting = inf_cfg.flag_source_video_eye_retargeting # not overwrite
lip_delta_before_animation, eye_delta_before_animation = None, None
######## process source info ########
if flag_is_source_video:
log(f"Start making source motion template...")
source_rgb_lst = source_rgb_lst[:n_frames]
if inf_cfg.flag_do_crop:
ret_s = self.cropper.crop_source_video(source_rgb_lst, crop_cfg)
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
if len(ret_s["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_s["frame_crop_lst"]))
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
else:
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
# save the motion template
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
if inf_cfg.flag_relative_motion:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else: # if the input is a source image, process it only once
if inf_cfg.flag_do_crop:
crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop_256x256 = crop_info['img_crop_256x256']
else:
source_lmk = self.cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp'] x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
if inference_cfg.flag_lip_zero:
# let lip-open scalar to be 0 at first # let lip-open scalar to be 0 at first
if flag_normalize_lip and inf_cfg.flag_relative_motion and source_lmk is not None:
c_d_lip_before_animation = [0.] c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold: if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
inference_cfg.flag_lip_zero = False
else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ######## if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
if is_video(args.driving_info): mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
# TODO: 这里track一下驱动视频 -> 构建模板 ######## animate ########
driving_rgb_lst = load_driving_info(args.driving_info) log(f"The animated video consists of {n_frames} frames.")
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256) if flag_is_source_video: # source video
n_frames = I_d_lst.shape[0] x_s_info = source_template_dct['motion'][i]
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting: x_s_info = dct2device(x_s_info, device)
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) source_lmk = source_lmk_crop_lst[i]
elif is_template(args.driving_info): img_crop_256x256 = img_crop_256x256_lst[i]
log(f"Load from video templates {args.driving_info}") I_s = I_s_lst[i]
with open(args.driving_info, 'rb') as f: f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
template_lst, driving_lmk_lst = pickle.load(f)
n_frames = template_lst[0]['n_frames'] x_c_s = x_s_info['kp']
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) R_s = x_s_info['R']
x_s =x_s_info['x_s']
# let lip-open scalar to be 0 at first if the input is a video
if flag_normalize_lip and inf_cfg.flag_relative_motion and source_lmk is not None:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
else: else:
raise Exception("Unsupported driving types!") lip_delta_before_animation = None
#########################################
######## prepare for pasteback ########
if inference_cfg.flag_pasteback:
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_paste_lst = []
#########################################
I_p_lst = []
R_d_0, x_d_0_info = None, None
for i in track(range(n_frames), description='Animating...', total=n_frames):
if is_video(args.driving_info):
# extract kp info by M
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
else:
# from template
x_d_i_info = template_lst[i]
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
R_d_i = x_d_i_info['R_d']
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
if flag_source_video_eye_retargeting and source_lmk is not None:
if i == 0: if i == 0:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
if c_d_eye_before_animation_frame_zero[0][0] < inf_cfg.source_video_eye_retargeting_threshold:
c_d_eye_before_animation_frame_zero = [[0.39]]
combined_eye_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk)
eye_delta_before_animation = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
if i == 0: # cache the first frame
R_d_0 = R_d_i R_d_0 = R_d_i
x_d_0_info = x_d_i_info x_d_0_info = x_d_i_info
if inference_cfg.flag_relative: if inf_cfg.flag_relative_motion:
if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
else:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
else: else:
R_new = R_d_i R_new = R_d_i
delta_new = x_d_i_info['exp'] delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
scale_new = x_s_info['scale'] scale_new = x_s_info['scale']
t_new = x_d_i_info['t'] t_new = x_d_i_info['t']
t_new[..., 2].fill_(0) # zero tz t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video:
if i == 0:
x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
# motion_multiplier *= inf_cfg.driving_multiplier
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
x_d_i_new = x_d_diff + x_s
# Algorithm 1: # Algorithm 1:
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# without stitching or retargeting # without stitching or retargeting
if inference_cfg.flag_lip_zero: if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) x_d_i_new += lip_delta_before_animation
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else: else:
pass pass
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# with stitching and without retargeting # with stitching and without retargeting
if inference_cfg.flag_lip_zero: if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation
else: else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else: else:
eyes_delta, lip_delta = None, None eyes_delta, lip_delta = None, None
if inference_cfg.flag_eye_retargeting: if inf_cfg.flag_eye_retargeting and source_lmk is not None:
c_d_eyes_i = input_eye_ratio_lst[i] c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
if inference_cfg.flag_lip_retargeting: if inf_cfg.flag_lip_retargeting and source_lmk is not None:
c_d_lip_i = input_lip_ratio_lst[i] c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
if inference_cfg.flag_relative: # use x_s if inf_cfg.flag_relative_motion: # use x_s
x_d_i_new = x_s + \ x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ (eyes_delta if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) (lip_delta if lip_delta is not None else 0)
else: # use x_d,i else: # use x_d,i
x_d_i_new = x_d_i_new + \ x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ (eyes_delta if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) (lip_delta if lip_delta is not None else 0)
if inference_cfg.flag_stitching: if inf_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
x_d_i_new = x_s + (x_d_i_new - x_s) * inf_cfg.driving_multiplier
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i) I_p_lst.append(I_p_i)
if inference_cfg.flag_pasteback: if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori) # TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU
I_p_paste_lst.append(I_p_i_to_ori_blend) if flag_is_source_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float)
else:
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], source_rgb_lst[0], mask_ori_float)
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir) mkdir(args.output_dir)
wfp_concat = None wfp_concat = None
if is_video(args.driving_info): flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256) flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
# save (driving frames, source image, drived frames) result
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat)
# save drived result ######### build the final concatenation result #########
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') # driving frame | source frame | generation, or source frame | generation
if inference_cfg.flag_pasteback: if flag_is_source_video:
images2video(I_p_paste_lst, wfp=wfp) frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
else: else:
images2video(I_p_lst, wfp=wfp) frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# NOTE: update output fps
output_fps = source_fps if flag_is_source_video else output_fps
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_source_has_audio or flag_driving_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build the final result #########
if flag_source_has_audio or flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}')
return wfp, wfp_concat return wfp, wfp_concat

View File

@ -0,0 +1,237 @@
# coding: utf-8
"""
Pipeline of LivePortrait (Animal)
"""
import warnings
warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.")
warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly.")
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
import torch
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import os
import os.path as osp
from rich.progress import track
from .config.argument_config import ArgumentConfig
from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream, video2gif
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, calc_motion_multiplier
from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapperAnimal
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
class LivePortraitPipelineAnimal(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper_animal: LivePortraitWrapperAnimal = LivePortraitWrapperAnimal(inference_cfg=inference_cfg)
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg, image_type='animal_face', flag_use_half_precision=inference_cfg.flag_use_half_precision)
def make_motion_template(self, I_lst, **kwargs):
n_frames = I_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
}
for i in track(range(n_frames), description='Making driving motion templates...', total=n_frames):
# collect s, R, δ and t for inference
I_i = I_lst[i]
x_i_info = self.live_portrait_wrapper_animal.get_kp_info(I_i)
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
item_dct = {
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
'R': R_i.cpu().numpy().astype(np.float32),
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
return template_dct
def execute(self, args: ArgumentConfig):
# for convenience
inf_cfg = self.live_portrait_wrapper_animal.inference_cfg
device = self.live_portrait_wrapper_animal.device
crop_cfg = self.cropper.crop_cfg
######## load source input ########
if is_image(args.source):
img_rgb = load_image_rgb(args.source)
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
log(f"Load source image from {args.source}")
else: # source input is an unknown format
raise Exception(f"Unknown source format: {args.source}")
######## process driving info ########
flag_load_from_template = is_template(args.driving)
driving_rgb_crop_256x256_lst = None
wfp_template = None
if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
driving_template_dct = load(args.driving)
n_frames = driving_template_dct['n_frames']
# set output_fps
output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving) and is_video(args.driving):
# load from video file, AND make motion template
output_fps = int(get_fps(args.driving))
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_rgb_lst = load_video(args.driving)
n_frames = len(driving_rgb_lst)
######## make motion template ########
log("Start making driving motion template...")
if inf_cfg.flag_crop_driving_video:
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst = ret_d['frame_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else:
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
#######################################
# save the motion template
I_d_lst = self.live_portrait_wrapper_animal.prepare_videos(driving_rgb_crop_256x256_lst)
driving_template_dct = self.make_motion_template(I_d_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}")
else:
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
######## prepare for pasteback ########
I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
I_p_pstbk_lst = []
log("Prepared pasteback mask done.")
######## process source info ########
if inf_cfg.flag_do_crop:
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
if crop_info is None:
raise Exception("No animal face detected in the source image!")
img_crop_256x256 = crop_info['img_crop_256x256']
else:
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper_animal.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper_animal.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper_animal.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper_animal.transform_keypoint(x_s_info)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
######## animate ########
I_p_lst = []
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
delta_new = x_d_i_info['exp']
t_new = x_d_i_info['t']
t_new[..., 2].fill_(0) # zero tz
scale_new = x_s_info['scale']
x_d_i = scale_new * (x_c_s @ R_d_i + delta_new) + t_new
if i == 0:
x_d_0 = x_d_i
motion_multiplier = calc_motion_multiplier(x_s, x_d_0)
x_d_diff = (x_d_i - x_d_0) * motion_multiplier
x_d_i = x_d_diff + x_s
if not inf_cfg.flag_stitching:
pass
else:
x_d_i = self.live_portrait_wrapper_animal.stitching(x_s, x_d_i)
x_d_i = x_s + (x_d_i - x_s) * inf_cfg.driving_multiplier
out = self.live_portrait_wrapper_animal.warp_decode(f_s, x_s, x_d_i)
I_p_i = self.live_portrait_wrapper_animal.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir)
wfp_concat = None
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
######### build the final concatenation result #########
# driving frame | source image | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_driving_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.driving
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build the final result #########
if flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
audio_from_which_video = args.driving
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}')
# build the gif
wfp_gif = video2gif(wfp)
log(f'Animated gif: {wfp_gif}')
return wfp, wfp_concat, wfp_gif

View File

@ -1,9 +1,10 @@
# coding: utf-8 # coding: utf-8
""" """
Wrapper for LivePortrait core functions Wrappers for LivePortrait core functions
""" """
import contextlib
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import cv2 import cv2
@ -19,46 +20,73 @@ from .utils.rprint import rlog as log
class LivePortraitWrapper(object): class LivePortraitWrapper(object):
"""
Wrapper for Human
"""
def __init__(self, cfg: InferenceConfig): def __init__(self, inference_cfg: InferenceConfig):
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader) self.inference_cfg = inference_cfg
self.device_id = inference_cfg.device_id
self.compile = inference_cfg.flag_do_torch_compile
if inference_cfg.flag_force_cpu:
self.device = 'cpu'
else:
try:
if torch.backends.mps.is_available():
self.device = 'mps'
else:
self.device = 'cuda:' + str(self.device_id)
except:
self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F # init F
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor done.') log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F)} done.')
# init M # init M
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor done.') log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M)} done.')
# init W # init W
self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
log(f'Load warping_module done.') log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W)} done.')
# init G # init G
self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
log(f'Load spade_generator done.') log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G)} done.')
# init S and R # init S and R
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S): if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module done.') log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S)} done.')
else: else:
self.stitching_retargeting_module = None self.stitching_retargeting_module = None
# Optimize for inference
if self.compile:
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
self.cfg = cfg
self.device_id = cfg.device_id
self.timer = Timer() self.timer = Timer()
def inference_ctx(self):
if self.device == "mps":
ctx = contextlib.nullcontext()
else:
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
enabled=self.inference_cfg.flag_use_half_precision)
return ctx
def update_config(self, user_args): def update_config(self, user_args):
for k, v in user_args.items(): for k, v in user_args.items():
if hasattr(self.cfg, k): if hasattr(self.inference_cfg, k):
setattr(self.cfg, k, v) setattr(self.inference_cfg, k, v)
def prepare_source(self, img: np.ndarray) -> torch.Tensor: def prepare_source(self, img: np.ndarray) -> torch.Tensor:
""" construct the input as standard """ construct the input as standard
img: HxWx3, uint8, 256x256 img: HxWx3, uint8, 256x256
""" """
h, w = img.shape[:2] h, w = img.shape[:2]
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]: if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1])) x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
else: else:
x = img.copy() x = img.copy()
@ -70,10 +98,10 @@ class LivePortraitWrapper(object):
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
x = np.clip(x, 0, 1) # clip to 0~1 x = np.clip(x, 0, 1) # clip to 0~1
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
x = x.cuda(self.device_id) x = x.to(self.device)
return x return x
def prepare_driving_videos(self, imgs) -> torch.Tensor: def prepare_videos(self, imgs) -> torch.Tensor:
""" construct the input as standard """ construct the input as standard
imgs: NxBxHxWx3, uint8 imgs: NxBxHxWx3, uint8
""" """
@ -87,7 +115,7 @@ class LivePortraitWrapper(object):
y = _imgs.astype(np.float32) / 255. y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1 y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
y = y.cuda(self.device_id) y = y.to(self.device)
return y return y
@ -95,8 +123,7 @@ class LivePortraitWrapper(object):
""" get the appearance feature of the image by F """ get the appearance feature of the image by F
x: Bx3xHxW, normalized to 0~1 x: Bx3xHxW, normalized to 0~1
""" """
with torch.no_grad(): with torch.no_grad(), self.inference_ctx():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
feature_3d = self.appearance_feature_extractor(x) feature_3d = self.appearance_feature_extractor(x)
return feature_3d.float() return feature_3d.float()
@ -107,11 +134,10 @@ class LivePortraitWrapper(object):
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp' return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
""" """
with torch.no_grad(): with torch.no_grad(), self.inference_ctx():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
kp_info = self.motion_extractor(x) kp_info = self.motion_extractor(x)
if self.cfg.flag_use_half_precision: if self.inference_cfg.flag_use_half_precision:
# float the dict # float the dict
for k, v in kp_info.items(): for k, v in kp_info.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
@ -189,26 +215,27 @@ class LivePortraitWrapper(object):
""" """
kp_source: BxNx3 kp_source: BxNx3
eye_close_ratio: Bx3 eye_close_ratio: Bx3
Return: Bx(3*num_kp+2) Return: Bx(3*num_kp)
""" """
feat_eye = concat_feat(kp_source, eye_close_ratio) feat_eye = concat_feat(kp_source, eye_close_ratio)
with torch.no_grad(): with torch.no_grad():
delta = self.stitching_retargeting_module['eye'](feat_eye) delta = self.stitching_retargeting_module['eye'](feat_eye)
return delta return delta.reshape(-1, kp_source.shape[1], 3)
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor: def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
""" """
kp_source: BxNx3 kp_source: BxNx3
lip_close_ratio: Bx2 lip_close_ratio: Bx2
Return: Bx(3*num_kp)
""" """
feat_lip = concat_feat(kp_source, lip_close_ratio) feat_lip = concat_feat(kp_source, lip_close_ratio)
with torch.no_grad(): with torch.no_grad():
delta = self.stitching_retargeting_module['lip'](feat_lip) delta = self.stitching_retargeting_module['lip'](feat_lip)
return delta return delta.reshape(-1, kp_source.shape[1], 3)
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" """
@ -253,15 +280,17 @@ class LivePortraitWrapper(object):
kp_driving: BxNx3 kp_driving: BxNx3
""" """
# The line 18 in Algorithm 1: D(W(f_s; x_s, x_d,i) # The line 18 in Algorithm 1: D(W(f_s; x_s, x_d,i)
with torch.no_grad(): with torch.no_grad(), self.inference_ctx():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision): if self.compile:
# Mark the beginning of a new CUDA Graph step
torch.compiler.cudagraph_mark_step_begin()
# get decoder input # get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving) ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode # decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out']) ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
# float the dict # float the dict
if self.cfg.flag_use_half_precision: if self.inference_cfg.flag_use_half_precision:
for k, v in ret_dct.items(): for k, v in ret_dct.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
ret_dct[k] = v.float() ret_dct[k] = v.float()
@ -278,30 +307,78 @@ class LivePortraitWrapper(object):
return out return out
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst): def calc_ratio(self, lmk_lst):
input_eye_ratio_lst = [] input_eye_ratio_lst = []
input_lip_ratio_lst = [] input_lip_ratio_lst = []
for lmk in driving_lmk_lst: for lmk in lmk_lst:
# for eyes retargeting # for eyes retargeting
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
# for lip retargeting # for lip retargeting
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None])) input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
return input_eye_ratio_lst, input_lip_ratio_lst return input_eye_ratio_lst, input_lip_ratio_lst
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk): def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
eye_close_ratio = calc_eye_close_ratio(source_lmk[None]) c_s_eyes = calc_eye_close_ratio(source_lmk[None])
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id) c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id) c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
# [c_s,eyes, c_d,eyes,i] # [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1) combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
return combined_eye_ratio_tensor return combined_eye_ratio_tensor
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk): def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
lip_close_ratio = calc_lip_close_ratio(source_lmk[None]) c_s_lip = calc_lip_close_ratio(source_lmk[None])
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id) c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
# [c_s,lip, c_d,lip,i] # [c_s,lip, c_d,lip,i]
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id) combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
if input_lip_ratio_tensor.shape != [1, 1]:
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
return combined_lip_ratio_tensor return combined_lip_ratio_tensor
class LivePortraitWrapperAnimal(LivePortraitWrapper):
"""
Wrapper for Animal
"""
def __init__(self, inference_cfg: InferenceConfig):
# super().__init__(inference_cfg) # 调用父类的初始化方法
self.inference_cfg = inference_cfg
self.device_id = inference_cfg.device_id
self.compile = inference_cfg.flag_do_torch_compile
if inference_cfg.flag_force_cpu:
self.device = 'cpu'
else:
try:
if torch.backends.mps.is_available():
self.device = 'mps'
else:
self.device = 'cuda:' + str(self.device_id)
except:
self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F_animal, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F_animal)} done.')
# init M
self.motion_extractor = load_model(inference_cfg.checkpoint_M_animal, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M_animal)} done.')
# init W
self.warping_module = load_model(inference_cfg.checkpoint_W_animal, model_config, self.device, 'warping_module')
log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W_animal)} done.')
# init G
self.spade_generator = load_model(inference_cfg.checkpoint_G_animal, model_config, self.device, 'spade_generator')
log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G_animal)} done.')
# init S and R
if inference_cfg.checkpoint_S_animal is not None and osp.exists(inference_cfg.checkpoint_S_animal):
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S_animal, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S_animal)} done.')
else:
self.stitching_retargeting_module = None
# Optimize for inference
if self.compile:
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
self.timer = Timer()

View File

@ -59,7 +59,7 @@ class DenseMotionNetwork(nn.Module):
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w) heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
# adding background feature # adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device) zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
heatmap = torch.cat([zeros, heatmap], dim=1) heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w) heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
return heatmap return heatmap

View File

@ -11,7 +11,8 @@ import torch
import torch.nn.utils.spectral_norm as spectral_norm import torch.nn.utils.spectral_norm as spectral_norm
import math import math
import warnings import warnings
import collections.abc
from itertools import repeat
def kp2gaussian(kp, spatial_size, kp_variance): def kp2gaussian(kp, spatial_size, kp_variance):
""" """
@ -439,3 +440,13 @@ class DropPath(nn.Module):
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b) return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)

View File

@ -1,65 +0,0 @@
# coding: utf-8
"""
Make video template
"""
import os
import cv2
import numpy as np
import pickle
from rich.progress import track
from .utils.cropper import Cropper
from .utils.io import load_driving_info
from .utils.camera import get_rotation_matrix
from .utils.helper import mkdir, basename
from .utils.rprint import rlog as log
from .config.crop_config import CropConfig
from .config.inference_config import InferenceConfig
from .live_portrait_wrapper import LivePortraitWrapper
class TemplateMaker:
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg)
def make_motion_template(self, video_fp: str, output_path: str, **kwargs):
""" make video template (.pkl format)
video_fp: driving video file path
output_path: where to save the pickle file
"""
driving_rgb_lst = load_driving_info(video_fp)
driving_rgb_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst)
n_frames = I_d_lst.shape[0]
templates = []
for i in track(range(n_frames), description='Making templates...', total=n_frames):
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
# collect s_d, R_d, δ_d and t_d for inference
template_dct = {
'n_frames': n_frames,
'frames_index': i,
}
template_dct['scale'] = x_d_i_info['scale'].cpu().numpy().astype(np.float32)
template_dct['R_d'] = R_d_i.cpu().numpy().astype(np.float32)
template_dct['exp'] = x_d_i_info['exp'].cpu().numpy().astype(np.float32)
template_dct['t'] = x_d_i_info['t'].cpu().numpy().astype(np.float32)
templates.append(template_dct)
mkdir(output_path)
# Save the dictionary as a pickle file
pickle_fp = os.path.join(output_path, f'{basename(video_fp)}.pkl')
with open(pickle_fp, 'wb') as f:
pickle.dump([templates, driving_lmk_lst], f)
log(f"Template saved at {pickle_fp}")

View File

@ -0,0 +1,138 @@
# coding: utf-8
"""
face detectoin and alignment using XPose
"""
import os
import pickle
import torch
import numpy as np
from PIL import Image
from torchvision.ops import nms
from .timer import Timer
from .rprint import rlog as log
from .helper import clean_state_dict
from .dependencies.XPose import transforms as T
from .dependencies.XPose.models import build_model
from .dependencies.XPose.predefined_keypoints import *
from .dependencies.XPose.util import box_ops
from .dependencies.XPose.util.config import Config
class XPoseRunner(object):
def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
self.device_id = kwargs.get("device_id", 0)
self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
self.timer = Timer()
# Load cached embeddings if available
try:
with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
print("Loaded cached embeddings from file.")
except Exception:
raise ValueError("Could not load clip embeddings from file, please check your file path.")
def load_animal_model(self, model_config_path, model_checkpoint_path, device):
args = Config.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
def load_image(self, input_image):
image_pil = input_image.convert("RGB")
transform = T.Compose([
T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image, _ = transform(image_pil, None)
return image_pil, image
def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
instance_list = instance_text_prompt.split(',')
if len(keypoint_text_prompt) == 9:
# torch.Size([1, 512]) torch.Size([9, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
elif len(keypoint_text_prompt) ==68:
# torch.Size([1, 512]) torch.Size([68, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
else:
raise ValueError("Invalid number of keypoint embeddings.")
target = {
"instance_text_prompt": instance_list,
"keypoint_text_prompt": keypoint_text_prompt,
"object_embeddings_text": ins_text_embeddings.float(),
"kpts_embeddings_text": torch.cat((kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), dim=0),
"kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
}
self.model = self.model.to(self.device)
image = image.to(self.device)
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
outputs = self.model(image[None], [target])
logits = outputs["pred_logits"].sigmoid()[0]
boxes = outputs["pred_boxes"][0]
keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]
logits_filt = logits.cpu().clone()
boxes_filt = boxes.cpu().clone()
keypoints_filt = keypoints.cpu().clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask]
boxes_filt = boxes_filt[filt_mask]
keypoints_filt = keypoints_filt[filt_mask]
keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], iou_threshold=IoU_threshold)
filtered_boxes = boxes_filt[keep_indices]
filtered_keypoints = keypoints_filt[keep_indices]
return filtered_boxes, filtered_keypoints
def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
if keypoint_text_example in globals():
keypoint_dict = globals()[keypoint_text_example]
elif instance_text_prompt in globals():
keypoint_dict = globals()[instance_text_prompt]
else:
keypoint_dict = globals()["animal"]
keypoint_text_prompt = keypoint_dict.get("keypoints")
keypoint_skeleton = keypoint_dict.get("skeleton")
image_pil, image = self.load_image(input_image)
boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold)
size = image_pil.size
H, W = size[1], size[0]
keypoints_filt = keypoints_filt[0].squeeze(0)
kp = np.array(keypoints_filt.cpu())
num_kpts = len(keypoint_text_prompt)
Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
Z = Z.reshape(num_kpts * 2)
x = Z[0::2]
y = Z[1::2]
return np.stack((x, y), axis=1)
def warmup(self):
self.timer.tic()
img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)
elapse = self.timer.toc()
log(f'XPoseRunner warmup time: {elapse:.3f}s')

View File

@ -31,8 +31,6 @@ def headpose_pred_to_degree(pred):
def get_rotation_matrix(pitch_, yaw_, roll_): def get_rotation_matrix(pitch_, yaw_, roll_):
""" the input is in degree """ the input is in degree
""" """
# calculate the rotation matrix: vps @ rot
# transform to radian # transform to radian
pitch = pitch_ / 180 * PI pitch = pitch_ / 180 * PI
yaw = yaw_ / 180 * PI yaw = yaw_ / 180 * PI

View File

@ -0,0 +1,18 @@
import socket
import sys
if len(sys.argv) != 2:
print("Usage: python check_port.py <port>")
sys.exit(1)
port = int(sys.argv[1])
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
result = sock.connect_ex(('127.0.0.1', port))
if result == 0:
print("LISTENING")
else:
print("NOT LISTENING")
sock.close

View File

@ -136,6 +136,29 @@ def parse_pt2_from_pt5(pt5, use_lip=True):
], axis=0) ], axis=0)
return pt2 return pt2
def parse_pt2_from_pt9(pt9, use_lip=True):
'''
parsing the 2 points according to the 9 points, which cancels the roll
['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip']
'''
if use_lip:
pt9 = np.stack([
(pt9[2] + pt9[3]) / 2, # left eye
(pt9[0] + pt9[1]) / 2, # right eye
pt9[4],
(pt9[5] + pt9[6] ) / 2 # lip
], axis=0)
pt2 = np.stack([
(pt9[0] + pt9[1]) / 2, # eye
pt9[3] # lip
], axis=0)
else:
pt2 = np.stack([
(pt9[2] + pt9[3]) / 2,
(pt9[0] + pt9[1]) / 2,
], axis=0)
return pt2
def parse_pt2_from_pt_x(pts, use_lip=True): def parse_pt2_from_pt_x(pts, use_lip=True):
if pts.shape[0] == 101: if pts.shape[0] == 101:
@ -151,6 +174,8 @@ def parse_pt2_from_pt_x(pts, use_lip=True):
elif pts.shape[0] > 101: elif pts.shape[0] > 101:
# take the first 101 points # take the first 101 points
pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip) pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip)
elif pts.shape[0] == 9:
pt2 = parse_pt2_from_pt9(pts, use_lip=use_lip)
else: else:
raise Exception(f'Unknow shape: {pts.shape}') raise Exception(f'Unknow shape: {pts.shape}')
@ -281,11 +306,10 @@ def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=Fals
dtype=DTYPE dtype=DTYPE
) )
if flag_rot and angle is None: # if flag_rot and angle is None:
print('angle is None, but flag_rotate is True', style="bold yellow") # print('angle is None, but flag_rotate is True', style="bold yellow")
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None)) img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)]) M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
@ -362,17 +386,6 @@ def crop_image(img, pts: np.ndarray, **kwargs):
flag_do_rot=kwargs.get('flag_do_rot', True), flag_do_rot=kwargs.get('flag_do_rot', True),
) )
if img is None:
M_INV_H = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
M = np.linalg.inv(M_INV_H)
ret_dct = {
'M': M[:2, ...], # from the original image to the cropped image
'M_o2c': M[:2, ...], # from the cropped image to the original image
'img_crop': None,
'pt_crop': None,
}
return ret_dct
img_crop = _transform_img(img, M_INV, dsize) # origin to crop img_crop = _transform_img(img, M_INV, dsize) # origin to crop
pt_crop = _transform_pts(pts, M_INV) pt_crop = _transform_pts(pts, M_INV)
@ -397,16 +410,14 @@ def average_bbox_lst(bbox_lst):
def prepare_paste_back(mask_crop, crop_M_c2o, dsize): def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
"""prepare mask for later image paste back """prepare mask for later image paste back
""" """
if mask_crop is None:
mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
mask_ori = mask_ori.astype(np.float32) / 255. mask_ori = mask_ori.astype(np.float32) / 255.
return mask_ori return mask_ori
def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori): def paste_back(img_crop, M_c2o, img_ori, mask_ori):
"""paste back the image """paste back the image
""" """
dsize = (rgb_ori.shape[1], rgb_ori.shape[0]) dsize = (img_ori.shape[1], img_ori.shape[0])
result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize) result = _transform_img(img_crop, M_c2o, dsize=dsize)
result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8) result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8)
return result return result

View File

@ -1,21 +1,25 @@
# coding: utf-8 # coding: utf-8
import gradio as gr
import numpy as np
import os.path as osp import os.path as osp
from typing import List, Union, Tuple import torch
from dataclasses import dataclass, field import numpy as np
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
from .landmark_runner import LandmarkRunner from PIL import Image
from .face_analysis_diy import FaceAnalysisDIY from typing import List, Tuple, Union
from .helper import prefix from dataclasses import dataclass, field
from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst
from .timer import Timer
from .rprint import rlog as log
from .io import load_image_rgb
from .video import VideoWriter, get_fps, change_video_fps
from ..config.crop_config import CropConfig
from .crop import (
average_bbox_lst,
crop_image,
crop_image_by_bbox,
parse_bbox_from_landmark,
)
from .io import contiguous
from .rprint import rlog as log
from .face_analysis_diy import FaceAnalysisDIY
from .human_landmark_runner import LandmarkRunner as HumanLandmark
def make_abs_path(fn): def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn) return osp.join(osp.dirname(osp.realpath(__file__)), fn)
@ -23,123 +27,287 @@ def make_abs_path(fn):
@dataclass @dataclass
class Trajectory: class Trajectory:
start: int = -1 # 起始帧 闭区间 start: int = -1 # start frame
end: int = -1 # 结束帧 闭区间 end: int = -1 # end frame
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
M_c2o_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # M_c2o list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
class Cropper(object): class Cropper(object):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
device_id = kwargs.get('device_id', 0) self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
self.landmark_runner = LandmarkRunner( self.image_type = kwargs.get("image_type", 'human_face')
ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'), device_id = kwargs.get("device_id", 0)
onnx_provider='cuda', flag_force_cpu = kwargs.get("flag_force_cpu", False)
device_id=device_id if flag_force_cpu:
) device = "cpu"
self.landmark_runner.warmup() face_analysis_wrapper_provider = ["CPUExecutionProvider"]
else:
try:
if torch.backends.mps.is_available():
# Shape inference currently fails with CoreMLExecutionProvider
# for the retinaface model
device = "mps"
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
else:
device = "cuda"
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
except:
device = "cuda"
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
self.face_analysis_wrapper = FaceAnalysisDIY( self.face_analysis_wrapper = FaceAnalysisDIY(
name='buffalo_l', name="buffalo_l",
root=make_abs_path('../../pretrained_weights/insightface'), root=self.crop_cfg.insightface_root,
providers=["CUDAExecutionProvider"] providers=face_analysis_wrapper_provider,
) )
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512)) self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh)
self.face_analysis_wrapper.warmup() self.face_analysis_wrapper.warmup()
self.crop_cfg = kwargs.get('crop_cfg', None) self.human_landmark_runner = HumanLandmark(
ckpt_path=self.crop_cfg.landmark_ckpt_path,
onnx_provider=device,
device_id=device_id,
)
self.human_landmark_runner.warmup()
if self.image_type == "animal_face":
from .animal_landmark_runner import XPoseRunner as AnimalLandmarkRunner
self.animal_landmark_runner = AnimalLandmarkRunner(
model_config_path=self.crop_cfg.xpose_config_file_path,
model_checkpoint_path=self.crop_cfg.xpose_ckpt_path,
embeddings_cache_path=self.crop_cfg.xpose_embedding_cache_path,
flag_use_half_precision=kwargs.get("flag_use_half_precision", True),
)
self.animal_landmark_runner.warmup()
def update_config(self, user_args): def update_config(self, user_args):
for k, v in user_args.items(): for k, v in user_args.items():
if hasattr(self.crop_cfg, k): if hasattr(self.crop_cfg, k):
setattr(self.crop_cfg, k, v) setattr(self.crop_cfg, k, v)
def crop_single_image(self, obj, **kwargs): def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig):
direction = kwargs.get('direction', 'large-small') # crop a source image and get neccessary information
img_rgb = img_rgb_.copy() # copy it
# crop and align a single image img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
if isinstance(obj, str):
img_rgb = load_image_rgb(obj)
elif isinstance(obj, np.ndarray):
img_rgb = obj
if self.image_type == "human_face":
src_face = self.face_analysis_wrapper.get( src_face = self.face_analysis_wrapper.get(
img_rgb, img_bgr,
flag_do_landmark_2d_106=True, flag_do_landmark_2d_106=True,
direction=direction direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
) )
if len(src_face) == 0: if len(src_face) == 0:
log('No face detected in the source image.') log("No face detected in the source image.")
raise gr.Error("No face detected in the source image 💥!", duration=5) return None
raise Exception("No face detected in the source image!")
elif len(src_face) > 1: elif len(src_face) > 1:
log(f'More than one face detected in the image, only pick one face by rule {direction}.') log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
# NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0] src_face = src_face[0]
pts = src_face.landmark_2d_106 lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
else:
tmp_dct = {
'animal_face_9': 'animal_face',
'animal_face_68': 'face'
}
img_rgb_pil = Image.fromarray(img_rgb)
lmk = self.animal_landmark_runner.run(
img_rgb_pil,
'face',
tmp_dct[crop_cfg.animal_face_type],
0,
0
)
# crop the face # crop the face
ret_dct = crop_image( ret_dct = crop_image(
img_rgb, # ndarray img_rgb, # ndarray
pts, # 106x2 or Nx2 lmk, # 106x2 or Nx2
dsize=kwargs.get('dsize', 512), dsize=crop_cfg.dsize,
scale=kwargs.get('scale', 2.3), scale=crop_cfg.scale,
vy_ratio=kwargs.get('vy_ratio', -0.15), vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
) )
# update a 256x256 version for network input or else
ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct['pt_crop_256x256'] = ret_dct['pt_crop'] * 256 / kwargs.get('dsize', 512)
recon_ret = self.landmark_runner.run(img_rgb, pts) # update a 256x256 version for network input
lmk = recon_ret['pts'] ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct['lmk_crop'] = lmk if self.image_type == "human_face":
lmk = self.human_landmark_runner.run(img_rgb, lmk)
ret_dct["lmk_crop"] = lmk
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
else:
# 68x2 or 9x2
ret_dct["lmk_crop"] = lmk
return ret_dct return ret_dct
def get_retargeting_lmk_info(self, driving_rgb_lst): def calc_lmk_from_cropped_image(self, img_rgb_, **kwargs):
# TODO: implement a tracking-based version direction = kwargs.get("direction", "large-small")
driving_lmk_lst = []
for driving_image in driving_rgb_lst:
ret_dct = self.crop_single_image(driving_image)
driving_lmk_lst.append(ret_dct['lmk_crop'])
return driving_lmk_lst
def make_video_clip(self, driving_rgb_lst, output_path, output_fps=30, **kwargs):
trajectory = Trajectory()
direction = kwargs.get('direction', 'large-small')
for idx, driving_image in enumerate(driving_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get( src_face = self.face_analysis_wrapper.get(
driving_image, contiguous(img_rgb_[..., ::-1]), # convert to BGR
flag_do_landmark_2d_106=True, flag_do_landmark_2d_106=True,
direction=direction direction=direction,
) )
if len(src_face) == 0: if len(src_face) == 0:
# No face detected in the driving_image log("No face detected in the source image.")
return None
elif len(src_face) > 1:
log(f"More than one face detected in the image, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.human_landmark_runner.run(img_rgb_, lmk)
return lmk
# TODO: support skipping frame with NO FACE
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb in enumerate(source_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
continue continue
elif len(src_face) > 1: elif len(src_face) > 1:
log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.') log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0] src_face = src_face[0]
pts = src_face.landmark_2d_106 lmk = src_face.landmark_2d_106
lmk_203 = self.landmark_runner(driving_image, pts)['pts'] lmk = self.human_landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx trajectory.start, trajectory.end = idx, idx
else: else:
lmk_203 = self.face_recon_wrapper(driving_image, trajectory.lmk_lst[-1])['pts'] # TODO: add IOU check for tracking
lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx trajectory.end = idx
trajectory.lmk_lst.append(lmk_203) trajectory.lmk_lst.append(lmk)
ret_bbox = parse_bbox_from_landmark(lmk_203, scale=self.crop_cfg.globalscale, vy_ratio=elf.crop_cfg.vy_ratio)['bbox']
bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4, # crop the face
ret_dct = crop_image(
frame_rgb, # ndarray
lmk, # 106x2 or Nx2
dsize=crop_cfg.dsize,
scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop_256x256"])
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop_256x256"])
trajectory.M_c2o_lst.append(ret_dct['M_c2o'])
return {
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
"lmk_crop_lst": trajectory.lmk_crop_lst,
"M_c2o_lst": trajectory.M_c2o_lst,
}
def crop_driving_video(self, driving_rgb_lst, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb in enumerate(driving_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
direction=direction,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
ret_bbox = parse_bbox_from_landmark(
lmk,
scale=self.crop_cfg.scale_crop_driving_video,
vx_ratio_crop_driving_video=self.crop_cfg.vx_ratio_crop_driving_video,
vy_ratio=self.crop_cfg.vy_ratio_crop_driving_video,
)["bbox"]
bbox = [
ret_bbox[0, 0],
ret_bbox[0, 1],
ret_bbox[2, 0],
ret_bbox[2, 1],
] # 4,
trajectory.bbox_lst.append(bbox) # bbox trajectory.bbox_lst.append(bbox) # bbox
trajectory.frame_rgb_lst.append(driving_image) trajectory.frame_rgb_lst.append(frame_rgb)
global_bbox = average_bbox_lst(trajectory.bbox_lst) global_bbox = average_bbox_lst(trajectory.bbox_lst)
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)): for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
ret_dct = crop_image_by_bbox( ret_dct = crop_image_by_bbox(
frame_rgb, global_bbox, lmk=lmk, frame_rgb,
dsize=self.video_crop_cfg.dsize, flag_rot=self.video_crop_cfg.flag_rot, borderValue=self.video_crop_cfg.borderValue global_bbox,
lmk=lmk,
dsize=kwargs.get("dsize", 512),
flag_rot=False,
borderValue=(0, 0, 0),
) )
frame_rgb_crop = ret_dct['img_crop'] trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"])
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"])
return {
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
"lmk_crop_lst": trajectory.lmk_crop_lst,
}
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
"""Tracking based landmarks/alignment"""
trajectory = Trajectory()
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR
flag_do_landmark_2d_106=True,
direction=direction,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
raise Exception(f"No face detected in the frame #{idx}")
elif len(src_face) > 1:
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.human_landmark_runner.run(frame_rgb_crop, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.human_landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
return trajectory.lmk_lst

View File

@ -0,0 +1,125 @@
_base_ = ['coco_transformer.py']
use_label_enc = True
num_classes=2
lr = 0.0001
param_dict_type = 'default'
lr_backbone = 1e-05
lr_backbone_names = ['backbone.0']
lr_linear_proj_names = ['reference_points', 'sampling_offsets']
lr_linear_proj_mult = 0.1
ddetr_lr_param = False
batch_size = 2
weight_decay = 0.0001
epochs = 12
lr_drop = 11
save_checkpoint_interval = 100
clip_max_norm = 0.1
onecyclelr = False
multi_step_lr = False
lr_drop_list = [33, 45]
modelname = 'UniPose'
frozen_weights = None
backbone = 'swin_T_224_1k'
dilation = False
position_embedding = 'sine'
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
unic_layers = 0
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
pdetr3_bbox_embed_diff_each_layer = False
pdetr3_refHW = -1
random_refpoints_xy = False
fix_refpoints_hw = -1
dabdetr_yolo_like_anchor_update = False
dabdetr_deformable_encoder = False
dabdetr_deformable_decoder = False
use_deformable_box_attn = False
box_attn_type = 'roi_align'
dec_layer_number = None
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
decoder_layer_noise = False
dln_xy_noise = 0.2
dln_hw_noise = 0.2
add_channel_attention = False
add_pos_value = False
two_stage_type = 'standard'
two_stage_pat_embed = 0
two_stage_add_query_num = 0
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
two_stage_learn_wh = False
two_stage_default_hw = 0.05
two_stage_keep_all_tokens = False
num_select = 50
transformer_activation = 'relu'
batch_norm_type = 'FrozenBatchNorm2d'
masks = False
decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content']
matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
decoder_module_seq = ['sa', 'ca', 'ffn']
nms_iou_threshold = -1
dec_pred_bbox_embed_share = True
dec_pred_class_embed_share = True
use_dn = True
dn_number = 100
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef=1.0
dn_bbox_coef=1.0
embed_init_tgt = True
dn_labelbook_size = 2000
match_unstable_error = True
# for ema
use_ema = True
ema_decay = 0.9997
ema_epoch = 0
use_detached_boxes_dec_out = False
max_text_len = 256
shuffle_type = None
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = False # True
use_transformer_ckpt = True
text_encoder_type = 'bert-base-uncased'
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
num_body_points=68
binary_query_selection = False
use_cdn = True
ffn_extra_layernorm = False
fix_size=False

View File

@ -0,0 +1,8 @@
data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
data_aug_max_size = 1333
data_aug_scales2_resize = [400, 500, 600]
data_aug_scales2_crop = [384, 600]
data_aug_scale_overlap = None

Some files were not shown because too many files have changed in this diff Show More