17
.gitignore
vendored
@ -9,9 +9,26 @@ __pycache__/
|
||||
**/*.pth
|
||||
**/*.onnx
|
||||
|
||||
pretrained_weights/*.md
|
||||
pretrained_weights/docs
|
||||
pretrained_weights/liveportrait
|
||||
pretrained_weights/liveportrait_animals
|
||||
|
||||
# Ipython notebook
|
||||
*.ipynb
|
||||
|
||||
# Temporary files or benchmark resources
|
||||
animations/*
|
||||
tmp/*
|
||||
.vscode/launch.json
|
||||
**/*.DS_Store
|
||||
gradio_temp/**
|
||||
|
||||
# Windows dependencies
|
||||
ffmpeg/
|
||||
LivePortrait_env/
|
||||
|
||||
# XPose build files
|
||||
src/utils/dependencies/XPose/models/UniPose/ops/build
|
||||
src/utils/dependencies/XPose/models/UniPose/ops/dist
|
||||
src/utils/dependencies/XPose/models/UniPose/ops/MultiScaleDeformableAttention.egg-info
|
||||
|
9
LICENSE
@ -19,3 +19,12 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
---
|
||||
|
||||
The code of InsightFace is released under the MIT License.
|
||||
The models of InsightFace are for non-commercial research purposes only.
|
||||
|
||||
If you want to use the LivePortrait project for commercial purposes, you
|
||||
should remove and replace InsightFace’s detection models to fully comply with
|
||||
the MIT license.
|
||||
|
450
app.py
@ -1,10 +1,12 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
The entrance of the gradio
|
||||
The entrance of the gradio for human
|
||||
"""
|
||||
|
||||
import os
|
||||
import tyro
|
||||
import subprocess
|
||||
import gradio as gr
|
||||
import os.path as osp
|
||||
from src.utils.helper import load_description
|
||||
@ -18,137 +20,451 @@ def partial_fields(target_class, kwargs):
|
||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# set tyro theme
|
||||
tyro.extras.set_accent_color("bright_cyan")
|
||||
args = tyro.cli(ArgumentConfig)
|
||||
|
||||
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
|
||||
if osp.exists(ffmpeg_dir):
|
||||
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
|
||||
|
||||
if not fast_check_ffmpeg():
|
||||
raise ImportError(
|
||||
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
|
||||
)
|
||||
# specify configs for inference
|
||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||
# global_tab_selection = None
|
||||
|
||||
gradio_pipeline = GradioPipeline(
|
||||
inference_cfg=inference_cfg,
|
||||
crop_cfg=crop_cfg,
|
||||
args=args
|
||||
)
|
||||
|
||||
if args.gradio_temp_dir not in (None, ''):
|
||||
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
|
||||
os.makedirs(args.gradio_temp_dir, exist_ok=True)
|
||||
|
||||
|
||||
def gpu_wrapped_execute_video(*args, **kwargs):
|
||||
return gradio_pipeline.execute_video(*args, **kwargs)
|
||||
|
||||
|
||||
def gpu_wrapped_execute_image_retargeting(*args, **kwargs):
|
||||
return gradio_pipeline.execute_image_retargeting(*args, **kwargs)
|
||||
|
||||
|
||||
def gpu_wrapped_execute_video_retargeting(*args, **kwargs):
|
||||
return gradio_pipeline.execute_video_retargeting(*args, **kwargs)
|
||||
|
||||
|
||||
def reset_sliders(*args, **kwargs):
|
||||
return 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5, True, True
|
||||
|
||||
|
||||
# assets
|
||||
title_md = "assets/gradio_title.md"
|
||||
title_md = "assets/gradio/gradio_title.md"
|
||||
example_portrait_dir = "assets/examples/source"
|
||||
example_video_dir = "assets/examples/driving"
|
||||
data_examples = [
|
||||
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
|
||||
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
|
||||
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True, True],
|
||||
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True, True],
|
||||
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True, True],
|
||||
data_examples_i2v = [
|
||||
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
|
||||
]
|
||||
data_examples_v2v = [
|
||||
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
|
||||
]
|
||||
#################### interface logic ####################
|
||||
|
||||
# Define components first
|
||||
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
|
||||
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
|
||||
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||
retargeting_input_image = gr.Image(type="numpy")
|
||||
video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||
head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch")
|
||||
head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw")
|
||||
head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll")
|
||||
mov_x = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="x-axis movement")
|
||||
mov_y = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="y-axis movement")
|
||||
mov_z = gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.01, label="z-axis movement")
|
||||
lip_variation_zero = gr.Slider(minimum=-0.09, maximum=0.09, value=0, step=0.01, label="pouting")
|
||||
lip_variation_one = gr.Slider(minimum=-20.0, maximum=15.0, value=0, step=0.01, label="pursing 😐")
|
||||
lip_variation_two = gr.Slider(minimum=0.0, maximum=15.0, value=0, step=0.01, label="grin 😁")
|
||||
lip_variation_three = gr.Slider(minimum=-90.0, maximum=120.0, value=0, step=1.0, label="lip close <-> open")
|
||||
smile = gr.Slider(minimum=-0.3, maximum=1.3, value=0, step=0.01, label="smile 😄")
|
||||
wink = gr.Slider(minimum=0, maximum=39, value=0, step=0.01, label="wink 😉")
|
||||
eyebrow = gr.Slider(minimum=-30, maximum=30, value=0, step=0.01, label="eyebrow 🤨")
|
||||
eyeball_direction_x = gr.Slider(minimum=-30.0, maximum=30.0, value=0, step=0.01, label="eye gaze (horizontal) 👀")
|
||||
eyeball_direction_y = gr.Slider(minimum=-63.0, maximum=63.0, value=0, step=0.01, label="eye gaze (vertical) 🙄")
|
||||
retargeting_input_image = gr.Image(type="filepath")
|
||||
retargeting_input_video = gr.Video()
|
||||
output_image = gr.Image(type="numpy")
|
||||
output_image_paste_back = gr.Image(type="numpy")
|
||||
output_video = gr.Video()
|
||||
output_video_concat = gr.Video()
|
||||
retargeting_output_image = gr.Image(type="numpy")
|
||||
retargeting_output_image_paste_back = gr.Image(type="numpy")
|
||||
output_video = gr.Video(autoplay=False)
|
||||
output_video_paste_back = gr.Video(autoplay=False)
|
||||
output_video_i2v = gr.Video(autoplay=False)
|
||||
output_video_concat_i2v = gr.Video(autoplay=False)
|
||||
|
||||
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
||||
|
||||
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
|
||||
gr.HTML(load_description(title_md))
|
||||
gr.Markdown(load_description("assets/gradio_description_upload.md"))
|
||||
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_upload.md"))
|
||||
with gr.Row():
|
||||
with gr.Accordion(open=True, label="Source Portrait"):
|
||||
image_input = gr.Image(type="filepath")
|
||||
with gr.Accordion(open=True, label="Driving Video"):
|
||||
video_input = gr.Video()
|
||||
gr.Markdown(load_description("assets/gradio_description_animation.md"))
|
||||
with gr.Column():
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("🖼️ Source Image") as tab_image:
|
||||
with gr.Accordion(open=True, label="Source Image"):
|
||||
source_image_input = gr.Image(type="filepath")
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s9.jpg")],
|
||||
[osp.join(example_portrait_dir, "s6.jpg")],
|
||||
[osp.join(example_portrait_dir, "s10.jpg")],
|
||||
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||
[osp.join(example_portrait_dir, "s22.jpg")],
|
||||
[osp.join(example_portrait_dir, "s23.jpg")],
|
||||
],
|
||||
inputs=[source_image_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
with gr.TabItem("🎞️ Source Video") as tab_video:
|
||||
with gr.Accordion(open=True, label="Source Video"):
|
||||
source_video_input = gr.Video()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s13.mp4")],
|
||||
# [osp.join(example_portrait_dir, "s14.mp4")],
|
||||
# [osp.join(example_portrait_dir, "s15.mp4")],
|
||||
[osp.join(example_portrait_dir, "s18.mp4")],
|
||||
# [osp.join(example_portrait_dir, "s19.mp4")],
|
||||
[osp.join(example_portrait_dir, "s20.mp4")],
|
||||
],
|
||||
inputs=[source_video_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
tab_selection = gr.Textbox(visible=False)
|
||||
tab_image.select(lambda: "Image", None, tab_selection)
|
||||
tab_video.select(lambda: "Video", None, tab_selection)
|
||||
with gr.Accordion(open=True, label="Cropping Options for Source Image or Video"):
|
||||
with gr.Row():
|
||||
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
|
||||
scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05)
|
||||
vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("🎞️ Driving Video") as v_tab_video:
|
||||
with gr.Accordion(open=True, label="Driving Video"):
|
||||
driving_video_input = gr.Video()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_video_dir, "d0.mp4")],
|
||||
[osp.join(example_video_dir, "d18.mp4")],
|
||||
[osp.join(example_video_dir, "d19.mp4")],
|
||||
[osp.join(example_video_dir, "d14.mp4")],
|
||||
[osp.join(example_video_dir, "d6.mp4")],
|
||||
[osp.join(example_video_dir, "d20.mp4")],
|
||||
],
|
||||
inputs=[driving_video_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
|
||||
with gr.Accordion(open=True, label="Driving Pickle"):
|
||||
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_video_dir, "d1.pkl")],
|
||||
[osp.join(example_video_dir, "d2.pkl")],
|
||||
[osp.join(example_video_dir, "d5.pkl")],
|
||||
[osp.join(example_video_dir, "d7.pkl")],
|
||||
[osp.join(example_video_dir, "d8.pkl")],
|
||||
],
|
||||
inputs=[driving_video_pickle_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
v_tab_selection = gr.Textbox(visible=False)
|
||||
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
|
||||
v_tab_video.select(lambda: "Video", None, v_tab_selection)
|
||||
# with gr.Accordion(open=False, label="Animation Instructions"):
|
||||
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
|
||||
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
|
||||
with gr.Row():
|
||||
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
|
||||
scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05)
|
||||
vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Accordion(open=True, label="Animation Options"):
|
||||
with gr.Row():
|
||||
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
||||
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
|
||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||
flag_stitching_input = gr.Checkbox(value=True, label="stitching")
|
||||
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
|
||||
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
|
||||
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
|
||||
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||
with gr.Column():
|
||||
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
|
||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
||||
output_video.render()
|
||||
output_video_i2v.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video"):
|
||||
output_video_concat.render()
|
||||
output_video_concat_i2v.render()
|
||||
with gr.Row():
|
||||
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
|
||||
|
||||
with gr.Row():
|
||||
# Examples
|
||||
gr.Markdown("## You could choose the examples below ⬇️")
|
||||
with gr.Row():
|
||||
gr.Examples(
|
||||
examples=data_examples,
|
||||
inputs=[
|
||||
image_input,
|
||||
video_input,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input
|
||||
],
|
||||
examples_per_page=5
|
||||
)
|
||||
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
|
||||
gr.Markdown("## You could also choose the examples below by one click ⬇️")
|
||||
with gr.Row():
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("🖼️ Portrait Animation"):
|
||||
gr.Examples(
|
||||
examples=data_examples_i2v,
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
source_image_input,
|
||||
driving_video_input,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input,
|
||||
],
|
||||
outputs=[output_image, output_image_paste_back],
|
||||
examples_per_page=len(data_examples_i2v),
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.TabItem("🎞️ Portrait Video Editing"):
|
||||
gr.Examples(
|
||||
examples=data_examples_v2v,
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
source_video_input,
|
||||
driving_video_input,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input,
|
||||
flag_video_editing_head_rotation,
|
||||
driving_smooth_observation_variance,
|
||||
],
|
||||
outputs=[output_image, output_image_paste_back],
|
||||
examples_per_page=len(data_examples_v2v),
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
# Retargeting Image
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
|
||||
with gr.Row(visible=True):
|
||||
flag_do_crop_input_retargeting_image = gr.Checkbox(value=True, label="do crop (source)")
|
||||
flag_stitching_retargeting_input = gr.Checkbox(value=True, label="stitching")
|
||||
retargeting_source_scale.render()
|
||||
eye_retargeting_slider.render()
|
||||
lip_retargeting_slider.render()
|
||||
with gr.Row():
|
||||
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
||||
with gr.Row(visible=True):
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Facial movement sliders"):
|
||||
with gr.Row(visible=True):
|
||||
head_pitch_slider.render()
|
||||
head_yaw_slider.render()
|
||||
head_roll_slider.render()
|
||||
with gr.Row(visible=True):
|
||||
mov_x.render()
|
||||
mov_y.render()
|
||||
mov_z.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Facial expression sliders"):
|
||||
with gr.Row(visible=True):
|
||||
lip_variation_zero.render()
|
||||
lip_variation_one.render()
|
||||
lip_variation_two.render()
|
||||
with gr.Row(visible=True):
|
||||
lip_variation_three.render()
|
||||
smile.render()
|
||||
wink.render()
|
||||
with gr.Row(visible=True):
|
||||
eyebrow.render()
|
||||
eyeball_direction_x.render()
|
||||
eyeball_direction_y.render()
|
||||
with gr.Row(visible=True):
|
||||
reset_button = gr.Button("🔄 Reset")
|
||||
reset_button.click(
|
||||
fn=reset_sliders,
|
||||
inputs=None,
|
||||
outputs=[
|
||||
head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z,
|
||||
lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y,
|
||||
retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image
|
||||
]
|
||||
)
|
||||
with gr.Row(visible=True):
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Retargeting Image Input"):
|
||||
retargeting_input_image.render()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s9.jpg")],
|
||||
[osp.join(example_portrait_dir, "s6.jpg")],
|
||||
[osp.join(example_portrait_dir, "s10.jpg")],
|
||||
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||
[osp.join(example_portrait_dir, "s22.jpg")],
|
||||
# [osp.join(example_portrait_dir, "s23.jpg")],
|
||||
[osp.join(example_portrait_dir, "s42.jpg")],
|
||||
],
|
||||
inputs=[retargeting_input_image],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Retargeting Result"):
|
||||
retargeting_output_image.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Paste-back Result"):
|
||||
retargeting_output_image_paste_back.render()
|
||||
with gr.Row(visible=True):
|
||||
process_button_reset_retargeting = gr.ClearButton(
|
||||
[
|
||||
eye_retargeting_slider,
|
||||
lip_retargeting_slider,
|
||||
retargeting_input_image,
|
||||
output_image,
|
||||
output_image_paste_back
|
||||
retargeting_output_image,
|
||||
retargeting_output_image_paste_back,
|
||||
],
|
||||
value="🧹 Clear"
|
||||
)
|
||||
with gr.Row():
|
||||
|
||||
# Retargeting Video
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting_video.md"), visible=True)
|
||||
with gr.Row(visible=True):
|
||||
flag_do_crop_input_retargeting_video = gr.Checkbox(value=True, label="do crop (source)")
|
||||
video_retargeting_source_scale.render()
|
||||
video_lip_retargeting_slider.render()
|
||||
driving_smooth_observation_variance_retargeting.render()
|
||||
with gr.Row(visible=True):
|
||||
process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
|
||||
with gr.Row(visible=True):
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Retargeting Input"):
|
||||
retargeting_input_image.render()
|
||||
with gr.Accordion(open=True, label="Retargeting Video Input"):
|
||||
retargeting_input_video.render()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s13.mp4")],
|
||||
# [osp.join(example_portrait_dir, "s18.mp4")],
|
||||
[osp.join(example_portrait_dir, "s20.mp4")],
|
||||
[osp.join(example_portrait_dir, "s29.mp4")],
|
||||
[osp.join(example_portrait_dir, "s32.mp4")],
|
||||
],
|
||||
inputs=[retargeting_input_video],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Retargeting Result"):
|
||||
output_image.render()
|
||||
output_video.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Paste-back Result"):
|
||||
output_image_paste_back.render()
|
||||
output_video_paste_back.render()
|
||||
with gr.Row(visible=True):
|
||||
process_button_reset_retargeting = gr.ClearButton(
|
||||
[
|
||||
video_lip_retargeting_slider,
|
||||
retargeting_input_video,
|
||||
output_video,
|
||||
output_video_paste_back
|
||||
],
|
||||
value="🧹 Clear"
|
||||
)
|
||||
|
||||
# binding functions for buttons
|
||||
process_button_retargeting.click(
|
||||
fn=gradio_pipeline.execute_image,
|
||||
inputs=[eye_retargeting_slider, lip_retargeting_slider],
|
||||
outputs=[output_image, output_image_paste_back],
|
||||
show_progress=True
|
||||
)
|
||||
process_button_animation.click(
|
||||
fn=gradio_pipeline.execute_video,
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
image_input,
|
||||
video_input,
|
||||
source_image_input,
|
||||
source_video_input,
|
||||
driving_video_pickle_input,
|
||||
driving_video_input,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input
|
||||
flag_remap_input,
|
||||
flag_stitching_input,
|
||||
driving_option_input,
|
||||
driving_multiplier,
|
||||
flag_crop_driving_video_input,
|
||||
flag_video_editing_head_rotation,
|
||||
scale,
|
||||
vx_ratio,
|
||||
vy_ratio,
|
||||
scale_crop_driving_video,
|
||||
vx_ratio_crop_driving_video,
|
||||
vy_ratio_crop_driving_video,
|
||||
driving_smooth_observation_variance,
|
||||
tab_selection,
|
||||
v_tab_selection,
|
||||
],
|
||||
outputs=[output_video, output_video_concat],
|
||||
outputs=[output_video_i2v, output_video_concat_i2v],
|
||||
show_progress=True
|
||||
)
|
||||
image_input.change(
|
||||
fn=gradio_pipeline.prepare_retargeting,
|
||||
inputs=image_input,
|
||||
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
|
||||
|
||||
retargeting_input_image.change(
|
||||
fn=gradio_pipeline.init_retargeting_image,
|
||||
inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image],
|
||||
outputs=[eye_retargeting_slider, lip_retargeting_slider]
|
||||
)
|
||||
|
||||
##########################################################
|
||||
sliders = [eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y]
|
||||
for slider in sliders:
|
||||
# NOTE: gradio >= 4.0.0 may cause slow response
|
||||
slider.change(
|
||||
fn=gpu_wrapped_execute_image_retargeting,
|
||||
inputs=[
|
||||
eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z,
|
||||
lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y,
|
||||
retargeting_input_image, retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image
|
||||
],
|
||||
outputs=[retargeting_output_image, retargeting_output_image_paste_back],
|
||||
)
|
||||
|
||||
process_button_retargeting_video.click(
|
||||
fn=gpu_wrapped_execute_video_retargeting,
|
||||
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video],
|
||||
outputs=[output_video, output_video_paste_back],
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
server_name=args.server_name,
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
server_name=args.server_name
|
||||
)
|
||||
|
248
app_animals.py
Normal file
@ -0,0 +1,248 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
The entrance of the gradio for animal
|
||||
"""
|
||||
|
||||
import os
|
||||
import tyro
|
||||
import subprocess
|
||||
import gradio as gr
|
||||
import os.path as osp
|
||||
from src.utils.helper import load_description
|
||||
from src.gradio_pipeline import GradioPipelineAnimal
|
||||
from src.config.crop_config import CropConfig
|
||||
from src.config.argument_config import ArgumentConfig
|
||||
from src.config.inference_config import InferenceConfig
|
||||
|
||||
|
||||
def partial_fields(target_class, kwargs):
|
||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# set tyro theme
|
||||
tyro.extras.set_accent_color("bright_cyan")
|
||||
args = tyro.cli(ArgumentConfig)
|
||||
|
||||
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
|
||||
if osp.exists(ffmpeg_dir):
|
||||
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
|
||||
|
||||
if not fast_check_ffmpeg():
|
||||
raise ImportError(
|
||||
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
|
||||
)
|
||||
# specify configs for inference
|
||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||
|
||||
gradio_pipeline_animal: GradioPipelineAnimal = GradioPipelineAnimal(
|
||||
inference_cfg=inference_cfg,
|
||||
crop_cfg=crop_cfg,
|
||||
args=args
|
||||
)
|
||||
|
||||
if args.gradio_temp_dir not in (None, ''):
|
||||
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
|
||||
os.makedirs(args.gradio_temp_dir, exist_ok=True)
|
||||
|
||||
def gpu_wrapped_execute_video(*args, **kwargs):
|
||||
return gradio_pipeline_animal.execute_video(*args, **kwargs)
|
||||
|
||||
|
||||
# assets
|
||||
title_md = "assets/gradio/gradio_title.md"
|
||||
example_portrait_dir = "assets/examples/source"
|
||||
example_video_dir = "assets/examples/driving"
|
||||
data_examples_i2v = [
|
||||
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "d3.mp4"), True, False, False, False],
|
||||
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "d6.mp4"), True, False, False, False],
|
||||
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "d19.mp4"), True, False, False, False],
|
||||
]
|
||||
data_examples_i2v_pickle = [
|
||||
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "wink.pkl"), True, False, False, False],
|
||||
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "talking.pkl"), True, False, False, False],
|
||||
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "aggrieved.pkl"), True, False, False, False],
|
||||
]
|
||||
#################### interface logic ####################
|
||||
|
||||
# Define components first
|
||||
output_image = gr.Image(type="numpy")
|
||||
output_image_paste_back = gr.Image(type="numpy")
|
||||
output_video_i2v = gr.Video(autoplay=False)
|
||||
output_video_concat_i2v = gr.Video(autoplay=False)
|
||||
output_video_i2v_gif = gr.Image(type="numpy")
|
||||
|
||||
|
||||
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
|
||||
gr.HTML(load_description(title_md))
|
||||
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_upload_animal.md"))
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="🐱 Source Animal Image"):
|
||||
source_image_input = gr.Image(type="filepath")
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s25.jpg")],
|
||||
[osp.join(example_portrait_dir, "s30.jpg")],
|
||||
[osp.join(example_portrait_dir, "s31.jpg")],
|
||||
[osp.join(example_portrait_dir, "s32.jpg")],
|
||||
[osp.join(example_portrait_dir, "s39.jpg")],
|
||||
[osp.join(example_portrait_dir, "s40.jpg")],
|
||||
[osp.join(example_portrait_dir, "s41.jpg")],
|
||||
[osp.join(example_portrait_dir, "s38.jpg")],
|
||||
[osp.join(example_portrait_dir, "s36.jpg")],
|
||||
],
|
||||
inputs=[source_image_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
with gr.Accordion(open=True, label="Cropping Options for Source Image"):
|
||||
with gr.Row():
|
||||
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
|
||||
scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05)
|
||||
vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("📁 Driving Pickle") as tab_pickle:
|
||||
with gr.Accordion(open=True, label="Driving Pickle"):
|
||||
driving_video_pickle_input = gr.File()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_video_dir, "wink.pkl")],
|
||||
[osp.join(example_video_dir, "shy.pkl")],
|
||||
[osp.join(example_video_dir, "aggrieved.pkl")],
|
||||
[osp.join(example_video_dir, "open_lip.pkl")],
|
||||
[osp.join(example_video_dir, "laugh.pkl")],
|
||||
[osp.join(example_video_dir, "talking.pkl")],
|
||||
[osp.join(example_video_dir, "shake_face.pkl")],
|
||||
],
|
||||
inputs=[driving_video_pickle_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.TabItem("🎞️ Driving Video") as tab_video:
|
||||
with gr.Accordion(open=True, label="Driving Video"):
|
||||
driving_video_input = gr.Video()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
# [osp.join(example_video_dir, "d0.mp4")],
|
||||
# [osp.join(example_video_dir, "d18.mp4")],
|
||||
[osp.join(example_video_dir, "d19.mp4")],
|
||||
[osp.join(example_video_dir, "d14.mp4")],
|
||||
[osp.join(example_video_dir, "d6.mp4")],
|
||||
[osp.join(example_video_dir, "d3.mp4")],
|
||||
],
|
||||
inputs=[driving_video_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
tab_selection = gr.Textbox(visible=False)
|
||||
tab_pickle.select(lambda: "Pickle", None, tab_selection)
|
||||
tab_video.select(lambda: "Video", None, tab_selection)
|
||||
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
|
||||
with gr.Row():
|
||||
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
|
||||
scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05)
|
||||
vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Accordion(open=False, label="Animation Options"):
|
||||
with gr.Row():
|
||||
flag_stitching = gr.Checkbox(value=False, label="stitching (not recommended)")
|
||||
flag_remap_input = gr.Checkbox(value=False, label="paste-back (not recommended)")
|
||||
driving_multiplier = gr.Number(value=1.0, label="driving multiplier", minimum=0.0, maximum=2.0, step=0.02)
|
||||
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
|
||||
with gr.Row():
|
||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video in the cropped image space"):
|
||||
output_video_i2v.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated gif in the cropped image space"):
|
||||
output_video_i2v_gif.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video"):
|
||||
output_video_concat_i2v.render()
|
||||
with gr.Row():
|
||||
process_button_reset = gr.ClearButton([source_image_input, driving_video_input, output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], value="🧹 Clear")
|
||||
|
||||
with gr.Row():
|
||||
# Examples
|
||||
gr.Markdown("## You could also choose the examples below by one click ⬇️")
|
||||
with gr.Row():
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("📁 Driving Pickle") as tab_video:
|
||||
gr.Examples(
|
||||
examples=data_examples_i2v_pickle,
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
source_image_input,
|
||||
driving_video_pickle_input,
|
||||
flag_do_crop_input,
|
||||
flag_stitching,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input,
|
||||
],
|
||||
outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
|
||||
examples_per_page=len(data_examples_i2v_pickle),
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.TabItem("🎞️ Driving Video") as tab_video:
|
||||
gr.Examples(
|
||||
examples=data_examples_i2v,
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
source_image_input,
|
||||
driving_video_input,
|
||||
flag_do_crop_input,
|
||||
flag_stitching,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input,
|
||||
],
|
||||
outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
|
||||
examples_per_page=len(data_examples_i2v),
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
process_button_animation.click(
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
source_image_input,
|
||||
driving_video_input,
|
||||
driving_video_pickle_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
driving_multiplier,
|
||||
flag_stitching,
|
||||
flag_crop_driving_video_input,
|
||||
scale,
|
||||
vx_ratio,
|
||||
vy_ratio,
|
||||
scale_crop_driving_video,
|
||||
vx_ratio_crop_driving_video,
|
||||
vy_ratio_crop_driving_video,
|
||||
tab_selection,
|
||||
],
|
||||
outputs=[output_video_i2v, output_video_concat_i2v, output_video_i2v_gif],
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
server_name=args.server_name
|
||||
)
|
2
assets/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
examples/driving/*.pkl
|
||||
examples/driving/*_crop.mp4
|
BIN
assets/docs/LivePortrait-Gradio-2024-07-19.jpg
Normal file
After Width: | Height: | Size: 364 KiB |
BIN
assets/docs/animals-mode-gradio-2024-08-02.jpg
Normal file
After Width: | Height: | Size: 344 KiB |
22
assets/docs/changelog/2024-07-10.md
Normal file
@ -0,0 +1,22 @@
|
||||
## 2024/07/10
|
||||
|
||||
**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️
|
||||
The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
|
||||
|
||||
### Updates
|
||||
|
||||
- <strong>Audio and video concatenating: </strong> If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
|
||||
|
||||
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
|
||||
|
||||
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option.
|
||||
|
||||
|
||||
### About driving video
|
||||
|
||||
- For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
|
||||
|
||||
|
||||
### Others
|
||||
|
||||
- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).
|
24
assets/docs/changelog/2024-07-19.md
Normal file
@ -0,0 +1,24 @@
|
||||
## 2024/07/19
|
||||
|
||||
**Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! 🎉**
|
||||
We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk).
|
||||
|
||||
### Updates
|
||||
|
||||
- <strong>Portrait video editing (v2v):</strong> Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `--flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with `--flag_source_video_eye_retargeting`.
|
||||
|
||||
- <strong>More options in Gradio:</strong> We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control.
|
||||
|
||||
<p align="center">
|
||||
<img src="../LivePortrait-Gradio-2024-07-19.jpg" alt="LivePortrait" width="800px">
|
||||
<br>
|
||||
The Gradio Interface for LivePortrait
|
||||
</p>
|
||||
|
||||
|
||||
### Community Contributions
|
||||
|
||||
- **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance:
|
||||
- [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KwaiVGI/LivePortrait/issues/150))
|
||||
- [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KwaiVGI/LivePortrait/issues/126), [#142](https://github.com/KwaiVGI/LivePortrait/issues/142))
|
||||
- **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KwaiVGI/LivePortrait/issues/119).
|
12
assets/docs/changelog/2024-07-24.md
Normal file
@ -0,0 +1,12 @@
|
||||
## 2024/07/24
|
||||
|
||||
### Updates
|
||||
|
||||
- **Portrait pose editing:** You can change the `relative pitch`, `relative yaw`, and `relative roll` in the Gradio interface to adjust the pose of the source portrait.
|
||||
- **Detection threshold:** We have added a `--det_thresh` argument with a default value of 0.15 to increase recall, meaning more types of faces (e.g., monkeys, human-like) will be detected. You can set it to other values, e.g., 0.5, by using `python app.py --det_thresh 0.5`.
|
||||
|
||||
<p align="center">
|
||||
<img src="../pose-edit-2024-07-24.jpg" alt="LivePortrait" width="960px">
|
||||
<br>
|
||||
Pose Editing in the Gradio Interface
|
||||
</p>
|
75
assets/docs/changelog/2024-08-02.md
Normal file
@ -0,0 +1,75 @@
|
||||
## 2024/08/02
|
||||
|
||||
<table class="center" style="width: 80%; margin-left: auto; margin-right: auto;">
|
||||
<tr>
|
||||
<td style="text-align: center"><b>Animals Singing Dance Monkey 🎤</b></td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td style="border: none; text-align: center;">
|
||||
<video controls loop src="https://github.com/user-attachments/assets/38d5b6e5-d29b-458d-9f2c-4dd52546cb41" muted="false" style="width: 60%;"></video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
🎉 We are excited to announce the release of a new version featuring animals mode, along with several other updates. Special thanks to the dedicated efforts of the LivePortrait team. 💪 We also provided an one-click installer for Windows users, checkout the details [here](./2024-08-05.md).
|
||||
|
||||
### Updates on Animals mode
|
||||
We are pleased to announce the release of the animals mode, which is fine-tuned on approximately 230K frames of various animals (mostly cats and dogs). The trained weights have been updated in the `liveportrait_animals` subdirectory, available on [HuggingFace](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/) or [Google Drive](https://drive.google.com/drive/u/0/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib). You should [download the weights](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#2-download-pretrained-weights) before running. There are two ways to run this mode.
|
||||
|
||||
> Please note that we have not trained the stitching and retargeting modules for the animals model due to several technical issues. _This may be addressed in future updates._ Therefore, we recommend **disabling stitching by setting the `--no_flag_stitching`** option when running the model. Additionally, `paste-back` is also not recommended.
|
||||
|
||||
#### Install X-Pose
|
||||
We have chosen [X-Pose](https://github.com/IDEA-Research/X-Pose) as the keypoints detector for animals. This relies on `transformers==4.22.0` and `pillow>=10.2.0` (which are already updated in `requirements.txt`) and requires building an OP named `MultiScaleDeformableAttention`.
|
||||
|
||||
Refer to the [PyTorch installation](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#for-linux-or-windows-users) for Linux and Windows users.
|
||||
|
||||
|
||||
Next, build the OP `MultiScaleDeformableAttention` by running:
|
||||
```bash
|
||||
cd src/utils/dependencies/XPose/models/UniPose/ops
|
||||
python setup.py build install
|
||||
cd - # this returns to the previous directory
|
||||
```
|
||||
|
||||
To run the model, use the `inference_animals.py` script:
|
||||
```bash
|
||||
python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --no_flag_stitching --driving_multiplier 1.75
|
||||
```
|
||||
|
||||
Alternatively, you can use Gradio for a more user-friendly interface. Launch it with:
|
||||
```bash
|
||||
python app_animals.py # --server_port 8889 --server_name "0.0.0.0" --share
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> [X-Pose](https://github.com/IDEA-Research/X-Pose) is only for Non-commercial Scientific Research Purposes, you should remove and replace it with other detectors if you use it for commercial purposes.
|
||||
|
||||
### Updates on Humans mode
|
||||
|
||||
- **Driving Options**: We have introduced an `expression-friendly` driving option to **reduce head wobbling**, now set as the default. While it may be less effective with large head poses, you can also select the `pose-friendly` option, which is the same as the previous version. This can be set using `--driving_option` or selected in the Gradio interface. Additionally, we added a `--driving_multiplier` option to adjust driving intensity, with a default value of 1, which can also be set in the Gradio interface.
|
||||
|
||||
- **Retargeting Video in Gradio**: We have implemented a video retargeting feature. You can specify a `target lip-open ratio` to adjust the mouth movement in the source video. For instance, setting it to 0 will close the mouth in the source video 🤐.
|
||||
|
||||
### Others
|
||||
|
||||
- [**Poe supports LivePortrait**](https://poe.com/LivePortrait). Check out the news on [X](https://x.com/poe_platform/status/1816136105781256260).
|
||||
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) (1.1K 🌟) now includes MediaPipe as an alternative to InsightFace, ensuring the license remains under MIT and Apache 2.0.
|
||||
- [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) features real-time portrait pose/expression editing and animation, and is registered with ComfyUI-Manager.
|
||||
|
||||
|
||||
|
||||
**Below are some screenshots of the new features and improvements:**
|
||||
|
||||
|  |
|
||||
|:---:|
|
||||
| **The Gradio Interface of Animals Mode** |
|
||||
|
||||
|  |
|
||||
|:---:|
|
||||
| **Driving Options and Multiplier** |
|
||||
|
||||
|  |
|
||||
|:---:|
|
||||
| **The Feature of Retargeting Video** |
|
18
assets/docs/changelog/2024-08-05.md
Normal file
@ -0,0 +1,18 @@
|
||||
## One-click Windows Installer
|
||||
|
||||
### Download the installer from HuggingFace
|
||||
```bash
|
||||
# !pip install -U "huggingface_hub[cli]"
|
||||
huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip
|
||||
```
|
||||
|
||||
If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download:
|
||||
```bash
|
||||
# !pip install -U "huggingface_hub[cli]"
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip
|
||||
```
|
||||
|
||||
Alternatively, you can manually download it from the [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) page.
|
||||
|
||||
Then, simply unzip the package `LivePortrait-Windows-v20240806.zip` and double-click `run_windows_human.bat` for the Humans mode, or `run_windows_animal.bat` for the **Animals mode**.
|
9
assets/docs/changelog/2024-08-06.md
Normal file
@ -0,0 +1,9 @@
|
||||
## Precise Portrait Editing
|
||||
|
||||
Inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) ([@PowerHouseMan](https://github.com/PowerHouseMan)), we have implemented a version of Precise Portrait Editing in the Gradio interface. With each adjustment of the slider, the edited image updates in real-time. You can click the `🔄 Reset` button to reset all slider parameters. However, the performance may not be as fast as the ComfyUI plugin.
|
||||
|
||||
<p align="center">
|
||||
<img src="../editing-portrait-2024-08-06.jpg" alt="LivePortrait" width="960px">
|
||||
<br>
|
||||
Preciese Portrait Editing in the Gradio Interface
|
||||
</p>
|
28
assets/docs/directory-structure.md
Normal file
@ -0,0 +1,28 @@
|
||||
## The directory structure of `pretrained_weights`
|
||||
|
||||
```text
|
||||
pretrained_weights
|
||||
├── insightface
|
||||
│ └── models
|
||||
│ └── buffalo_l
|
||||
│ ├── 2d106det.onnx
|
||||
│ └── det_10g.onnx
|
||||
├── liveportrait
|
||||
│ ├── base_models
|
||||
│ │ ├── appearance_feature_extractor.pth
|
||||
│ │ ├── motion_extractor.pth
|
||||
│ │ ├── spade_generator.pth
|
||||
│ │ └── warping_module.pth
|
||||
│ ├── landmark.onnx
|
||||
│ └── retargeting_models
|
||||
│ └── stitching_retargeting_module.pth
|
||||
└── liveportrait_animals
|
||||
├── base_models
|
||||
│ ├── appearance_feature_extractor.pth
|
||||
│ ├── motion_extractor.pth
|
||||
│ ├── spade_generator.pth
|
||||
│ └── warping_module.pth
|
||||
├── retargeting_models
|
||||
│ └── stitching_retargeting_module.pth
|
||||
└── xpose.pth
|
||||
```
|
BIN
assets/docs/driving-option-multiplier-2024-08-02.jpg
Normal file
After Width: | Height: | Size: 82 KiB |
BIN
assets/docs/editing-portrait-2024-08-06.jpg
Normal file
After Width: | Height: | Size: 301 KiB |
29
assets/docs/how-to-install-ffmpeg.md
Normal file
@ -0,0 +1,29 @@
|
||||
## Install FFmpeg
|
||||
|
||||
Make sure you have `ffmpeg` and `ffprobe` installed on your system. If you don't have them installed, follow the instructions below.
|
||||
|
||||
> [!Note]
|
||||
> The installation is copied from [SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 🤗
|
||||
|
||||
### Conda Users
|
||||
|
||||
```bash
|
||||
conda install ffmpeg
|
||||
```
|
||||
|
||||
### Ubuntu/Debian Users
|
||||
|
||||
```bash
|
||||
sudo apt install ffmpeg
|
||||
sudo apt install libsox-dev
|
||||
conda install -c conda-forge 'ffmpeg<7'
|
||||
```
|
||||
|
||||
### Windows Users
|
||||
|
||||
Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root.
|
||||
|
||||
### MacOS Users
|
||||
```bash
|
||||
brew install ffmpeg
|
||||
```
|
BIN
assets/docs/inference-animals.gif
Normal file
After Width: | Height: | Size: 491 KiB |
BIN
assets/docs/pose-edit-2024-07-24.jpg
Normal file
After Width: | Height: | Size: 217 KiB |
BIN
assets/docs/retargeting-video-2024-08-02.jpg
Normal file
After Width: | Height: | Size: 115 KiB |
13
assets/docs/speed.md
Normal file
@ -0,0 +1,13 @@
|
||||
### Speed
|
||||
|
||||
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
|
||||
|
||||
| Model | Parameters(M) | Model Size(MB) | Inference(ms) |
|
||||
|-----------------------------------|:-------------:|:--------------:|:-------------:|
|
||||
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
|
||||
| Motion Extractor | 28.12 | 108 | 0.84 |
|
||||
| Spade Generator | 55.37 | 212 | 7.59 |
|
||||
| Warping Module | 45.53 | 174 | 5.21 |
|
||||
| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
|
||||
|
||||
*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*
|
BIN
assets/examples/driving/aggrieved.pkl
Normal file
BIN
assets/examples/driving/d1.pkl
Normal file
BIN
assets/examples/driving/d10.mp4
Normal file
BIN
assets/examples/driving/d11.mp4
Normal file
BIN
assets/examples/driving/d12.mp4
Normal file
BIN
assets/examples/driving/d13.mp4
Normal file
BIN
assets/examples/driving/d14.mp4
Normal file
BIN
assets/examples/driving/d18.mp4
Normal file
BIN
assets/examples/driving/d19.mp4
Normal file
BIN
assets/examples/driving/d2.pkl
Normal file
BIN
assets/examples/driving/d20.mp4
Normal file
BIN
assets/examples/driving/d5.pkl
Normal file
BIN
assets/examples/driving/d7.pkl
Normal file
BIN
assets/examples/driving/d8.pkl
Normal file
BIN
assets/examples/driving/laugh.pkl
Normal file
BIN
assets/examples/driving/open_lip.pkl
Normal file
BIN
assets/examples/driving/shake_face.pkl
Normal file
BIN
assets/examples/driving/shy.pkl
Normal file
BIN
assets/examples/driving/talking.pkl
Normal file
BIN
assets/examples/driving/wink.pkl
Normal file
BIN
assets/examples/source/s11.jpg
Normal file
After Width: | Height: | Size: 102 KiB |
BIN
assets/examples/source/s12.jpg
Normal file
After Width: | Height: | Size: 49 KiB |
BIN
assets/examples/source/s13.mp4
Normal file
BIN
assets/examples/source/s18.mp4
Normal file
BIN
assets/examples/source/s20.mp4
Normal file
BIN
assets/examples/source/s22.jpg
Normal file
After Width: | Height: | Size: 156 KiB |
BIN
assets/examples/source/s23.jpg
Normal file
After Width: | Height: | Size: 82 KiB |
BIN
assets/examples/source/s25.jpg
Normal file
After Width: | Height: | Size: 390 KiB |
BIN
assets/examples/source/s29.mp4
Normal file
BIN
assets/examples/source/s30.jpg
Normal file
After Width: | Height: | Size: 96 KiB |
BIN
assets/examples/source/s31.jpg
Normal file
After Width: | Height: | Size: 99 KiB |
BIN
assets/examples/source/s32.jpg
Normal file
After Width: | Height: | Size: 115 KiB |
BIN
assets/examples/source/s32.mp4
Normal file
BIN
assets/examples/source/s36.jpg
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
assets/examples/source/s38.jpg
Normal file
After Width: | Height: | Size: 500 KiB |
BIN
assets/examples/source/s39.jpg
Normal file
After Width: | Height: | Size: 457 KiB |
BIN
assets/examples/source/s40.jpg
Normal file
After Width: | Height: | Size: 220 KiB |
BIN
assets/examples/source/s41.jpg
Normal file
After Width: | Height: | Size: 111 KiB |
BIN
assets/examples/source/s42.jpg
Normal file
After Width: | Height: | Size: 88 KiB |
6
assets/gradio/gradio_description_animate_clear.md
Normal file
@ -0,0 +1,6 @@
|
||||
<div style="font-size: 1.2em; text-align: center;">
|
||||
Step 3: Click the <strong>🚀 Animate</strong> button below to generate, or click <strong>🧹 Clear</strong> to erase the results
|
||||
</div>
|
||||
<!-- <div style="font-size: 1.1em; text-align: center;">
|
||||
<strong style="color: red;">Note:</strong> If both <strong>Source Image</strong> and <strong>Video</strong> are uploaded, the <strong>Source Image</strong> will be used. Please click the <strong>🧹 Clear</strong> button, then re-upload the <strong>Source Image</strong> or <strong>Video</strong>.
|
||||
</div> -->
|
19
assets/gradio/gradio_description_animation.md
Normal file
@ -0,0 +1,19 @@
|
||||
<span style="font-size: 1.2em;">🔥 To animate the source image or video with the driving video, please follow these steps:</span>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
1. In the <strong>Animation Options for Source Image or Video</strong> section, we recommend enabling the <code>do crop (source)</code> option if faces occupy a small portion of your source image or video.
|
||||
</div>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
2. In the <strong>Animation Options for Driving Video</strong> section, the <code>relative head rotation</code> and <code>smooth strength</code> options only take effect if the source input is a video.
|
||||
</div>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
3. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
|
||||
</div>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
4. If you want to upload your own driving video, <strong>the best practice</strong>:
|
||||
|
||||
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
|
||||
- Focus on the head area, similar to the example videos.
|
||||
- Minimize shoulder movement.
|
||||
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
|
||||
|
||||
</div>
|
13
assets/gradio/gradio_description_retargeting.md
Normal file
@ -0,0 +1,13 @@
|
||||
<br>
|
||||
|
||||
<!-- ## Retargeting -->
|
||||
<!-- <span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span> -->
|
||||
|
||||
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
|
||||
<div>
|
||||
<h2>Retargeting and Editing Portraits</h2>
|
||||
<p>Upload a source portrait, and the <code>eyes-open ratio</code> and <code>lip-open ratio</code> will be auto-calculated. Adjust the sliders to see instant edits. Feel free to experiment! 🎨</p>
|
||||
<strong>😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!</strong></p>
|
||||
</div>
|
||||
</div>
|
9
assets/gradio/gradio_description_retargeting_video.md
Normal file
@ -0,0 +1,9 @@
|
||||
<br>
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
|
||||
<div>
|
||||
<h2>Retargeting Video</h2>
|
||||
<p>Upload a Source Video as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting Video</strong> button. You can try running it multiple times.
|
||||
<br>
|
||||
<strong>🤐 Set target lip-open ratio to 0 to see what's going on!</strong></p>
|
||||
</div>
|
||||
</div>
|
19
assets/gradio/gradio_description_upload.md
Normal file
@ -0,0 +1,19 @@
|
||||
<br>
|
||||
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
|
||||
<div style="flex: 1; text-align: center; margin-right: 20px;">
|
||||
<div style="display: inline-block;">
|
||||
Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
|
||||
</div>
|
||||
<div style="display: inline-block; font-size: 0.8em;">
|
||||
<strong>Note:</strong> Better if Source Video has <strong>the same FPS</strong> as the Driving Video.
|
||||
</div>
|
||||
</div>
|
||||
<div style="flex: 1; text-align: center; margin-left: 20px;">
|
||||
<div style="display: inline-block;">
|
||||
Step 2: Upload a <strong>Driving Video</strong> (any aspect ratio) ⬇️
|
||||
</div>
|
||||
<div style="display: inline-block; font-size: 0.8em;">
|
||||
<strong>Tips:</strong> Focus on the head, minimize shoulder movement, <strong>neutral expression</strong> in first frame.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
16
assets/gradio/gradio_description_upload_animal.md
Normal file
@ -0,0 +1,16 @@
|
||||
<br>
|
||||
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
|
||||
<div style="flex: 1; text-align: center; margin-right: 20px;">
|
||||
<div style="display: inline-block;">
|
||||
Step 1: Upload a <strong>Source Animal Image</strong> (any aspect ratio) ⬇️
|
||||
</div>
|
||||
</div>
|
||||
<div style="flex: 1; text-align: center; margin-left: 20px;">
|
||||
<div style="display: inline-block;">
|
||||
Step 2: Upload a <strong>Driving Pickle</strong> or <strong>Driving Video</strong> (any aspect ratio) ⬇️
|
||||
</div>
|
||||
<div style="display: inline-block; font-size: 0.8em;">
|
||||
<strong>Tips:</strong> Focus on the head, minimize shoulder movement, <strong>neutral expression</strong> in first frame.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
20
assets/gradio/gradio_title.md
Normal file
@ -0,0 +1,20 @@
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
||||
<div>
|
||||
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
|
||||
<!-- <span>Add mimics and lip sync to your static portrait driven by a video</span> -->
|
||||
<!-- <span>Efficient Portrait Animation with Stitching and Retargeting Control</span> -->
|
||||
<!-- <br> -->
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
||||
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
||||
|
||||
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
||||
|
||||
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
||||
|
||||
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
||||
|
||||
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/github/stars/KwaiVGI/LivePortrait
|
||||
"></a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
@ -1,7 +0,0 @@
|
||||
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
|
||||
</div>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
|
||||
</div>
|
@ -1 +0,0 @@
|
||||
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
@ -1,2 +0,0 @@
|
||||
## 🤗 This is the official gradio demo for **LivePortrait**.
|
||||
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
|
@ -1,10 +0,0 @@
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
||||
<div>
|
||||
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
||||
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
||||
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
||||
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
38
inference.py
@ -1,6 +1,12 @@
|
||||
# coding: utf-8
|
||||
"""
|
||||
for human
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import tyro
|
||||
import subprocess
|
||||
from src.config.argument_config import ArgumentConfig
|
||||
from src.config.inference_config import InferenceConfig
|
||||
from src.config.crop_config import CropConfig
|
||||
@ -11,14 +17,40 @@ def partial_fields(target_class, kwargs):
|
||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def fast_check_args(args: ArgumentConfig):
|
||||
if not osp.exists(args.source):
|
||||
raise FileNotFoundError(f"source info not found: {args.source}")
|
||||
if not osp.exists(args.driving):
|
||||
raise FileNotFoundError(f"driving info not found: {args.driving}")
|
||||
|
||||
|
||||
def main():
|
||||
# set tyro theme
|
||||
tyro.extras.set_accent_color("bright_cyan")
|
||||
args = tyro.cli(ArgumentConfig)
|
||||
|
||||
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
|
||||
if osp.exists(ffmpeg_dir):
|
||||
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
|
||||
|
||||
if not fast_check_ffmpeg():
|
||||
raise ImportError(
|
||||
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
|
||||
)
|
||||
|
||||
fast_check_args(args)
|
||||
|
||||
# specify configs for inference
|
||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
||||
crop_cfg = partial_fields(CropConfig, args.__dict__)
|
||||
|
||||
live_portrait_pipeline = LivePortraitPipeline(
|
||||
inference_cfg=inference_cfg,
|
||||
@ -29,5 +61,5 @@ def main():
|
||||
live_portrait_pipeline.execute(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
65
inference_animals.py
Normal file
@ -0,0 +1,65 @@
|
||||
# coding: utf-8
|
||||
"""
|
||||
for animal
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import tyro
|
||||
import subprocess
|
||||
from src.config.argument_config import ArgumentConfig
|
||||
from src.config.inference_config import InferenceConfig
|
||||
from src.config.crop_config import CropConfig
|
||||
from src.live_portrait_pipeline_animal import LivePortraitPipelineAnimal
|
||||
|
||||
|
||||
def partial_fields(target_class, kwargs):
|
||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def fast_check_args(args: ArgumentConfig):
|
||||
if not osp.exists(args.source):
|
||||
raise FileNotFoundError(f"source info not found: {args.source}")
|
||||
if not osp.exists(args.driving):
|
||||
raise FileNotFoundError(f"driving info not found: {args.driving}")
|
||||
|
||||
|
||||
def main():
|
||||
# set tyro theme
|
||||
tyro.extras.set_accent_color("bright_cyan")
|
||||
args = tyro.cli(ArgumentConfig)
|
||||
|
||||
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
|
||||
if osp.exists(ffmpeg_dir):
|
||||
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
|
||||
|
||||
if not fast_check_ffmpeg():
|
||||
raise ImportError(
|
||||
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
|
||||
)
|
||||
|
||||
fast_check_args(args)
|
||||
|
||||
# specify configs for inference
|
||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
||||
crop_cfg = partial_fields(CropConfig, args.__dict__)
|
||||
|
||||
live_portrait_pipeline_animal = LivePortraitPipelineAnimal(
|
||||
inference_cfg=inference_cfg,
|
||||
crop_cfg=crop_cfg
|
||||
)
|
||||
|
||||
# run
|
||||
live_portrait_pipeline_animal.execute(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
218
readme.md
@ -2,9 +2,9 @@
|
||||
|
||||
<div align='center'>
|
||||
<a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup> 
|
||||
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup> 
|
||||
<a href='https://github.com/Mystery099' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup> 
|
||||
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup> 
|
||||
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup> 
|
||||
<a href='https://github.com/zzzweakman' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup> 
|
||||
<a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup> 
|
||||
</div>
|
||||
|
||||
@ -16,6 +16,9 @@
|
||||
<div align='center'>
|
||||
<sup>1 </sup>Kuaishou Technology  <sup>2 </sup>University of Science and Technology of China  <sup>3 </sup>Fudan University 
|
||||
</div>
|
||||
<div align='center'>
|
||||
<small><sup>†</sup> Corresponding author</small>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
<div align="center">
|
||||
@ -23,6 +26,7 @@
|
||||
<a href='https://arxiv.org/pdf/2407.03168'><img src='https://img.shields.io/badge/arXiv-LivePortrait-red'></a>
|
||||
<a href='https://liveportrait.github.io'><img src='https://img.shields.io/badge/Project-LivePortrait-green'></a>
|
||||
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
||||
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/github/stars/KwaiVGI/LivePortrait"></a>
|
||||
</div>
|
||||
<br>
|
||||
|
||||
@ -33,55 +37,102 @@
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
## 🔥 Updates
|
||||
- **`2024/07/04`**: 🔥 We released the initial version of the inference code and models. Continuous updates, stay tuned!
|
||||
- **`2024/07/04`**: 😊 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
|
||||
- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, insipred by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md).
|
||||
- **`2024/08/05`**: 📦 Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md).
|
||||
- **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)!
|
||||
- **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main) or [BaiduYun](https://pan.baidu.com/s/1FWsWqKe0eNfXrwjEhhCqlw?pwd=86q2). Simply unzip and double-click `run_windows.bat` to enjoy!
|
||||
- **`2024/07/24`**: 🎨 We support pose editing for source portraits in the Gradio interface. We’ve also lowered the default detection threshold to increase recall. [Have fun](assets/docs/changelog/2024-07-24.md)!
|
||||
- **`2024/07/19`**: ✨ We support 🎞️ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md).
|
||||
- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
|
||||
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
|
||||
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
|
||||
- **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned!
|
||||
- **`2024/07/04`**: 🔥 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
|
||||
|
||||
## Introduction
|
||||
|
||||
|
||||
## Introduction 📖
|
||||
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
|
||||
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
|
||||
|
||||
## 🔥 Getting Started
|
||||
### 1. Clone the code and prepare the environment
|
||||
## Getting Started 🏁
|
||||
### 1. Clone the code and prepare the environment 🛠️
|
||||
|
||||
> [!Note]
|
||||
> Make sure your system has [`git`](https://git-scm.com/), [`conda`](https://anaconda.org/anaconda/conda), and [`FFmpeg`](https://ffmpeg.org/download.html) installed. For details on FFmpeg installation, see [**how to install FFmpeg**](assets/docs/how-to-install-ffmpeg.md).
|
||||
|
||||
```bash
|
||||
git clone https://github.com/KwaiVGI/LivePortrait
|
||||
cd LivePortrait
|
||||
|
||||
# create env using conda
|
||||
conda create -n LivePortrait python==3.9.18
|
||||
conda create -n LivePortrait python=3.9
|
||||
conda activate LivePortrait
|
||||
# install dependencies with pip
|
||||
```
|
||||
|
||||
#### For Linux or Windows Users
|
||||
[X-Pose](https://github.com/IDEA-Research/X-Pose) requires your `torch` version to be compatible with the CUDA version.
|
||||
|
||||
Firstly, check your current CUDA version by:
|
||||
```bash
|
||||
nvcc -V # example versions: 11.1, 11.8, 12.1, etc.
|
||||
```
|
||||
|
||||
Then, install the corresponding torch version. Here are examples for different CUDA versions. Visit the [PyTorch Official Website](https://pytorch.org/get-started/previous-versions) for installation commands if your CUDA version is not listed:
|
||||
```bash
|
||||
# for CUDA 11.1
|
||||
pip install torch==1.10.1+cu111 torchvision==0.11.2 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
|
||||
# for CUDA 11.8
|
||||
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
|
||||
# for CUDA 12.1
|
||||
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
|
||||
# ...
|
||||
```
|
||||
|
||||
Finally, install the remaining dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. Download pretrained weights
|
||||
Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
|
||||
```text
|
||||
pretrained_weights
|
||||
├── insightface
|
||||
│ └── models
|
||||
│ └── buffalo_l
|
||||
│ ├── 2d106det.onnx
|
||||
│ └── det_10g.onnx
|
||||
└── liveportrait
|
||||
├── base_models
|
||||
│ ├── appearance_feature_extractor.pth
|
||||
│ ├── motion_extractor.pth
|
||||
│ ├── spade_generator.pth
|
||||
│ └── warping_module.pth
|
||||
├── landmark.onnx
|
||||
└── retargeting_models
|
||||
└── stitching_retargeting_module.pth
|
||||
#### For macOS with Apple Silicon Users
|
||||
The [X-Pose](https://github.com/IDEA-Research/X-Pose) dependency does not support macOS, so you can skip its installation. While Humans mode works as usual, Animals mode is not supported. Use the provided requirements file for macOS with Apple Silicon:
|
||||
```bash
|
||||
# for macOS with Apple Silicon users
|
||||
pip install -r requirements_macOS.txt
|
||||
```
|
||||
|
||||
### 2. Download pretrained weights 📥
|
||||
|
||||
The easiest way to download the pretrained weights is from HuggingFace:
|
||||
```bash
|
||||
# !pip install -U "huggingface_hub[cli]"
|
||||
huggingface-cli download KwaiVGI/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs"
|
||||
```
|
||||
|
||||
If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download:
|
||||
```bash
|
||||
# !pip install -U "huggingface_hub[cli]"
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
huggingface-cli download KwaiVGI/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs"
|
||||
```
|
||||
|
||||
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn) (WIP). Unzip and place them in `./pretrained_weights`.
|
||||
|
||||
Ensuring the directory structure is as or contains [**this**](assets/docs/directory-structure.md).
|
||||
|
||||
### 3. Inference 🚀
|
||||
|
||||
#### Fast hands-on (humans) 👤
|
||||
```bash
|
||||
# For Linux and Windows users
|
||||
python inference.py
|
||||
|
||||
# For macOS users with Apple Silicon (Intel is not tested). NOTE: this maybe 20x slower than RTX 4090
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py
|
||||
```
|
||||
|
||||
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result.
|
||||
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image or video, and generated result.
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/docs/inference.gif" alt="image">
|
||||
@ -90,55 +141,122 @@ If the script runs successfully, you will get an output mp4 file named `animatio
|
||||
Or, you can change the input by specifying the `-s` and `-d` arguments:
|
||||
|
||||
```bash
|
||||
# source input is an image
|
||||
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
|
||||
|
||||
# or disable pasting back
|
||||
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
|
||||
# source input is a video ✨
|
||||
python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d0.mp4
|
||||
|
||||
# more options to see
|
||||
python inference.py -h
|
||||
```
|
||||
|
||||
**More interesting results can be found in our [Homepage](https://liveportrait.github.io)** 😊
|
||||
#### Fast hands-on (animals) 🐱🐶
|
||||
Animals mode is ONLY tested on Linux and Windows with NVIDIA GPU.
|
||||
|
||||
### 4. Gradio interface
|
||||
You need to build an OP named `MultiScaleDeformableAttention` first, which is used by [X-Pose](https://github.com/IDEA-Research/X-Pose), a general keypoint detection framework.
|
||||
```bash
|
||||
cd src/utils/dependencies/XPose/models/UniPose/ops
|
||||
python setup.py build install
|
||||
cd - # equal to cd ../../../../../../../
|
||||
```
|
||||
|
||||
We also provide a Gradio interface for a better experience, just run by:
|
||||
Then
|
||||
```bash
|
||||
python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --driving_multiplier 1.75 --no_flag_stitching
|
||||
```
|
||||
If the script runs successfully, you will get an output mp4 file named `animations/s39--wink_concat.mp4`.
|
||||
<p align="center">
|
||||
<img src="./assets/docs/inference-animals.gif" alt="image">
|
||||
</p>
|
||||
|
||||
#### Driving video auto-cropping 📢📢📢
|
||||
> [!IMPORTANT]
|
||||
> To use your own driving video, we **recommend**: ⬇️
|
||||
> - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
|
||||
> - Focus on the head area, similar to the example videos.
|
||||
> - Minimize shoulder movement.
|
||||
> - Make sure the first frame of driving video is a frontal face with **neutral expression**.
|
||||
|
||||
Below is an auto-cropping case by `--flag_crop_driving_video`:
|
||||
```bash
|
||||
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video
|
||||
```
|
||||
|
||||
If you find the results of auto-cropping is not well, you can modify the `--scale_crop_driving_video`, `--vy_ratio_crop_driving_video` options to adjust the scale and offset, or do it manually.
|
||||
|
||||
#### Motion template making
|
||||
You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as:
|
||||
```bash
|
||||
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl # portrait animation
|
||||
python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d5.pkl # portrait video editing
|
||||
```
|
||||
|
||||
### 4. Gradio interface 🤗
|
||||
|
||||
We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
|
||||
|
||||
```bash
|
||||
python app.py
|
||||
# For Linux and Windows users (and macOS with Intel??)
|
||||
python app.py # humans mode
|
||||
|
||||
# For macOS with Apple Silicon users, Intel not supported, this maybe 20x slower than RTX 4090
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py # humans mode
|
||||
```
|
||||
|
||||
We also provide a Gradio interface of animals mode, which is only tested on Linux with NVIDIA GPU:
|
||||
```bash
|
||||
python app_animals.py # animals mode 🐱🐶
|
||||
```
|
||||
|
||||
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
|
||||
|
||||
🚀 We also provide an acceleration option `--flag_do_torch_compile`. The first-time inference triggers an optimization process (about one minute), making subsequent inferences 20-30% faster. Performance gains may vary with different CUDA versions.
|
||||
```bash
|
||||
# enable torch.compile for faster inference
|
||||
python app.py --flag_do_torch_compile
|
||||
```
|
||||
**Note**: This method is not supported on Windows and macOS.
|
||||
|
||||
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
|
||||
|
||||
### 5. Inference speed evaluation 🚀🚀🚀
|
||||
We have also provided a script to evaluate the inference speed of each module:
|
||||
|
||||
```bash
|
||||
# For NVIDIA GPU
|
||||
python speed.py
|
||||
```
|
||||
|
||||
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
|
||||
The results are [**here**](./assets/docs/speed.md).
|
||||
|
||||
| Model | Parameters(M) | Model Size(MB) | Inference(ms) |
|
||||
|-----------------------------------|:-------------:|:--------------:|:-------------:|
|
||||
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
|
||||
| Motion Extractor | 28.12 | 108 | 0.84 |
|
||||
| Spade Generator | 55.37 | 212 | 7.59 |
|
||||
| Warping Module | 45.53 | 174 | 5.21 |
|
||||
| Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 |
|
||||
## Community Resources 🤗
|
||||
|
||||
*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.*
|
||||
Discover the invaluable resources contributed by our community to enhance your LivePortrait experience:
|
||||
|
||||
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai)
|
||||
- [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) by [@PowerHouseMan](https://github.com/PowerHouseMan).
|
||||
- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007)
|
||||
- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker)
|
||||
- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch)
|
||||
- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph)
|
||||
- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr)
|
||||
|
||||
## Acknowledgements
|
||||
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
|
||||
And many more amazing contributions from our community!
|
||||
|
||||
## Acknowledgements 💐
|
||||
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions.
|
||||
|
||||
## Citation 💖
|
||||
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
|
||||
```bibtex
|
||||
@article{guo2024live,
|
||||
@article{guo2024liveportrait,
|
||||
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
|
||||
author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang},
|
||||
year = {2024},
|
||||
journal = {arXiv preprint:2407.03168},
|
||||
author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di},
|
||||
journal = {arXiv preprint arXiv:2407.03168},
|
||||
year = {2024}
|
||||
}
|
||||
```
|
||||
|
||||
## Contact 📧
|
||||
[**Jianzhu Guo (郭建珠)**](https://guojianzhu.com); **guojianzhu1994@gmail.com**
|
||||
|
@ -1,22 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.3.0
|
||||
torchvision==0.18.0
|
||||
torchaudio==2.3.0
|
||||
-r requirements_base.txt
|
||||
|
||||
numpy==1.26.4
|
||||
pyyaml==6.0.1
|
||||
opencv-python==4.10.0.84
|
||||
scipy==1.13.1
|
||||
imageio==2.34.2
|
||||
lmdb==1.4.1
|
||||
tqdm==4.66.4
|
||||
rich==13.7.1
|
||||
ffmpeg==1.4
|
||||
onnxruntime-gpu==1.18.0
|
||||
onnx==1.16.1
|
||||
scikit-image==0.24.0
|
||||
albumentations==1.4.10
|
||||
matplotlib==3.9.0
|
||||
imageio-ffmpeg==0.5.1
|
||||
tyro==0.8.5
|
||||
gradio==4.37.1
|
||||
transformers==4.22.0
|
||||
|
18
requirements_base.txt
Normal file
@ -0,0 +1,18 @@
|
||||
numpy==1.26.4
|
||||
pyyaml==6.0.1
|
||||
opencv-python==4.10.0.84
|
||||
scipy==1.13.1
|
||||
imageio==2.34.2
|
||||
lmdb==1.4.1
|
||||
tqdm==4.66.4
|
||||
rich==13.7.1
|
||||
ffmpeg-python==0.2.0
|
||||
onnx==1.16.1
|
||||
scikit-image==0.24.0
|
||||
albumentations==1.4.10
|
||||
matplotlib==3.9.0
|
||||
imageio-ffmpeg==0.5.1
|
||||
tyro==0.8.5
|
||||
gradio==4.37.1
|
||||
pykalman==0.9.7
|
||||
pillow>=10.2.0
|
7
requirements_macOS.txt
Normal file
@ -0,0 +1,7 @@
|
||||
-r requirements_base.txt
|
||||
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.3.0
|
||||
torchvision==0.18.0
|
||||
torchaudio==2.3.0
|
||||
onnxruntime-silicon==1.16.3
|
33
speed.py
@ -6,25 +6,28 @@ Benchmark the inference speed of each module in LivePortrait.
|
||||
TODO: heavy GPT style, need to refactor
|
||||
"""
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
|
||||
import yaml
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from src.utils.helper import load_model, concat_feat
|
||||
from src.config.inference_config import InferenceConfig
|
||||
|
||||
|
||||
def initialize_inputs(batch_size=1):
|
||||
def initialize_inputs(batch_size=1, device_id=0):
|
||||
"""
|
||||
Generate random input tensors and move them to GPU
|
||||
"""
|
||||
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
|
||||
kp_source = torch.randn(batch_size, 21, 3).cuda().half()
|
||||
kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
|
||||
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
|
||||
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
|
||||
eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
|
||||
lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
|
||||
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
|
||||
kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
|
||||
kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
|
||||
source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
|
||||
generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
|
||||
eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
|
||||
lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
|
||||
feat_stitching = concat_feat(kp_source, kp_driving).half()
|
||||
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
|
||||
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
|
||||
@ -99,7 +102,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
|
||||
Measure inference times for each model
|
||||
"""
|
||||
times = {name: [] for name in compiled_models.keys()}
|
||||
times['Retargeting Models'] = []
|
||||
times['Stitching and Retargeting Modules'] = []
|
||||
|
||||
overall_times = []
|
||||
|
||||
@ -133,7 +136,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
|
||||
stitching_retargeting_module['eye'](inputs['feat_eye'])
|
||||
stitching_retargeting_module['lip'](inputs['feat_lip'])
|
||||
torch.cuda.synchronize()
|
||||
times['Retargeting Models'].append(time.time() - start)
|
||||
times['Stitching and Retargeting Modules'].append(time.time() - start)
|
||||
|
||||
overall_times.append(time.time() - overall_start)
|
||||
|
||||
@ -166,15 +169,15 @@ def main():
|
||||
"""
|
||||
Main function to benchmark speed and model parameters
|
||||
"""
|
||||
# Sample input tensors
|
||||
inputs = initialize_inputs()
|
||||
|
||||
# Load configuration
|
||||
cfg = InferenceConfig(device_id=0)
|
||||
cfg = InferenceConfig()
|
||||
model_config_path = cfg.models_config
|
||||
with open(model_config_path, 'r') as file:
|
||||
model_config = yaml.safe_load(file)
|
||||
|
||||
# Sample input tensors
|
||||
inputs = initialize_inputs(device_id = cfg.device_id)
|
||||
|
||||
# Load and compile models
|
||||
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
|
||||
|
||||
|
@ -1,44 +1,57 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
config for user
|
||||
All configs for user
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
from dataclasses import dataclass
|
||||
import tyro
|
||||
from typing_extensions import Annotated
|
||||
from typing import Optional, Literal
|
||||
from .base_config import PrintableConfig, make_abs_path
|
||||
|
||||
|
||||
@dataclass(repr=False) # use repr from PrintableConfig
|
||||
class ArgumentConfig(PrintableConfig):
|
||||
########## input arguments ##########
|
||||
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
|
||||
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
||||
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human)
|
||||
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
||||
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
||||
#####################################
|
||||
|
||||
########## inference arguments ##########
|
||||
device_id: int = 0
|
||||
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
||||
flag_eye_retargeting: bool = False
|
||||
flag_lip_retargeting: bool = False
|
||||
flag_stitching: bool = True # we recommend setting it to True!
|
||||
flag_relative: bool = True # whether to use relative motion
|
||||
flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
|
||||
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
|
||||
device_id: int = 0 # gpu device id
|
||||
flag_force_cpu: bool = False # force cpu inference, WIP!
|
||||
flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
||||
flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
|
||||
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
|
||||
flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
|
||||
flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
|
||||
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
|
||||
flag_relative_motion: bool = True # whether to use relative motion
|
||||
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
||||
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
|
||||
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
|
||||
driving_option: Literal["expression-friendly", "pose-friendly"] = "expression-friendly" # "expression-friendly" or "pose-friendly"; "expression-friendly" would adapt the driving motion with the global multiplier, and could be used when the source is a human image
|
||||
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
|
||||
########## source crop arguments ##########
|
||||
det_thresh: float = 0.15 # detection threshold
|
||||
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
||||
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
|
||||
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
|
||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
||||
#########################################
|
||||
source_max_dim: int = 1280 # the max dim of height and width of source image or video, you can change it to a larger number, e.g., 1920
|
||||
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
||||
|
||||
########## crop arguments ##########
|
||||
dsize: int = 512
|
||||
scale: float = 2.3
|
||||
vx_ratio: float = 0 # vx ratio
|
||||
vy_ratio: float = -0.125 # vy ratio +up, -down
|
||||
####################################
|
||||
########## driving crop arguments ##########
|
||||
scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video
|
||||
vx_ratio_crop_driving_video: float = 0. # adjust y offset
|
||||
vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
|
||||
|
||||
########## gradio arguments ##########
|
||||
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
|
||||
share: bool = True
|
||||
server_name: str = "0.0.0.0"
|
||||
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
|
||||
share: bool = False # whether to share the server to public
|
||||
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
|
||||
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation
|
||||
gradio_temp_dir: Optional[str] = None # directory to save gradio temp files
|
||||
|
@ -4,15 +4,32 @@
|
||||
parameters used for crop faces
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, List
|
||||
from .base_config import PrintableConfig
|
||||
|
||||
from .base_config import PrintableConfig, make_abs_path
|
||||
|
||||
|
||||
@dataclass(repr=False) # use repr from PrintableConfig
|
||||
class CropConfig(PrintableConfig):
|
||||
insightface_root: str = make_abs_path("../../pretrained_weights/insightface")
|
||||
landmark_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait/landmark.onnx")
|
||||
xpose_config_file_path: str = make_abs_path("../utils/dependencies/XPose/config_model/UniPose_SwinT.py")
|
||||
xpose_embedding_cache_path: str = make_abs_path('../utils/resources/clip_embedding')
|
||||
|
||||
xpose_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait_animals/xpose.pth")
|
||||
device_id: int = 0 # gpu device id
|
||||
flag_force_cpu: bool = False # force cpu inference, WIP
|
||||
det_thresh: float = 0.1 # detection threshold
|
||||
########## source image or video cropping option ##########
|
||||
dsize: int = 512 # crop size
|
||||
scale: float = 2.3 # scale factor
|
||||
vx_ratio: float = 0 # vx ratio
|
||||
vy_ratio: float = -0.125 # vy ratio +up, -down
|
||||
max_face_num: int = 0 # max face number, 0 mean no limit
|
||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
||||
animal_face_type: str = "animal_face_9" # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks
|
||||
########## driving video auto cropping option ##########
|
||||
scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video
|
||||
vx_ratio_crop_driving_video: float = 0.0 # adjust y offset
|
||||
vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
|
||||
direction: str = "large-small" # direction of cropping
|
||||
|
@ -4,46 +4,61 @@
|
||||
config dataclass used for inference
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
from dataclasses import dataclass
|
||||
import cv2
|
||||
from numpy import ndarray
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Tuple
|
||||
from .base_config import PrintableConfig, make_abs_path
|
||||
|
||||
|
||||
@dataclass(repr=False) # use repr from PrintableConfig
|
||||
class InferenceConfig(PrintableConfig):
|
||||
# HUMAN MODEL CONFIG, NOT EXPORTED PARAMS
|
||||
models_config: str = make_abs_path('./models.yaml') # portrait animation config
|
||||
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint
|
||||
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint
|
||||
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint
|
||||
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint
|
||||
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
|
||||
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
|
||||
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G
|
||||
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
|
||||
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
|
||||
|
||||
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint
|
||||
flag_use_half_precision: bool = True # whether to use half precision
|
||||
|
||||
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
||||
lip_zero_threshold: float = 0.03
|
||||
# ANIMAL MODEL CONFIG, NOT EXPORTED PARAMS
|
||||
checkpoint_F_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
|
||||
checkpoint_M_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/motion_extractor.pth') # path to checkpoint pf M
|
||||
checkpoint_G_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/spade_generator.pth') # path to checkpoint of G
|
||||
checkpoint_W_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/warping_module.pth') # path to checkpoint of W
|
||||
checkpoint_S_animal: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip, NOTE: use human temporarily!
|
||||
|
||||
# EXPORTED PARAMS
|
||||
flag_use_half_precision: bool = True
|
||||
flag_crop_driving_video: bool = False
|
||||
device_id: int = 0
|
||||
flag_normalize_lip: bool = True
|
||||
flag_source_video_eye_retargeting: bool = False
|
||||
flag_video_editing_head_rotation: bool = False
|
||||
flag_eye_retargeting: bool = False
|
||||
flag_lip_retargeting: bool = False
|
||||
flag_stitching: bool = True # we recommend setting it to True!
|
||||
flag_stitching: bool = True
|
||||
flag_relative_motion: bool = True
|
||||
flag_pasteback: bool = True
|
||||
flag_do_crop: bool = True
|
||||
flag_do_rot: bool = True
|
||||
flag_force_cpu: bool = False
|
||||
flag_do_torch_compile: bool = False
|
||||
driving_option: str = "pose-friendly" # "expression-friendly" or "pose-friendly"
|
||||
driving_multiplier: float = 1.0
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
source_max_dim: int = 1280 # the max dim of height and width of source image or video
|
||||
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
||||
|
||||
flag_relative: bool = True # whether to use relative motion
|
||||
anchor_frame: int = 0 # set this value if find_best_frame is True
|
||||
# NOT EXPORTED PARAMS
|
||||
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
||||
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
|
||||
anchor_frame: int = 0 # TO IMPLEMENT
|
||||
|
||||
input_shape: Tuple[int, int] = (256, 256) # input shape
|
||||
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
|
||||
output_fps: int = 30 # fps for output video
|
||||
crf: int = 15 # crf for output video
|
||||
output_fps: int = 25 # default output fps
|
||||
|
||||
flag_write_result: bool = True # whether to write output video
|
||||
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
||||
mask_crop = None
|
||||
flag_write_gif: bool = False
|
||||
size_gif: int = 256
|
||||
ref_max_shape: int = 1280
|
||||
ref_shape_n: int = 2
|
||||
|
||||
device_id: int = 0
|
||||
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
|
||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
||||
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
|
||||
size_gif: int = 256 # default gif size, TO IMPLEMENT
|
||||
|
@ -3,15 +3,28 @@
|
||||
"""
|
||||
Pipeline for gradio
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import os
|
||||
import cv2
|
||||
from rich.progress import track
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .config.argument_config import ArgumentConfig
|
||||
from .live_portrait_pipeline import LivePortraitPipeline
|
||||
from .utils.io import load_img_online
|
||||
from .live_portrait_pipeline_animal import LivePortraitPipelineAnimal
|
||||
from .utils.io import load_img_online, load_video, resize_to_limit
|
||||
from .utils.filter import smooth
|
||||
from .utils.rprint import rlog as log
|
||||
from .utils.crop import prepare_paste_back, paste_back
|
||||
from .utils.camera import get_rotation_matrix
|
||||
from .utils.video import get_fps, has_audio_stream, concat_frames, images2video, add_audio_to_video
|
||||
from .utils.helper import is_square_video, mkdir, dct2device, basename
|
||||
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
||||
|
||||
|
||||
def update_args(args, user_args):
|
||||
"""update the args according to user inputs
|
||||
"""
|
||||
@ -20,40 +33,179 @@ def update_args(args, user_args):
|
||||
setattr(args, k, v)
|
||||
return args
|
||||
|
||||
|
||||
class GradioPipeline(LivePortraitPipeline):
|
||||
"""gradio for human
|
||||
"""
|
||||
|
||||
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
|
||||
super().__init__(inference_cfg, crop_cfg)
|
||||
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
||||
self.args = args
|
||||
# for single image retargeting
|
||||
self.start_prepare = False
|
||||
self.f_s_user = None
|
||||
self.x_c_s_info_user = None
|
||||
self.x_s_user = None
|
||||
self.source_lmk_user = None
|
||||
self.mask_ori = None
|
||||
self.img_rgb = None
|
||||
self.crop_M_c2o = None
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs):
|
||||
if eyeball_direction_x > 0:
|
||||
delta_new[0, 11, 0] += eyeball_direction_x * 0.0007
|
||||
delta_new[0, 15, 0] += eyeball_direction_x * 0.001
|
||||
else:
|
||||
delta_new[0, 11, 0] += eyeball_direction_x * 0.001
|
||||
delta_new[0, 15, 0] += eyeball_direction_x * 0.0007
|
||||
|
||||
delta_new[0, 11, 1] += eyeball_direction_y * -0.001
|
||||
delta_new[0, 15, 1] += eyeball_direction_y * -0.001
|
||||
blink = -eyeball_direction_y / 2.
|
||||
|
||||
delta_new[0, 11, 1] += blink * -0.001
|
||||
delta_new[0, 13, 1] += blink * 0.0003
|
||||
delta_new[0, 15, 1] += blink * -0.001
|
||||
delta_new[0, 16, 1] += blink * 0.0003
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_smile(self, smile, delta_new, **kwargs):
|
||||
delta_new[0, 20, 1] += smile * -0.01
|
||||
delta_new[0, 14, 1] += smile * -0.02
|
||||
delta_new[0, 17, 1] += smile * 0.0065
|
||||
delta_new[0, 17, 2] += smile * 0.003
|
||||
delta_new[0, 13, 1] += smile * -0.00275
|
||||
delta_new[0, 16, 1] += smile * -0.00275
|
||||
delta_new[0, 3, 1] += smile * -0.0035
|
||||
delta_new[0, 7, 1] += smile * -0.0035
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_wink(self, wink, delta_new, **kwargs):
|
||||
delta_new[0, 11, 1] += wink * 0.001
|
||||
delta_new[0, 13, 1] += wink * -0.0003
|
||||
delta_new[0, 17, 0] += wink * 0.0003
|
||||
delta_new[0, 17, 1] += wink * 0.0003
|
||||
delta_new[0, 3, 1] += wink * -0.0003
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_eyebrow(self, eyebrow, delta_new, **kwargs):
|
||||
if eyebrow > 0:
|
||||
delta_new[0, 1, 1] += eyebrow * 0.001
|
||||
delta_new[0, 2, 1] += eyebrow * -0.001
|
||||
else:
|
||||
delta_new[0, 1, 0] += eyebrow * -0.001
|
||||
delta_new[0, 2, 0] += eyebrow * 0.001
|
||||
delta_new[0, 1, 1] += eyebrow * 0.0003
|
||||
delta_new[0, 2, 1] += eyebrow * -0.0003
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_lip_variation_zero(self, lip_variation_zero, delta_new, **kwargs):
|
||||
delta_new[0, 19, 0] += lip_variation_zero
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_lip_variation_one(self, lip_variation_one, delta_new, **kwargs):
|
||||
delta_new[0, 14, 1] += lip_variation_one * 0.001
|
||||
delta_new[0, 3, 1] += lip_variation_one * -0.0005
|
||||
delta_new[0, 7, 1] += lip_variation_one * -0.0005
|
||||
delta_new[0, 17, 2] += lip_variation_one * -0.0005
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_lip_variation_two(self, lip_variation_two, delta_new, **kwargs):
|
||||
delta_new[0, 20, 2] += lip_variation_two * -0.001
|
||||
delta_new[0, 20, 1] += lip_variation_two * -0.001
|
||||
delta_new[0, 14, 1] += lip_variation_two * -0.001
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_lip_variation_three(self, lip_variation_three, delta_new, **kwargs):
|
||||
delta_new[0, 19, 1] += lip_variation_three * 0.001
|
||||
delta_new[0, 19, 2] += lip_variation_three * 0.0001
|
||||
delta_new[0, 17, 1] += lip_variation_three * -0.0001
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_mov_x(self, mov_x, delta_new, **kwargs):
|
||||
delta_new[0, 5, 0] += mov_x
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def update_delta_new_mov_y(self, mov_y, delta_new, **kwargs):
|
||||
delta_new[0, 5, 1] += mov_y
|
||||
|
||||
return delta_new
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_video(
|
||||
self,
|
||||
input_image_path,
|
||||
input_video_path,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
):
|
||||
""" for video driven potrait animation
|
||||
input_source_image_path=None,
|
||||
input_source_video_path=None,
|
||||
input_driving_video_pickle_path=None,
|
||||
input_driving_video_path=None,
|
||||
flag_relative_input=True,
|
||||
flag_do_crop_input=True,
|
||||
flag_remap_input=True,
|
||||
flag_stitching_input=True,
|
||||
driving_option_input="pose-friendly",
|
||||
driving_multiplier=1.0,
|
||||
flag_crop_driving_video_input=True,
|
||||
flag_video_editing_head_rotation=False,
|
||||
scale=2.3,
|
||||
vx_ratio=0.0,
|
||||
vy_ratio=-0.125,
|
||||
scale_crop_driving_video=2.2,
|
||||
vx_ratio_crop_driving_video=0.0,
|
||||
vy_ratio_crop_driving_video=-0.1,
|
||||
driving_smooth_observation_variance=3e-7,
|
||||
tab_selection=None,
|
||||
v_tab_selection=None
|
||||
):
|
||||
""" for video-driven portrait animation or video editing
|
||||
"""
|
||||
if input_image_path is not None and input_video_path is not None:
|
||||
if tab_selection == 'Image':
|
||||
input_source_path = input_source_image_path
|
||||
elif tab_selection == 'Video':
|
||||
input_source_path = input_source_video_path
|
||||
else:
|
||||
input_source_path = input_source_image_path
|
||||
|
||||
if v_tab_selection == 'Video':
|
||||
input_driving_path = input_driving_video_path
|
||||
elif v_tab_selection == 'Pickle':
|
||||
input_driving_path = input_driving_video_pickle_path
|
||||
else:
|
||||
input_driving_path = input_driving_video_path
|
||||
|
||||
if input_source_path is not None and input_driving_path is not None:
|
||||
if osp.exists(input_driving_path) and v_tab_selection == 'Video' and is_square_video(input_driving_path) is False:
|
||||
flag_crop_driving_video_input = True
|
||||
log("The driving video is not square, it will be cropped to square automatically.")
|
||||
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
|
||||
|
||||
args_user = {
|
||||
'source_image': input_image_path,
|
||||
'driving_info': input_video_path,
|
||||
'flag_relative': flag_relative_input,
|
||||
'source': input_source_path,
|
||||
'driving': input_driving_path,
|
||||
'flag_relative_motion': flag_relative_input,
|
||||
'flag_do_crop': flag_do_crop_input,
|
||||
'flag_pasteback': flag_remap_input,
|
||||
'flag_stitching': flag_stitching_input,
|
||||
'driving_option': driving_option_input,
|
||||
'driving_multiplier': driving_multiplier,
|
||||
'flag_crop_driving_video': flag_crop_driving_video_input,
|
||||
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
|
||||
'scale': scale,
|
||||
'vx_ratio': vx_ratio,
|
||||
'vy_ratio': vy_ratio,
|
||||
'scale_crop_driving_video': scale_crop_driving_video,
|
||||
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
|
||||
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
|
||||
'driving_smooth_observation_variance': driving_smooth_observation_variance,
|
||||
}
|
||||
# update config from user input
|
||||
self.args = update_args(self.args, args_user)
|
||||
@ -62,79 +214,368 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
# video driven animation
|
||||
video_path, video_path_concat = self.execute(self.args)
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return video_path, video_path_concat,
|
||||
return video_path, video_path_concat
|
||||
else:
|
||||
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
||||
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
|
||||
|
||||
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
||||
@torch.no_grad()
|
||||
def execute_image_retargeting(
|
||||
self,
|
||||
input_eye_ratio: float,
|
||||
input_lip_ratio: float,
|
||||
input_head_pitch_variation: float,
|
||||
input_head_yaw_variation: float,
|
||||
input_head_roll_variation: float,
|
||||
mov_x: float,
|
||||
mov_y: float,
|
||||
mov_z: float,
|
||||
lip_variation_zero: float,
|
||||
lip_variation_one: float,
|
||||
lip_variation_two: float,
|
||||
lip_variation_three: float,
|
||||
smile: float,
|
||||
wink: float,
|
||||
eyebrow: float,
|
||||
eyeball_direction_x: float,
|
||||
eyeball_direction_y: float,
|
||||
input_image,
|
||||
retargeting_source_scale: float,
|
||||
flag_stitching_retargeting_input=True,
|
||||
flag_do_crop_input_retargeting_image=True):
|
||||
""" for single image retargeting
|
||||
"""
|
||||
if input_eye_ratio is None or input_eye_ratio is None:
|
||||
if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
|
||||
raise gr.Error("Invalid relative pose input 💥!", duration=5)
|
||||
# disposable feature
|
||||
f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
|
||||
self.prepare_retargeting_image(
|
||||
input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image)
|
||||
|
||||
if input_eye_ratio is None or input_lip_ratio is None:
|
||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||
elif self.f_s_user is None:
|
||||
if self.start_prepare:
|
||||
raise gr.Error(
|
||||
"The source portrait is under processing 💥! Please wait for a second.",
|
||||
duration=5
|
||||
)
|
||||
else:
|
||||
raise gr.Error(
|
||||
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
|
||||
duration=5
|
||||
)
|
||||
else:
|
||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
|
||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
|
||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
|
||||
num_kp = self.x_s_user.shape[1]
|
||||
# default: use x_s
|
||||
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
||||
# D(W(f_s; x_s, x′_d))
|
||||
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
|
||||
device = self.live_portrait_wrapper.device
|
||||
# inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
x_s_user = x_s_user.to(device)
|
||||
f_s_user = f_s_user.to(device)
|
||||
R_s_user = R_s_user.to(device)
|
||||
R_d_user = R_d_user.to(device)
|
||||
mov_x = torch.tensor(mov_x).to(device)
|
||||
mov_y = torch.tensor(mov_y).to(device)
|
||||
mov_z = torch.tensor(mov_z).to(device)
|
||||
eyeball_direction_x = torch.tensor(eyeball_direction_x).to(device)
|
||||
eyeball_direction_y = torch.tensor(eyeball_direction_y).to(device)
|
||||
smile = torch.tensor(smile).to(device)
|
||||
wink = torch.tensor(wink).to(device)
|
||||
eyebrow = torch.tensor(eyebrow).to(device)
|
||||
lip_variation_zero = torch.tensor(lip_variation_zero).to(device)
|
||||
lip_variation_one = torch.tensor(lip_variation_one).to(device)
|
||||
lip_variation_two = torch.tensor(lip_variation_two).to(device)
|
||||
lip_variation_three = torch.tensor(lip_variation_three).to(device)
|
||||
|
||||
x_c_s = x_s_info['kp'].to(device)
|
||||
delta_new = x_s_info['exp'].to(device)
|
||||
scale_new = x_s_info['scale'].to(device)
|
||||
t_new = x_s_info['t'].to(device)
|
||||
R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user
|
||||
|
||||
if eyeball_direction_x != 0 or eyeball_direction_y != 0:
|
||||
delta_new = self.update_delta_new_eyeball_direction(eyeball_direction_x, eyeball_direction_y, delta_new)
|
||||
if smile != 0:
|
||||
delta_new = self.update_delta_new_smile(smile, delta_new)
|
||||
if wink != 0:
|
||||
delta_new = self.update_delta_new_wink(wink, delta_new)
|
||||
if eyebrow != 0:
|
||||
delta_new = self.update_delta_new_eyebrow(eyebrow, delta_new)
|
||||
if lip_variation_zero != 0:
|
||||
delta_new = self.update_delta_new_lip_variation_zero(lip_variation_zero, delta_new)
|
||||
if lip_variation_one != 0:
|
||||
delta_new = self.update_delta_new_lip_variation_one(lip_variation_one, delta_new)
|
||||
if lip_variation_two != 0:
|
||||
delta_new = self.update_delta_new_lip_variation_two(lip_variation_two, delta_new)
|
||||
if lip_variation_three != 0:
|
||||
delta_new = self.update_delta_new_lip_variation_three(lip_variation_three, delta_new)
|
||||
if mov_x != 0:
|
||||
delta_new = self.update_delta_new_mov_x(-mov_x, delta_new)
|
||||
if mov_y !=0 :
|
||||
delta_new = self.update_delta_new_mov_y(mov_y, delta_new)
|
||||
|
||||
x_d_new = mov_z * scale_new * (x_c_s @ R_d_new + delta_new) + t_new
|
||||
eyes_delta, lip_delta = None, None
|
||||
if input_eye_ratio != self.source_eye_ratio:
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user)
|
||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
|
||||
if input_lip_ratio != self.source_lip_ratio:
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
|
||||
x_d_new = x_d_new + \
|
||||
(eyes_delta if eyes_delta is not None else 0) + \
|
||||
(lip_delta if lip_delta is not None else 0)
|
||||
|
||||
if flag_stitching_retargeting_input:
|
||||
x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new)
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
|
||||
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
if flag_do_crop_input_retargeting_image:
|
||||
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
|
||||
else:
|
||||
out_to_ori_blend = out
|
||||
return out, out_to_ori_blend
|
||||
|
||||
|
||||
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
|
||||
@torch.no_grad()
|
||||
def prepare_retargeting_image(
|
||||
self,
|
||||
input_image,
|
||||
input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation,
|
||||
retargeting_source_scale,
|
||||
flag_do_crop=True):
|
||||
""" for single image retargeting
|
||||
"""
|
||||
if input_image_path is not None:
|
||||
gr.Info("Upload successfully!", duration=2)
|
||||
self.start_prepare = True
|
||||
inference_cfg = self.live_portrait_wrapper.cfg
|
||||
if input_image is not None:
|
||||
# gr.Info("Upload successfully!", duration=2)
|
||||
args_user = {'scale': retargeting_source_scale}
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
######## process source portrait ########
|
||||
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
||||
log(f"Load source image from {input_image_path}.")
|
||||
crop_info = self.cropper.crop_single_image(img_rgb)
|
||||
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
|
||||
if flag_do_crop:
|
||||
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
|
||||
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
||||
source_lmk_user = crop_info['lmk_crop']
|
||||
crop_M_c2o = crop_info['M_c2o']
|
||||
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||
else:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
||||
source_lmk_user = self.cropper.calc_lmk_from_cropped_image(img_rgb)
|
||||
crop_M_c2o = None
|
||||
mask_ori = None
|
||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
x_d_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
|
||||
x_d_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
|
||||
x_d_info_user_roll = x_s_info['roll'] + input_head_roll_variation
|
||||
R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
R_d_user = get_rotation_matrix(x_d_info_user_pitch, x_d_info_user_yaw, x_d_info_user_roll)
|
||||
############################################
|
||||
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
|
||||
else:
|
||||
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
|
||||
|
||||
# record global info for next time use
|
||||
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
self.x_s_info_user = x_s_info
|
||||
self.source_lmk_user = crop_info['lmk_crop']
|
||||
self.img_rgb = img_rgb
|
||||
self.crop_M_c2o = crop_info['M_c2o']
|
||||
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||
# update slider
|
||||
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
|
||||
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
|
||||
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
|
||||
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
|
||||
# for vis
|
||||
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
|
||||
return eye_close_ratio, lip_close_ratio, self.I_s_vis
|
||||
@torch.no_grad()
|
||||
def init_retargeting_image(self, retargeting_source_scale: float, source_eye_ratio: float, source_lip_ratio:float, input_image = None):
|
||||
""" initialize the retargeting slider
|
||||
"""
|
||||
if input_image != None:
|
||||
args_user = {'scale': retargeting_source_scale}
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
# inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
######## process source portrait ########
|
||||
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
|
||||
log(f"Load source image from {input_image}.")
|
||||
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
|
||||
if crop_info is None:
|
||||
raise gr.Error("Source portrait NO face detected", duration=2)
|
||||
source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None])
|
||||
source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
|
||||
self.source_eye_ratio = round(float(source_eye_ratio.mean()), 2)
|
||||
self.source_lip_ratio = round(float(source_lip_ratio[0][0]), 2)
|
||||
log("Calculating eyes-open and lip-open ratios successfully!")
|
||||
return self.source_eye_ratio, self.source_lip_ratio
|
||||
else:
|
||||
return source_eye_ratio, source_lip_ratio
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True):
|
||||
""" retargeting the lip-open ratio of each source frame
|
||||
"""
|
||||
# disposable feature
|
||||
device = self.live_portrait_wrapper.device
|
||||
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
|
||||
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
|
||||
|
||||
if input_lip_ratio is None:
|
||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||
else:
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
|
||||
I_p_pstbk_lst = None
|
||||
if flag_do_crop_input_retargeting_video:
|
||||
I_p_pstbk_lst = []
|
||||
I_p_lst = []
|
||||
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
|
||||
x_s_user_i = x_s_user_lst[i].to(device)
|
||||
f_s_user_i = f_s_user_lst[i].to(device)
|
||||
|
||||
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
|
||||
x_d_i_new = x_s_user_i + lip_delta_retargeting
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
|
||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
I_p_lst.append(I_p_i)
|
||||
|
||||
if flag_do_crop_input_retargeting_video:
|
||||
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
|
||||
I_p_pstbk_lst.append(I_p_pstbk)
|
||||
|
||||
mkdir(self.args.output_dir)
|
||||
flag_source_has_audio = has_audio_stream(input_video)
|
||||
|
||||
######### build the final concatenation result #########
|
||||
# source frame | generation
|
||||
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
|
||||
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
|
||||
|
||||
if flag_source_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
|
||||
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=source_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio:
|
||||
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
|
||||
add_audio_to_video(wfp, input_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return wfp_concat, wfp
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
|
||||
""" for video retargeting
|
||||
"""
|
||||
if input_video is not None:
|
||||
# gr.Info("Upload successfully!", duration=2)
|
||||
args_user = {'scale': retargeting_source_scale}
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
######## process source video ########
|
||||
source_rgb_lst = load_video(input_video)
|
||||
source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
|
||||
source_fps = int(get_fps(input_video))
|
||||
n_frames = len(source_rgb_lst)
|
||||
log(f"Load source video from {input_video}. FPS is {source_fps}")
|
||||
|
||||
if flag_do_crop:
|
||||
ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
|
||||
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
|
||||
if len(ret_s["frame_crop_lst"]) != n_frames:
|
||||
n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
|
||||
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
|
||||
mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
|
||||
else:
|
||||
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
|
||||
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
|
||||
source_M_c2o_lst, mask_ori_lst = None, None
|
||||
|
||||
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
|
||||
# save the motion template
|
||||
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
|
||||
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
|
||||
|
||||
c_d_lip_retargeting = [input_lip_ratio]
|
||||
f_s_user_lst, x_s_user_lst, lip_delta_retargeting_lst = [], [], []
|
||||
for i in track(range(n_frames), description='Preparing retargeting video...', total=n_frames):
|
||||
x_s_info = source_template_dct['motion'][i]
|
||||
x_s_info = dct2device(x_s_info, device)
|
||||
x_s_user = x_s_info['x_s']
|
||||
|
||||
source_lmk = source_lmk_crop_lst[i]
|
||||
img_crop_256x256 = img_crop_256x256_lst[i]
|
||||
I_s = I_s_lst[i]
|
||||
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
|
||||
combined_lip_ratio_tensor_retargeting = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_retargeting, source_lmk)
|
||||
lip_delta_retargeting = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor_retargeting)
|
||||
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32))
|
||||
lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting)
|
||||
|
||||
|
||||
return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames
|
||||
else:
|
||||
# when press the clear button, go here
|
||||
return 0.8, 0.8, self.I_s_vis
|
||||
raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
|
||||
|
||||
class GradioPipelineAnimal(LivePortraitPipelineAnimal):
|
||||
"""gradio for animal
|
||||
"""
|
||||
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
|
||||
inference_cfg.flag_crop_driving_video = True # ensure the face_analysis_wrapper is enabled
|
||||
super().__init__(inference_cfg, crop_cfg)
|
||||
# self.live_portrait_wrapper_animal = self.live_portrait_wrapper_animal
|
||||
self.args = args
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_video(
|
||||
self,
|
||||
input_source_image_path=None,
|
||||
input_driving_video_path=None,
|
||||
input_driving_video_pickle_path=None,
|
||||
flag_do_crop_input=False,
|
||||
flag_remap_input=False,
|
||||
driving_multiplier=1.0,
|
||||
flag_stitching=False,
|
||||
flag_crop_driving_video_input=False,
|
||||
scale=2.3,
|
||||
vx_ratio=0.0,
|
||||
vy_ratio=-0.125,
|
||||
scale_crop_driving_video=2.2,
|
||||
vx_ratio_crop_driving_video=0.0,
|
||||
vy_ratio_crop_driving_video=-0.1,
|
||||
tab_selection=None,
|
||||
):
|
||||
""" for video-driven potrait animation
|
||||
"""
|
||||
input_source_path = input_source_image_path
|
||||
|
||||
if tab_selection == 'Video':
|
||||
input_driving_path = input_driving_video_path
|
||||
elif tab_selection == 'Pickle':
|
||||
input_driving_path = input_driving_video_pickle_path
|
||||
else:
|
||||
input_driving_path = input_driving_video_pickle_path
|
||||
|
||||
if input_source_path is not None and input_driving_path is not None:
|
||||
if osp.exists(input_driving_path) and tab_selection == 'Video' and is_square_video(input_driving_path) is False:
|
||||
flag_crop_driving_video_input = True
|
||||
log("The driving video is not square, it will be cropped to square automatically.")
|
||||
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
|
||||
|
||||
args_user = {
|
||||
'source': input_source_path,
|
||||
'driving': input_driving_path,
|
||||
'flag_do_crop': flag_do_crop_input,
|
||||
'flag_pasteback': flag_remap_input,
|
||||
'driving_multiplier': driving_multiplier,
|
||||
'flag_stitching': flag_stitching,
|
||||
'flag_crop_driving_video': flag_crop_driving_video_input,
|
||||
'scale': scale,
|
||||
'vx_ratio': vx_ratio,
|
||||
'vy_ratio': vy_ratio,
|
||||
'scale_crop_driving_video': scale_crop_driving_video,
|
||||
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
|
||||
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
|
||||
}
|
||||
# update config from user input
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.live_portrait_wrapper_animal.update_config(self.args.__dict__)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
# video driven animation
|
||||
video_path, video_path_concat, video_gif_path = self.execute(self.args)
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return video_path, video_path_concat, video_gif_path
|
||||
else:
|
||||
raise gr.Error("Please upload the source animal image, and driving video 🤗🤗🤗", duration=5)
|
||||
|
@ -1,16 +1,15 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Pipeline of LivePortrait
|
||||
Pipeline of LivePortrait (Human)
|
||||
"""
|
||||
|
||||
# TODO:
|
||||
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
|
||||
# 2. pick样例图 source + driving
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
|
||||
|
||||
import cv2
|
||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
import os.path as osp
|
||||
from rich.progress import track
|
||||
|
||||
@ -19,12 +18,13 @@ from .config.inference_config import InferenceConfig
|
||||
from .config.crop_config import CropConfig
|
||||
from .utils.cropper import Cropper
|
||||
from .utils.camera import get_rotation_matrix
|
||||
from .utils.video import images2video, concat_frames
|
||||
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
||||
from .utils.retargeting_utils import calc_lip_close_ratio
|
||||
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
|
||||
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
|
||||
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
|
||||
from .utils.crop import prepare_paste_back, paste_back
|
||||
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
|
||||
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, is_square_video, calc_motion_multiplier
|
||||
from .utils.filter import smooth
|
||||
from .utils.rprint import rlog as log
|
||||
# from .utils.viz import viz_lmk
|
||||
from .live_portrait_wrapper import LivePortraitWrapper
|
||||
|
||||
|
||||
@ -35,156 +35,388 @@ def make_abs_path(fn):
|
||||
class LivePortraitPipeline(object):
|
||||
|
||||
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
||||
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
||||
self.cropper = Cropper(crop_cfg=crop_cfg)
|
||||
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
|
||||
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
|
||||
|
||||
def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs):
|
||||
n_frames = I_lst.shape[0]
|
||||
template_dct = {
|
||||
'n_frames': n_frames,
|
||||
'output_fps': kwargs.get('output_fps', 25),
|
||||
'motion': [],
|
||||
'c_eyes_lst': [],
|
||||
'c_lip_lst': [],
|
||||
}
|
||||
|
||||
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
|
||||
# collect s, R, δ and t for inference
|
||||
I_i = I_lst[i]
|
||||
x_i_info = self.live_portrait_wrapper.get_kp_info(I_i)
|
||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_i_info)
|
||||
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
|
||||
|
||||
item_dct = {
|
||||
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
|
||||
'R': R_i.cpu().numpy().astype(np.float32),
|
||||
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
|
||||
't': x_i_info['t'].cpu().numpy().astype(np.float32),
|
||||
'kp': x_i_info['kp'].cpu().numpy().astype(np.float32),
|
||||
'x_s': x_s.cpu().numpy().astype(np.float32),
|
||||
}
|
||||
|
||||
template_dct['motion'].append(item_dct)
|
||||
|
||||
c_eyes = c_eyes_lst[i].astype(np.float32)
|
||||
template_dct['c_eyes_lst'].append(c_eyes)
|
||||
|
||||
c_lip = c_lip_lst[i].astype(np.float32)
|
||||
template_dct['c_lip_lst'].append(c_lip)
|
||||
|
||||
|
||||
return template_dct
|
||||
|
||||
def execute(self, args: ArgumentConfig):
|
||||
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
||||
######## process source portrait ########
|
||||
img_rgb = load_image_rgb(args.source_image)
|
||||
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
||||
log(f"Load source image from {args.source_image}")
|
||||
crop_info = self.cropper.crop_single_image(img_rgb)
|
||||
source_lmk = crop_info['lmk_crop']
|
||||
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
||||
if inference_cfg.flag_do_crop:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||
else:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||
x_c_s = x_s_info['kp']
|
||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
# for convenience
|
||||
inf_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
device = self.live_portrait_wrapper.device
|
||||
crop_cfg = self.cropper.crop_cfg
|
||||
|
||||
if inference_cfg.flag_lip_zero:
|
||||
# let lip-open scalar to be 0 at first
|
||||
c_d_lip_before_animation = [0.]
|
||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
||||
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
|
||||
inference_cfg.flag_lip_zero = False
|
||||
else:
|
||||
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
||||
############################################
|
||||
######## load source input ########
|
||||
flag_is_source_video = False
|
||||
source_fps = None
|
||||
if is_image(args.source):
|
||||
flag_is_source_video = False
|
||||
img_rgb = load_image_rgb(args.source)
|
||||
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
|
||||
log(f"Load source image from {args.source}")
|
||||
source_rgb_lst = [img_rgb]
|
||||
elif is_video(args.source):
|
||||
flag_is_source_video = True
|
||||
source_rgb_lst = load_video(args.source)
|
||||
source_rgb_lst = [resize_to_limit(img, inf_cfg.source_max_dim, inf_cfg.source_division) for img in source_rgb_lst]
|
||||
source_fps = int(get_fps(args.source))
|
||||
log(f"Load source video from {args.source}, FPS is {source_fps}")
|
||||
else: # source input is an unknown format
|
||||
raise Exception(f"Unknown source format: {args.source}")
|
||||
|
||||
######## process driving info ########
|
||||
if is_video(args.driving_info):
|
||||
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
|
||||
# TODO: 这里track一下驱动视频 -> 构建模板
|
||||
driving_rgb_lst = load_driving_info(args.driving_info)
|
||||
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
|
||||
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
|
||||
n_frames = I_d_lst.shape[0]
|
||||
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
|
||||
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
||||
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
||||
elif is_template(args.driving_info):
|
||||
log(f"Load from video templates {args.driving_info}")
|
||||
with open(args.driving_info, 'rb') as f:
|
||||
template_lst, driving_lmk_lst = pickle.load(f)
|
||||
n_frames = template_lst[0]['n_frames']
|
||||
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
||||
flag_load_from_template = is_template(args.driving)
|
||||
driving_rgb_crop_256x256_lst = None
|
||||
wfp_template = None
|
||||
|
||||
if flag_load_from_template:
|
||||
# NOTE: load from template, it is fast, but the cropping video is None
|
||||
log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
|
||||
driving_template_dct = load(args.driving)
|
||||
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
|
||||
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
|
||||
driving_n_frames = driving_template_dct['n_frames']
|
||||
if flag_is_source_video:
|
||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||
else:
|
||||
n_frames = driving_n_frames
|
||||
|
||||
# set output_fps
|
||||
output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps)
|
||||
log(f'The FPS of template: {output_fps}')
|
||||
|
||||
if args.flag_crop_driving_video:
|
||||
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
||||
|
||||
elif osp.exists(args.driving) and is_video(args.driving):
|
||||
# load from video file, AND make motion template
|
||||
output_fps = int(get_fps(args.driving))
|
||||
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
||||
|
||||
driving_rgb_lst = load_video(args.driving)
|
||||
driving_n_frames = len(driving_rgb_lst)
|
||||
|
||||
######## make motion template ########
|
||||
log("Start making driving motion template...")
|
||||
if flag_is_source_video:
|
||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||
driving_rgb_lst = driving_rgb_lst[:n_frames]
|
||||
else:
|
||||
n_frames = driving_n_frames
|
||||
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
|
||||
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
|
||||
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
|
||||
if len(ret_d["frame_crop_lst"]) is not n_frames:
|
||||
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
|
||||
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
|
||||
else:
|
||||
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
|
||||
#######################################
|
||||
|
||||
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_ratio(driving_lmk_crop_lst)
|
||||
# save the motion template
|
||||
I_d_lst = self.live_portrait_wrapper.prepare_videos(driving_rgb_crop_256x256_lst)
|
||||
driving_template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
|
||||
|
||||
wfp_template = remove_suffix(args.driving) + '.pkl'
|
||||
dump(wfp_template, driving_template_dct)
|
||||
log(f"Dump motion template to {wfp_template}")
|
||||
|
||||
else:
|
||||
raise Exception("Unsupported driving types!")
|
||||
#########################################
|
||||
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
|
||||
|
||||
######## prepare for pasteback ########
|
||||
if inference_cfg.flag_pasteback:
|
||||
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||
I_p_paste_lst = []
|
||||
#########################################
|
||||
I_p_pstbk_lst = None
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
I_p_pstbk_lst = []
|
||||
log("Prepared pasteback mask done.")
|
||||
|
||||
I_p_lst = []
|
||||
R_d_0, x_d_0_info = None, None
|
||||
for i in track(range(n_frames), description='Animating...', total=n_frames):
|
||||
if is_video(args.driving_info):
|
||||
# extract kp info by M
|
||||
I_d_i = I_d_lst[i]
|
||||
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
||||
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
||||
else:
|
||||
# from template
|
||||
x_d_i_info = template_lst[i]
|
||||
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
|
||||
R_d_i = x_d_i_info['R_d']
|
||||
flag_normalize_lip = inf_cfg.flag_normalize_lip # not overwrite
|
||||
flag_source_video_eye_retargeting = inf_cfg.flag_source_video_eye_retargeting # not overwrite
|
||||
lip_delta_before_animation, eye_delta_before_animation = None, None
|
||||
|
||||
if i == 0:
|
||||
######## process source info ########
|
||||
if flag_is_source_video:
|
||||
log(f"Start making source motion template...")
|
||||
|
||||
source_rgb_lst = source_rgb_lst[:n_frames]
|
||||
if inf_cfg.flag_do_crop:
|
||||
ret_s = self.cropper.crop_source_video(source_rgb_lst, crop_cfg)
|
||||
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
|
||||
if len(ret_s["frame_crop_lst"]) is not n_frames:
|
||||
n_frames = min(n_frames, len(ret_s["frame_crop_lst"]))
|
||||
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
|
||||
else:
|
||||
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
|
||||
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
|
||||
|
||||
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
|
||||
# save the motion template
|
||||
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
|
||||
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
|
||||
|
||||
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
|
||||
if inf_cfg.flag_relative_motion:
|
||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else:
|
||||
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
|
||||
else: # if the input is a source image, process it only once
|
||||
if inf_cfg.flag_do_crop:
|
||||
crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
|
||||
if crop_info is None:
|
||||
raise Exception("No face detected in the source image!")
|
||||
source_lmk = crop_info['lmk_crop']
|
||||
img_crop_256x256 = crop_info['img_crop_256x256']
|
||||
else:
|
||||
source_lmk = self.cropper.calc_lmk_from_cropped_image(source_rgb_lst[0])
|
||||
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||
x_c_s = x_s_info['kp']
|
||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
|
||||
# let lip-open scalar to be 0 at first
|
||||
if flag_normalize_lip and inf_cfg.flag_relative_motion and source_lmk is not None:
|
||||
c_d_lip_before_animation = [0.]
|
||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
||||
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
|
||||
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
||||
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
|
||||
|
||||
######## animate ########
|
||||
log(f"The animated video consists of {n_frames} frames.")
|
||||
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
|
||||
if flag_is_source_video: # source video
|
||||
x_s_info = source_template_dct['motion'][i]
|
||||
x_s_info = dct2device(x_s_info, device)
|
||||
|
||||
source_lmk = source_lmk_crop_lst[i]
|
||||
img_crop_256x256 = img_crop_256x256_lst[i]
|
||||
I_s = I_s_lst[i]
|
||||
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
|
||||
x_c_s = x_s_info['kp']
|
||||
R_s = x_s_info['R']
|
||||
x_s =x_s_info['x_s']
|
||||
|
||||
# let lip-open scalar to be 0 at first if the input is a video
|
||||
if flag_normalize_lip and inf_cfg.flag_relative_motion and source_lmk is not None:
|
||||
c_d_lip_before_animation = [0.]
|
||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
||||
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
|
||||
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
||||
else:
|
||||
lip_delta_before_animation = None
|
||||
|
||||
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
|
||||
if flag_source_video_eye_retargeting and source_lmk is not None:
|
||||
if i == 0:
|
||||
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
|
||||
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
|
||||
if c_d_eye_before_animation_frame_zero[0][0] < inf_cfg.source_video_eye_retargeting_threshold:
|
||||
c_d_eye_before_animation_frame_zero = [[0.39]]
|
||||
combined_eye_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk)
|
||||
eye_delta_before_animation = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
|
||||
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
|
||||
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
|
||||
|
||||
x_d_i_info = driving_template_dct['motion'][i]
|
||||
x_d_i_info = dct2device(x_d_i_info, device)
|
||||
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
|
||||
|
||||
if i == 0: # cache the first frame
|
||||
R_d_0 = R_d_i
|
||||
x_d_0_info = x_d_i_info
|
||||
|
||||
if inference_cfg.flag_relative:
|
||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||
if inf_cfg.flag_relative_motion:
|
||||
if flag_is_source_video:
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
R_new = x_d_r_lst_smooth[i]
|
||||
else:
|
||||
R_new = R_s
|
||||
else:
|
||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||
else:
|
||||
R_new = R_d_i
|
||||
delta_new = x_d_i_info['exp']
|
||||
if flag_is_source_video:
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
R_new = x_d_r_lst_smooth[i]
|
||||
else:
|
||||
R_new = R_s
|
||||
else:
|
||||
R_new = R_d_i
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
|
||||
scale_new = x_s_info['scale']
|
||||
t_new = x_d_i_info['t']
|
||||
|
||||
t_new[..., 2].fill_(0) # zero tz
|
||||
t_new[..., 2].fill_(0) # zero tz
|
||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||
|
||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video:
|
||||
if i == 0:
|
||||
x_d_0_new = x_d_i_new
|
||||
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
||||
# motion_multiplier *= inf_cfg.driving_multiplier
|
||||
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
|
||||
x_d_i_new = x_d_diff + x_s
|
||||
|
||||
# Algorithm 1:
|
||||
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
||||
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
|
||||
# without stitching or retargeting
|
||||
if inference_cfg.flag_lip_zero:
|
||||
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
||||
if flag_normalize_lip and lip_delta_before_animation is not None:
|
||||
x_d_i_new += lip_delta_before_animation
|
||||
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
||||
x_d_i_new += eye_delta_before_animation
|
||||
else:
|
||||
pass
|
||||
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
||||
elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
|
||||
# with stitching and without retargeting
|
||||
if inference_cfg.flag_lip_zero:
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
||||
if flag_normalize_lip and lip_delta_before_animation is not None:
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation
|
||||
else:
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
||||
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
||||
x_d_i_new += eye_delta_before_animation
|
||||
else:
|
||||
eyes_delta, lip_delta = None, None
|
||||
if inference_cfg.flag_eye_retargeting:
|
||||
c_d_eyes_i = input_eye_ratio_lst[i]
|
||||
if inf_cfg.flag_eye_retargeting and source_lmk is not None:
|
||||
c_d_eyes_i = c_d_eyes_lst[i]
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
||||
if inference_cfg.flag_lip_retargeting:
|
||||
c_d_lip_i = input_lip_ratio_lst[i]
|
||||
if inf_cfg.flag_lip_retargeting and source_lmk is not None:
|
||||
c_d_lip_i = c_d_lip_lst[i]
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
||||
|
||||
if inference_cfg.flag_relative: # use x_s
|
||||
if inf_cfg.flag_relative_motion: # use x_s
|
||||
x_d_i_new = x_s + \
|
||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
||||
(eyes_delta if eyes_delta is not None else 0) + \
|
||||
(lip_delta if lip_delta is not None else 0)
|
||||
else: # use x_d,i
|
||||
x_d_i_new = x_d_i_new + \
|
||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
||||
(eyes_delta if eyes_delta is not None else 0) + \
|
||||
(lip_delta if lip_delta is not None else 0)
|
||||
|
||||
if inference_cfg.flag_stitching:
|
||||
if inf_cfg.flag_stitching:
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
||||
|
||||
x_d_i_new = x_s + (x_d_i_new - x_s) * inf_cfg.driving_multiplier
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
I_p_lst.append(I_p_i)
|
||||
|
||||
if inference_cfg.flag_pasteback:
|
||||
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
||||
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
# TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU
|
||||
if flag_is_source_video:
|
||||
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float)
|
||||
else:
|
||||
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], source_rgb_lst[0], mask_ori_float)
|
||||
I_p_pstbk_lst.append(I_p_pstbk)
|
||||
|
||||
mkdir(args.output_dir)
|
||||
wfp_concat = None
|
||||
if is_video(args.driving_info):
|
||||
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
|
||||
# save (driving frames, source image, drived frames) result
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
||||
images2video(frames_concatenated, wfp=wfp_concat)
|
||||
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
|
||||
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
||||
|
||||
# save drived result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
||||
if inference_cfg.flag_pasteback:
|
||||
images2video(I_p_paste_lst, wfp=wfp)
|
||||
######### build the final concatenation result #########
|
||||
# driving frame | source frame | generation, or source frame | generation
|
||||
if flag_is_source_video:
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp)
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
||||
|
||||
# NOTE: update output fps
|
||||
output_fps = source_fps if flag_is_source_video else output_fps
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}, concat mode")
|
||||
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}")
|
||||
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
|
||||
# final log
|
||||
if wfp_template not in (None, ''):
|
||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||
log(f'Animated video: {wfp}')
|
||||
log(f'Animated video with concat: {wfp_concat}')
|
||||
|
||||
return wfp, wfp_concat
|
||||
|
237
src/live_portrait_pipeline_animal.py
Normal file
@ -0,0 +1,237 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Pipeline of LivePortrait (Animal)
|
||||
"""
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.")
|
||||
warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly.")
|
||||
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
|
||||
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
|
||||
|
||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
from rich.progress import track
|
||||
|
||||
from .config.argument_config import ArgumentConfig
|
||||
from .config.inference_config import InferenceConfig
|
||||
from .config.crop_config import CropConfig
|
||||
from .utils.cropper import Cropper
|
||||
from .utils.camera import get_rotation_matrix
|
||||
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream, video2gif
|
||||
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
||||
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
|
||||
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, calc_motion_multiplier
|
||||
from .utils.rprint import rlog as log
|
||||
# from .utils.viz import viz_lmk
|
||||
from .live_portrait_wrapper import LivePortraitWrapperAnimal
|
||||
|
||||
|
||||
def make_abs_path(fn):
|
||||
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
||||
|
||||
class LivePortraitPipelineAnimal(object):
|
||||
|
||||
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
||||
self.live_portrait_wrapper_animal: LivePortraitWrapperAnimal = LivePortraitWrapperAnimal(inference_cfg=inference_cfg)
|
||||
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg, image_type='animal_face', flag_use_half_precision=inference_cfg.flag_use_half_precision)
|
||||
|
||||
def make_motion_template(self, I_lst, **kwargs):
|
||||
n_frames = I_lst.shape[0]
|
||||
template_dct = {
|
||||
'n_frames': n_frames,
|
||||
'output_fps': kwargs.get('output_fps', 25),
|
||||
'motion': [],
|
||||
}
|
||||
|
||||
for i in track(range(n_frames), description='Making driving motion templates...', total=n_frames):
|
||||
# collect s, R, δ and t for inference
|
||||
I_i = I_lst[i]
|
||||
x_i_info = self.live_portrait_wrapper_animal.get_kp_info(I_i)
|
||||
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
|
||||
|
||||
item_dct = {
|
||||
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
|
||||
'R': R_i.cpu().numpy().astype(np.float32),
|
||||
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
|
||||
't': x_i_info['t'].cpu().numpy().astype(np.float32),
|
||||
}
|
||||
|
||||
template_dct['motion'].append(item_dct)
|
||||
|
||||
return template_dct
|
||||
|
||||
def execute(self, args: ArgumentConfig):
|
||||
# for convenience
|
||||
inf_cfg = self.live_portrait_wrapper_animal.inference_cfg
|
||||
device = self.live_portrait_wrapper_animal.device
|
||||
crop_cfg = self.cropper.crop_cfg
|
||||
|
||||
######## load source input ########
|
||||
if is_image(args.source):
|
||||
img_rgb = load_image_rgb(args.source)
|
||||
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
|
||||
log(f"Load source image from {args.source}")
|
||||
else: # source input is an unknown format
|
||||
raise Exception(f"Unknown source format: {args.source}")
|
||||
|
||||
######## process driving info ########
|
||||
flag_load_from_template = is_template(args.driving)
|
||||
driving_rgb_crop_256x256_lst = None
|
||||
wfp_template = None
|
||||
|
||||
if flag_load_from_template:
|
||||
# NOTE: load from template, it is fast, but the cropping video is None
|
||||
log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
|
||||
driving_template_dct = load(args.driving)
|
||||
n_frames = driving_template_dct['n_frames']
|
||||
|
||||
# set output_fps
|
||||
output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps)
|
||||
log(f'The FPS of template: {output_fps}')
|
||||
|
||||
if args.flag_crop_driving_video:
|
||||
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
||||
|
||||
elif osp.exists(args.driving) and is_video(args.driving):
|
||||
# load from video file, AND make motion template
|
||||
output_fps = int(get_fps(args.driving))
|
||||
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
||||
|
||||
driving_rgb_lst = load_video(args.driving)
|
||||
n_frames = len(driving_rgb_lst)
|
||||
|
||||
######## make motion template ########
|
||||
log("Start making driving motion template...")
|
||||
if inf_cfg.flag_crop_driving_video:
|
||||
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
|
||||
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
|
||||
if len(ret_d["frame_crop_lst"]) is not n_frames:
|
||||
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
|
||||
driving_rgb_crop_lst = ret_d['frame_crop_lst']
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
|
||||
else:
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
|
||||
#######################################
|
||||
|
||||
# save the motion template
|
||||
I_d_lst = self.live_portrait_wrapper_animal.prepare_videos(driving_rgb_crop_256x256_lst)
|
||||
driving_template_dct = self.make_motion_template(I_d_lst, output_fps=output_fps)
|
||||
|
||||
wfp_template = remove_suffix(args.driving) + '.pkl'
|
||||
dump(wfp_template, driving_template_dct)
|
||||
log(f"Dump motion template to {wfp_template}")
|
||||
|
||||
else:
|
||||
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
|
||||
|
||||
######## prepare for pasteback ########
|
||||
I_p_pstbk_lst = None
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
I_p_pstbk_lst = []
|
||||
log("Prepared pasteback mask done.")
|
||||
|
||||
######## process source info ########
|
||||
if inf_cfg.flag_do_crop:
|
||||
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
|
||||
if crop_info is None:
|
||||
raise Exception("No animal face detected in the source image!")
|
||||
img_crop_256x256 = crop_info['img_crop_256x256']
|
||||
else:
|
||||
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
|
||||
I_s = self.live_portrait_wrapper_animal.prepare_source(img_crop_256x256)
|
||||
x_s_info = self.live_portrait_wrapper_animal.get_kp_info(I_s)
|
||||
x_c_s = x_s_info['kp']
|
||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
f_s = self.live_portrait_wrapper_animal.extract_feature_3d(I_s)
|
||||
x_s = self.live_portrait_wrapper_animal.transform_keypoint(x_s_info)
|
||||
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||
|
||||
######## animate ########
|
||||
I_p_lst = []
|
||||
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
|
||||
|
||||
x_d_i_info = driving_template_dct['motion'][i]
|
||||
x_d_i_info = dct2device(x_d_i_info, device)
|
||||
|
||||
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
|
||||
delta_new = x_d_i_info['exp']
|
||||
t_new = x_d_i_info['t']
|
||||
t_new[..., 2].fill_(0) # zero tz
|
||||
scale_new = x_s_info['scale']
|
||||
|
||||
x_d_i = scale_new * (x_c_s @ R_d_i + delta_new) + t_new
|
||||
|
||||
if i == 0:
|
||||
x_d_0 = x_d_i
|
||||
motion_multiplier = calc_motion_multiplier(x_s, x_d_0)
|
||||
|
||||
x_d_diff = (x_d_i - x_d_0) * motion_multiplier
|
||||
x_d_i = x_d_diff + x_s
|
||||
|
||||
if not inf_cfg.flag_stitching:
|
||||
pass
|
||||
else:
|
||||
x_d_i = self.live_portrait_wrapper_animal.stitching(x_s, x_d_i)
|
||||
|
||||
x_d_i = x_s + (x_d_i - x_s) * inf_cfg.driving_multiplier
|
||||
out = self.live_portrait_wrapper_animal.warp_decode(f_s, x_s, x_d_i)
|
||||
I_p_i = self.live_portrait_wrapper_animal.parse_output(out['out'])[0]
|
||||
I_p_lst.append(I_p_i)
|
||||
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
|
||||
I_p_pstbk_lst.append(I_p_pstbk)
|
||||
|
||||
mkdir(args.output_dir)
|
||||
wfp_concat = None
|
||||
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
||||
|
||||
######### build the final concatenation result #########
|
||||
# driving frame | source image | generation
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||
|
||||
if flag_driving_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
||||
audio_from_which_video = args.driving
|
||||
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_driving_has_audio:
|
||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
||||
audio_from_which_video = args.driving
|
||||
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
|
||||
# final log
|
||||
if wfp_template not in (None, ''):
|
||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||
log(f'Animated video: {wfp}')
|
||||
log(f'Animated video with concat: {wfp_concat}')
|
||||
|
||||
# build the gif
|
||||
wfp_gif = video2gif(wfp)
|
||||
log(f'Animated gif: {wfp_gif}')
|
||||
|
||||
|
||||
return wfp, wfp_concat, wfp_gif
|
@ -1,9 +1,10 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Wrapper for LivePortrait core functions
|
||||
Wrappers for LivePortrait core functions
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import cv2
|
||||
@ -19,46 +20,73 @@ from .utils.rprint import rlog as log
|
||||
|
||||
|
||||
class LivePortraitWrapper(object):
|
||||
"""
|
||||
Wrapper for Human
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: InferenceConfig):
|
||||
def __init__(self, inference_cfg: InferenceConfig):
|
||||
|
||||
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||
self.inference_cfg = inference_cfg
|
||||
self.device_id = inference_cfg.device_id
|
||||
self.compile = inference_cfg.flag_do_torch_compile
|
||||
if inference_cfg.flag_force_cpu:
|
||||
self.device = 'cpu'
|
||||
else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = 'mps'
|
||||
else:
|
||||
self.device = 'cuda:' + str(self.device_id)
|
||||
except:
|
||||
self.device = 'cuda:' + str(self.device_id)
|
||||
|
||||
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||
# init F
|
||||
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
||||
log(f'Load appearance_feature_extractor done.')
|
||||
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
|
||||
log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F)} done.')
|
||||
# init M
|
||||
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
||||
log(f'Load motion_extractor done.')
|
||||
self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
|
||||
log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M)} done.')
|
||||
# init W
|
||||
self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
||||
log(f'Load warping_module done.')
|
||||
self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
|
||||
log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W)} done.')
|
||||
# init G
|
||||
self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
||||
log(f'Load spade_generator done.')
|
||||
self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
|
||||
log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G)} done.')
|
||||
# init S and R
|
||||
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
|
||||
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
||||
log(f'Load stitching_retargeting_module done.')
|
||||
if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
|
||||
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
|
||||
log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S)} done.')
|
||||
else:
|
||||
self.stitching_retargeting_module = None
|
||||
# Optimize for inference
|
||||
if self.compile:
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
|
||||
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
|
||||
|
||||
self.cfg = cfg
|
||||
self.device_id = cfg.device_id
|
||||
self.timer = Timer()
|
||||
|
||||
def inference_ctx(self):
|
||||
if self.device == "mps":
|
||||
ctx = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
|
||||
enabled=self.inference_cfg.flag_use_half_precision)
|
||||
return ctx
|
||||
|
||||
def update_config(self, user_args):
|
||||
for k, v in user_args.items():
|
||||
if hasattr(self.cfg, k):
|
||||
setattr(self.cfg, k, v)
|
||||
if hasattr(self.inference_cfg, k):
|
||||
setattr(self.inference_cfg, k, v)
|
||||
|
||||
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
|
||||
""" construct the input as standard
|
||||
img: HxWx3, uint8, 256x256
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
|
||||
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
|
||||
if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
|
||||
x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
|
||||
else:
|
||||
x = img.copy()
|
||||
|
||||
@ -70,10 +98,10 @@ class LivePortraitWrapper(object):
|
||||
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
||||
x = np.clip(x, 0, 1) # clip to 0~1
|
||||
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
||||
x = x.cuda(self.device_id)
|
||||
x = x.to(self.device)
|
||||
return x
|
||||
|
||||
def prepare_driving_videos(self, imgs) -> torch.Tensor:
|
||||
def prepare_videos(self, imgs) -> torch.Tensor:
|
||||
""" construct the input as standard
|
||||
imgs: NxBxHxWx3, uint8
|
||||
"""
|
||||
@ -87,7 +115,7 @@ class LivePortraitWrapper(object):
|
||||
y = _imgs.astype(np.float32) / 255.
|
||||
y = np.clip(y, 0, 1) # clip to 0~1
|
||||
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
||||
y = y.cuda(self.device_id)
|
||||
y = y.to(self.device)
|
||||
|
||||
return y
|
||||
|
||||
@ -95,9 +123,8 @@ class LivePortraitWrapper(object):
|
||||
""" get the appearance feature of the image by F
|
||||
x: Bx3xHxW, normalized to 0~1
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
||||
feature_3d = self.appearance_feature_extractor(x)
|
||||
with torch.no_grad(), self.inference_ctx():
|
||||
feature_3d = self.appearance_feature_extractor(x)
|
||||
|
||||
return feature_3d.float()
|
||||
|
||||
@ -107,11 +134,10 @@ class LivePortraitWrapper(object):
|
||||
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
|
||||
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
||||
kp_info = self.motion_extractor(x)
|
||||
with torch.no_grad(), self.inference_ctx():
|
||||
kp_info = self.motion_extractor(x)
|
||||
|
||||
if self.cfg.flag_use_half_precision:
|
||||
if self.inference_cfg.flag_use_half_precision:
|
||||
# float the dict
|
||||
for k, v in kp_info.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
@ -189,26 +215,27 @@ class LivePortraitWrapper(object):
|
||||
"""
|
||||
kp_source: BxNx3
|
||||
eye_close_ratio: Bx3
|
||||
Return: Bx(3*num_kp+2)
|
||||
Return: Bx(3*num_kp)
|
||||
"""
|
||||
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
||||
|
||||
with torch.no_grad():
|
||||
delta = self.stitching_retargeting_module['eye'](feat_eye)
|
||||
|
||||
return delta
|
||||
return delta.reshape(-1, kp_source.shape[1], 3)
|
||||
|
||||
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
kp_source: BxNx3
|
||||
lip_close_ratio: Bx2
|
||||
Return: Bx(3*num_kp)
|
||||
"""
|
||||
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
||||
|
||||
with torch.no_grad():
|
||||
delta = self.stitching_retargeting_module['lip'](feat_lip)
|
||||
|
||||
return delta
|
||||
return delta.reshape(-1, kp_source.shape[1], 3)
|
||||
|
||||
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@ -253,15 +280,17 @@ class LivePortraitWrapper(object):
|
||||
kp_driving: BxNx3
|
||||
"""
|
||||
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
||||
# get decoder input
|
||||
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
||||
# decode
|
||||
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
||||
with torch.no_grad(), self.inference_ctx():
|
||||
if self.compile:
|
||||
# Mark the beginning of a new CUDA Graph step
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
# get decoder input
|
||||
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
||||
# decode
|
||||
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
||||
|
||||
# float the dict
|
||||
if self.cfg.flag_use_half_precision:
|
||||
if self.inference_cfg.flag_use_half_precision:
|
||||
for k, v in ret_dct.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
ret_dct[k] = v.float()
|
||||
@ -278,30 +307,78 @@ class LivePortraitWrapper(object):
|
||||
|
||||
return out
|
||||
|
||||
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
|
||||
def calc_ratio(self, lmk_lst):
|
||||
input_eye_ratio_lst = []
|
||||
input_lip_ratio_lst = []
|
||||
for lmk in driving_lmk_lst:
|
||||
for lmk in lmk_lst:
|
||||
# for eyes retargeting
|
||||
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
|
||||
# for lip retargeting
|
||||
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
||||
return input_eye_ratio_lst, input_lip_ratio_lst
|
||||
|
||||
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
|
||||
eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
|
||||
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
|
||||
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
|
||||
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
|
||||
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
|
||||
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
|
||||
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
|
||||
# [c_s,eyes, c_d,eyes,i]
|
||||
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
|
||||
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
|
||||
return combined_eye_ratio_tensor
|
||||
|
||||
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
|
||||
lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
|
||||
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
|
||||
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
|
||||
c_s_lip = calc_lip_close_ratio(source_lmk[None])
|
||||
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
|
||||
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
|
||||
# [c_s,lip, c_d,lip,i]
|
||||
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
|
||||
if input_lip_ratio_tensor.shape != [1, 1]:
|
||||
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
|
||||
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
|
||||
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
|
||||
return combined_lip_ratio_tensor
|
||||
|
||||
|
||||
class LivePortraitWrapperAnimal(LivePortraitWrapper):
|
||||
"""
|
||||
Wrapper for Animal
|
||||
"""
|
||||
def __init__(self, inference_cfg: InferenceConfig):
|
||||
# super().__init__(inference_cfg) # 调用父类的初始化方法
|
||||
|
||||
self.inference_cfg = inference_cfg
|
||||
self.device_id = inference_cfg.device_id
|
||||
self.compile = inference_cfg.flag_do_torch_compile
|
||||
if inference_cfg.flag_force_cpu:
|
||||
self.device = 'cpu'
|
||||
else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = 'mps'
|
||||
else:
|
||||
self.device = 'cuda:' + str(self.device_id)
|
||||
except:
|
||||
self.device = 'cuda:' + str(self.device_id)
|
||||
|
||||
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||
# init F
|
||||
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F_animal, model_config, self.device, 'appearance_feature_extractor')
|
||||
log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F_animal)} done.')
|
||||
# init M
|
||||
self.motion_extractor = load_model(inference_cfg.checkpoint_M_animal, model_config, self.device, 'motion_extractor')
|
||||
log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M_animal)} done.')
|
||||
# init W
|
||||
self.warping_module = load_model(inference_cfg.checkpoint_W_animal, model_config, self.device, 'warping_module')
|
||||
log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W_animal)} done.')
|
||||
# init G
|
||||
self.spade_generator = load_model(inference_cfg.checkpoint_G_animal, model_config, self.device, 'spade_generator')
|
||||
log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G_animal)} done.')
|
||||
# init S and R
|
||||
if inference_cfg.checkpoint_S_animal is not None and osp.exists(inference_cfg.checkpoint_S_animal):
|
||||
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S_animal, model_config, self.device, 'stitching_retargeting_module')
|
||||
log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S_animal)} done.')
|
||||
else:
|
||||
self.stitching_retargeting_module = None
|
||||
|
||||
# Optimize for inference
|
||||
if self.compile:
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
|
||||
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
|
||||
|
||||
self.timer = Timer()
|
||||
|
@ -59,7 +59,7 @@ class DenseMotionNetwork(nn.Module):
|
||||
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
|
||||
|
||||
# adding background feature
|
||||
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
|
||||
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
|
||||
heatmap = torch.cat([zeros, heatmap], dim=1)
|
||||
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
|
||||
return heatmap
|
||||
|
@ -11,7 +11,8 @@ import torch
|
||||
import torch.nn.utils.spectral_norm as spectral_norm
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import collections.abc
|
||||
from itertools import repeat
|
||||
|
||||
def kp2gaussian(kp, spatial_size, kp_variance):
|
||||
"""
|
||||
@ -439,3 +440,13 @@ class DropPath(nn.Module):
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
@ -1,65 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Make video template
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pickle
|
||||
from rich.progress import track
|
||||
from .utils.cropper import Cropper
|
||||
|
||||
from .utils.io import load_driving_info
|
||||
from .utils.camera import get_rotation_matrix
|
||||
from .utils.helper import mkdir, basename
|
||||
from .utils.rprint import rlog as log
|
||||
from .config.crop_config import CropConfig
|
||||
from .config.inference_config import InferenceConfig
|
||||
from .live_portrait_wrapper import LivePortraitWrapper
|
||||
|
||||
class TemplateMaker:
|
||||
|
||||
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
||||
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
||||
self.cropper = Cropper(crop_cfg=crop_cfg)
|
||||
|
||||
def make_motion_template(self, video_fp: str, output_path: str, **kwargs):
|
||||
""" make video template (.pkl format)
|
||||
video_fp: driving video file path
|
||||
output_path: where to save the pickle file
|
||||
"""
|
||||
|
||||
driving_rgb_lst = load_driving_info(video_fp)
|
||||
driving_rgb_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
|
||||
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
||||
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst)
|
||||
|
||||
n_frames = I_d_lst.shape[0]
|
||||
|
||||
templates = []
|
||||
|
||||
|
||||
for i in track(range(n_frames), description='Making templates...', total=n_frames):
|
||||
I_d_i = I_d_lst[i]
|
||||
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
||||
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
||||
# collect s_d, R_d, δ_d and t_d for inference
|
||||
template_dct = {
|
||||
'n_frames': n_frames,
|
||||
'frames_index': i,
|
||||
}
|
||||
template_dct['scale'] = x_d_i_info['scale'].cpu().numpy().astype(np.float32)
|
||||
template_dct['R_d'] = R_d_i.cpu().numpy().astype(np.float32)
|
||||
template_dct['exp'] = x_d_i_info['exp'].cpu().numpy().astype(np.float32)
|
||||
template_dct['t'] = x_d_i_info['t'].cpu().numpy().astype(np.float32)
|
||||
|
||||
templates.append(template_dct)
|
||||
|
||||
mkdir(output_path)
|
||||
# Save the dictionary as a pickle file
|
||||
pickle_fp = os.path.join(output_path, f'{basename(video_fp)}.pkl')
|
||||
with open(pickle_fp, 'wb') as f:
|
||||
pickle.dump([templates, driving_lmk_lst], f)
|
||||
log(f"Template saved at {pickle_fp}")
|
138
src/utils/animal_landmark_runner.py
Normal file
@ -0,0 +1,138 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
face detectoin and alignment using XPose
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.ops import nms
|
||||
|
||||
from .timer import Timer
|
||||
from .rprint import rlog as log
|
||||
from .helper import clean_state_dict
|
||||
|
||||
from .dependencies.XPose import transforms as T
|
||||
from .dependencies.XPose.models import build_model
|
||||
from .dependencies.XPose.predefined_keypoints import *
|
||||
from .dependencies.XPose.util import box_ops
|
||||
from .dependencies.XPose.util.config import Config
|
||||
|
||||
|
||||
class XPoseRunner(object):
|
||||
def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
|
||||
self.device_id = kwargs.get("device_id", 0)
|
||||
self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
|
||||
self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
|
||||
self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
|
||||
self.timer = Timer()
|
||||
# Load cached embeddings if available
|
||||
try:
|
||||
with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
|
||||
self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
|
||||
with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
|
||||
self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
|
||||
print("Loaded cached embeddings from file.")
|
||||
except Exception:
|
||||
raise ValueError("Could not load clip embeddings from file, please check your file path.")
|
||||
|
||||
def load_animal_model(self, model_config_path, model_checkpoint_path, device):
|
||||
args = Config.fromfile(model_config_path)
|
||||
args.device = device
|
||||
model = build_model(args)
|
||||
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_image(self, input_image):
|
||||
image_pil = input_image.convert("RGB")
|
||||
transform = T.Compose([
|
||||
T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
image, _ = transform(image_pil, None)
|
||||
return image_pil, image
|
||||
|
||||
def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
|
||||
instance_list = instance_text_prompt.split(',')
|
||||
|
||||
if len(keypoint_text_prompt) == 9:
|
||||
# torch.Size([1, 512]) torch.Size([9, 512])
|
||||
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
|
||||
elif len(keypoint_text_prompt) ==68:
|
||||
# torch.Size([1, 512]) torch.Size([68, 512])
|
||||
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
|
||||
else:
|
||||
raise ValueError("Invalid number of keypoint embeddings.")
|
||||
target = {
|
||||
"instance_text_prompt": instance_list,
|
||||
"keypoint_text_prompt": keypoint_text_prompt,
|
||||
"object_embeddings_text": ins_text_embeddings.float(),
|
||||
"kpts_embeddings_text": torch.cat((kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), dim=0),
|
||||
"kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
|
||||
}
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
image = image.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
|
||||
outputs = self.model(image[None], [target])
|
||||
|
||||
logits = outputs["pred_logits"].sigmoid()[0]
|
||||
boxes = outputs["pred_boxes"][0]
|
||||
keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]
|
||||
|
||||
logits_filt = logits.cpu().clone()
|
||||
boxes_filt = boxes.cpu().clone()
|
||||
keypoints_filt = keypoints.cpu().clone()
|
||||
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
||||
logits_filt = logits_filt[filt_mask]
|
||||
boxes_filt = boxes_filt[filt_mask]
|
||||
keypoints_filt = keypoints_filt[filt_mask]
|
||||
|
||||
keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], iou_threshold=IoU_threshold)
|
||||
|
||||
filtered_boxes = boxes_filt[keep_indices]
|
||||
filtered_keypoints = keypoints_filt[keep_indices]
|
||||
|
||||
return filtered_boxes, filtered_keypoints
|
||||
|
||||
def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
|
||||
if keypoint_text_example in globals():
|
||||
keypoint_dict = globals()[keypoint_text_example]
|
||||
elif instance_text_prompt in globals():
|
||||
keypoint_dict = globals()[instance_text_prompt]
|
||||
else:
|
||||
keypoint_dict = globals()["animal"]
|
||||
|
||||
keypoint_text_prompt = keypoint_dict.get("keypoints")
|
||||
keypoint_skeleton = keypoint_dict.get("skeleton")
|
||||
|
||||
image_pil, image = self.load_image(input_image)
|
||||
boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold)
|
||||
|
||||
size = image_pil.size
|
||||
H, W = size[1], size[0]
|
||||
keypoints_filt = keypoints_filt[0].squeeze(0)
|
||||
kp = np.array(keypoints_filt.cpu())
|
||||
num_kpts = len(keypoint_text_prompt)
|
||||
Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
|
||||
Z = Z.reshape(num_kpts * 2)
|
||||
x = Z[0::2]
|
||||
y = Z[1::2]
|
||||
return np.stack((x, y), axis=1)
|
||||
|
||||
def warmup(self):
|
||||
self.timer.tic()
|
||||
|
||||
img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
|
||||
self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)
|
||||
|
||||
elapse = self.timer.toc()
|
||||
log(f'XPoseRunner warmup time: {elapse:.3f}s')
|
@ -31,8 +31,6 @@ def headpose_pred_to_degree(pred):
|
||||
def get_rotation_matrix(pitch_, yaw_, roll_):
|
||||
""" the input is in degree
|
||||
"""
|
||||
# calculate the rotation matrix: vps @ rot
|
||||
|
||||
# transform to radian
|
||||
pitch = pitch_ / 180 * PI
|
||||
yaw = yaw_ / 180 * PI
|
||||
|
18
src/utils/check_windows_port.py
Normal file
@ -0,0 +1,18 @@
|
||||
import socket
|
||||
import sys
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python check_port.py <port>")
|
||||
sys.exit(1)
|
||||
|
||||
port = int(sys.argv[1])
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
result = sock.connect_ex(('127.0.0.1', port))
|
||||
|
||||
if result == 0:
|
||||
print("LISTENING")
|
||||
else:
|
||||
print("NOT LISTENING")
|
||||
sock.close
|
@ -136,6 +136,29 @@ def parse_pt2_from_pt5(pt5, use_lip=True):
|
||||
], axis=0)
|
||||
return pt2
|
||||
|
||||
def parse_pt2_from_pt9(pt9, use_lip=True):
|
||||
'''
|
||||
parsing the 2 points according to the 9 points, which cancels the roll
|
||||
['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip']
|
||||
'''
|
||||
if use_lip:
|
||||
pt9 = np.stack([
|
||||
(pt9[2] + pt9[3]) / 2, # left eye
|
||||
(pt9[0] + pt9[1]) / 2, # right eye
|
||||
pt9[4],
|
||||
(pt9[5] + pt9[6] ) / 2 # lip
|
||||
], axis=0)
|
||||
pt2 = np.stack([
|
||||
(pt9[0] + pt9[1]) / 2, # eye
|
||||
pt9[3] # lip
|
||||
], axis=0)
|
||||
else:
|
||||
pt2 = np.stack([
|
||||
(pt9[2] + pt9[3]) / 2,
|
||||
(pt9[0] + pt9[1]) / 2,
|
||||
], axis=0)
|
||||
|
||||
return pt2
|
||||
|
||||
def parse_pt2_from_pt_x(pts, use_lip=True):
|
||||
if pts.shape[0] == 101:
|
||||
@ -151,6 +174,8 @@ def parse_pt2_from_pt_x(pts, use_lip=True):
|
||||
elif pts.shape[0] > 101:
|
||||
# take the first 101 points
|
||||
pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip)
|
||||
elif pts.shape[0] == 9:
|
||||
pt2 = parse_pt2_from_pt9(pts, use_lip=use_lip)
|
||||
else:
|
||||
raise Exception(f'Unknow shape: {pts.shape}')
|
||||
|
||||
@ -281,11 +306,10 @@ def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=Fals
|
||||
dtype=DTYPE
|
||||
)
|
||||
|
||||
if flag_rot and angle is None:
|
||||
print('angle is None, but flag_rotate is True', style="bold yellow")
|
||||
# if flag_rot and angle is None:
|
||||
# print('angle is None, but flag_rotate is True', style="bold yellow")
|
||||
|
||||
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
|
||||
|
||||
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
|
||||
|
||||
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
|
||||
@ -362,17 +386,6 @@ def crop_image(img, pts: np.ndarray, **kwargs):
|
||||
flag_do_rot=kwargs.get('flag_do_rot', True),
|
||||
)
|
||||
|
||||
if img is None:
|
||||
M_INV_H = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
|
||||
M = np.linalg.inv(M_INV_H)
|
||||
ret_dct = {
|
||||
'M': M[:2, ...], # from the original image to the cropped image
|
||||
'M_o2c': M[:2, ...], # from the cropped image to the original image
|
||||
'img_crop': None,
|
||||
'pt_crop': None,
|
||||
}
|
||||
return ret_dct
|
||||
|
||||
img_crop = _transform_img(img, M_INV, dsize) # origin to crop
|
||||
pt_crop = _transform_pts(pts, M_INV)
|
||||
|
||||
@ -397,16 +410,14 @@ def average_bbox_lst(bbox_lst):
|
||||
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
|
||||
"""prepare mask for later image paste back
|
||||
"""
|
||||
if mask_crop is None:
|
||||
mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
|
||||
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
|
||||
mask_ori = mask_ori.astype(np.float32) / 255.
|
||||
return mask_ori
|
||||
|
||||
def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
|
||||
def paste_back(img_crop, M_c2o, img_ori, mask_ori):
|
||||
"""paste back the image
|
||||
"""
|
||||
dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
|
||||
result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
|
||||
result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
|
||||
return result
|
||||
dsize = (img_ori.shape[1], img_ori.shape[0])
|
||||
result = _transform_img(img_crop, M_c2o, dsize=dsize)
|
||||
result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8)
|
||||
return result
|
||||
|
@ -1,21 +1,25 @@
|
||||
# coding: utf-8
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from typing import List, Union, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
from .landmark_runner import LandmarkRunner
|
||||
from .face_analysis_diy import FaceAnalysisDIY
|
||||
from .helper import prefix
|
||||
from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst
|
||||
from .timer import Timer
|
||||
from .rprint import rlog as log
|
||||
from .io import load_image_rgb
|
||||
from .video import VideoWriter, get_fps, change_video_fps
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config.crop_config import CropConfig
|
||||
from .crop import (
|
||||
average_bbox_lst,
|
||||
crop_image,
|
||||
crop_image_by_bbox,
|
||||
parse_bbox_from_landmark,
|
||||
)
|
||||
from .io import contiguous
|
||||
from .rprint import rlog as log
|
||||
from .face_analysis_diy import FaceAnalysisDIY
|
||||
from .human_landmark_runner import LandmarkRunner as HumanLandmark
|
||||
|
||||
def make_abs_path(fn):
|
||||
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
||||
@ -23,123 +27,287 @@ def make_abs_path(fn):
|
||||
|
||||
@dataclass
|
||||
class Trajectory:
|
||||
start: int = -1 # 起始帧 闭区间
|
||||
end: int = -1 # 结束帧 闭区间
|
||||
start: int = -1 # start frame
|
||||
end: int = -1 # end frame
|
||||
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
||||
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
|
||||
M_c2o_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # M_c2o list
|
||||
|
||||
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
|
||||
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
||||
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
|
||||
|
||||
|
||||
class Cropper(object):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
device_id = kwargs.get('device_id', 0)
|
||||
self.landmark_runner = LandmarkRunner(
|
||||
ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'),
|
||||
onnx_provider='cuda',
|
||||
device_id=device_id
|
||||
)
|
||||
self.landmark_runner.warmup()
|
||||
|
||||
self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
|
||||
self.image_type = kwargs.get("image_type", 'human_face')
|
||||
device_id = kwargs.get("device_id", 0)
|
||||
flag_force_cpu = kwargs.get("flag_force_cpu", False)
|
||||
if flag_force_cpu:
|
||||
device = "cpu"
|
||||
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
|
||||
else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
# Shape inference currently fails with CoreMLExecutionProvider
|
||||
# for the retinaface model
|
||||
device = "mps"
|
||||
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
|
||||
else:
|
||||
device = "cuda"
|
||||
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
|
||||
except:
|
||||
device = "cuda"
|
||||
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
|
||||
self.face_analysis_wrapper = FaceAnalysisDIY(
|
||||
name='buffalo_l',
|
||||
root=make_abs_path('../../pretrained_weights/insightface'),
|
||||
providers=["CUDAExecutionProvider"]
|
||||
)
|
||||
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
|
||||
name="buffalo_l",
|
||||
root=self.crop_cfg.insightface_root,
|
||||
providers=face_analysis_wrapper_provider,
|
||||
)
|
||||
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh)
|
||||
self.face_analysis_wrapper.warmup()
|
||||
|
||||
self.crop_cfg = kwargs.get('crop_cfg', None)
|
||||
self.human_landmark_runner = HumanLandmark(
|
||||
ckpt_path=self.crop_cfg.landmark_ckpt_path,
|
||||
onnx_provider=device,
|
||||
device_id=device_id,
|
||||
)
|
||||
self.human_landmark_runner.warmup()
|
||||
|
||||
if self.image_type == "animal_face":
|
||||
from .animal_landmark_runner import XPoseRunner as AnimalLandmarkRunner
|
||||
self.animal_landmark_runner = AnimalLandmarkRunner(
|
||||
model_config_path=self.crop_cfg.xpose_config_file_path,
|
||||
model_checkpoint_path=self.crop_cfg.xpose_ckpt_path,
|
||||
embeddings_cache_path=self.crop_cfg.xpose_embedding_cache_path,
|
||||
flag_use_half_precision=kwargs.get("flag_use_half_precision", True),
|
||||
)
|
||||
self.animal_landmark_runner.warmup()
|
||||
|
||||
def update_config(self, user_args):
|
||||
for k, v in user_args.items():
|
||||
if hasattr(self.crop_cfg, k):
|
||||
setattr(self.crop_cfg, k, v)
|
||||
|
||||
def crop_single_image(self, obj, **kwargs):
|
||||
direction = kwargs.get('direction', 'large-small')
|
||||
def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig):
|
||||
# crop a source image and get neccessary information
|
||||
img_rgb = img_rgb_.copy() # copy it
|
||||
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# crop and align a single image
|
||||
if isinstance(obj, str):
|
||||
img_rgb = load_image_rgb(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
img_rgb = obj
|
||||
if self.image_type == "human_face":
|
||||
src_face = self.face_analysis_wrapper.get(
|
||||
img_bgr,
|
||||
flag_do_landmark_2d_106=True,
|
||||
direction=crop_cfg.direction,
|
||||
max_face_num=crop_cfg.max_face_num,
|
||||
)
|
||||
|
||||
src_face = self.face_analysis_wrapper.get(
|
||||
img_rgb,
|
||||
flag_do_landmark_2d_106=True,
|
||||
direction=direction
|
||||
)
|
||||
if len(src_face) == 0:
|
||||
log("No face detected in the source image.")
|
||||
return None
|
||||
elif len(src_face) > 1:
|
||||
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
|
||||
|
||||
if len(src_face) == 0:
|
||||
log('No face detected in the source image.')
|
||||
raise gr.Error("No face detected in the source image 💥!", duration=5)
|
||||
raise Exception("No face detected in the source image!")
|
||||
elif len(src_face) > 1:
|
||||
log(f'More than one face detected in the image, only pick one face by rule {direction}.')
|
||||
# NOTE: temporarily only pick the first face, to support multiple face in the future
|
||||
src_face = src_face[0]
|
||||
lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
|
||||
else:
|
||||
tmp_dct = {
|
||||
'animal_face_9': 'animal_face',
|
||||
'animal_face_68': 'face'
|
||||
}
|
||||
|
||||
src_face = src_face[0]
|
||||
pts = src_face.landmark_2d_106
|
||||
img_rgb_pil = Image.fromarray(img_rgb)
|
||||
lmk = self.animal_landmark_runner.run(
|
||||
img_rgb_pil,
|
||||
'face',
|
||||
tmp_dct[crop_cfg.animal_face_type],
|
||||
0,
|
||||
0
|
||||
)
|
||||
|
||||
# crop the face
|
||||
ret_dct = crop_image(
|
||||
img_rgb, # ndarray
|
||||
pts, # 106x2 or Nx2
|
||||
dsize=kwargs.get('dsize', 512),
|
||||
scale=kwargs.get('scale', 2.3),
|
||||
vy_ratio=kwargs.get('vy_ratio', -0.15),
|
||||
lmk, # 106x2 or Nx2
|
||||
dsize=crop_cfg.dsize,
|
||||
scale=crop_cfg.scale,
|
||||
vx_ratio=crop_cfg.vx_ratio,
|
||||
vy_ratio=crop_cfg.vy_ratio,
|
||||
flag_do_rot=crop_cfg.flag_do_rot,
|
||||
)
|
||||
# update a 256x256 version for network input or else
|
||||
ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA)
|
||||
ret_dct['pt_crop_256x256'] = ret_dct['pt_crop'] * 256 / kwargs.get('dsize', 512)
|
||||
|
||||
recon_ret = self.landmark_runner.run(img_rgb, pts)
|
||||
lmk = recon_ret['pts']
|
||||
ret_dct['lmk_crop'] = lmk
|
||||
# update a 256x256 version for network input
|
||||
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
|
||||
if self.image_type == "human_face":
|
||||
lmk = self.human_landmark_runner.run(img_rgb, lmk)
|
||||
ret_dct["lmk_crop"] = lmk
|
||||
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
|
||||
else:
|
||||
# 68x2 or 9x2
|
||||
ret_dct["lmk_crop"] = lmk
|
||||
|
||||
return ret_dct
|
||||
|
||||
def get_retargeting_lmk_info(self, driving_rgb_lst):
|
||||
# TODO: implement a tracking-based version
|
||||
driving_lmk_lst = []
|
||||
for driving_image in driving_rgb_lst:
|
||||
ret_dct = self.crop_single_image(driving_image)
|
||||
driving_lmk_lst.append(ret_dct['lmk_crop'])
|
||||
return driving_lmk_lst
|
||||
def calc_lmk_from_cropped_image(self, img_rgb_, **kwargs):
|
||||
direction = kwargs.get("direction", "large-small")
|
||||
src_face = self.face_analysis_wrapper.get(
|
||||
contiguous(img_rgb_[..., ::-1]), # convert to BGR
|
||||
flag_do_landmark_2d_106=True,
|
||||
direction=direction,
|
||||
)
|
||||
if len(src_face) == 0:
|
||||
log("No face detected in the source image.")
|
||||
return None
|
||||
elif len(src_face) > 1:
|
||||
log(f"More than one face detected in the image, only pick one face by rule {direction}.")
|
||||
src_face = src_face[0]
|
||||
lmk = src_face.landmark_2d_106
|
||||
lmk = self.human_landmark_runner.run(img_rgb_, lmk)
|
||||
|
||||
def make_video_clip(self, driving_rgb_lst, output_path, output_fps=30, **kwargs):
|
||||
return lmk
|
||||
|
||||
# TODO: support skipping frame with NO FACE
|
||||
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
|
||||
"""Tracking based landmarks/alignment and cropping"""
|
||||
trajectory = Trajectory()
|
||||
direction = kwargs.get('direction', 'large-small')
|
||||
for idx, driving_image in enumerate(driving_rgb_lst):
|
||||
direction = kwargs.get("direction", "large-small")
|
||||
for idx, frame_rgb in enumerate(source_rgb_lst):
|
||||
if idx == 0 or trajectory.start == -1:
|
||||
src_face = self.face_analysis_wrapper.get(
|
||||
driving_image,
|
||||
contiguous(frame_rgb[..., ::-1]),
|
||||
flag_do_landmark_2d_106=True,
|
||||
direction=direction
|
||||
direction=crop_cfg.direction,
|
||||
max_face_num=crop_cfg.max_face_num,
|
||||
)
|
||||
if len(src_face) == 0:
|
||||
# No face detected in the driving_image
|
||||
log(f"No face detected in the frame #{idx}")
|
||||
continue
|
||||
elif len(src_face) > 1:
|
||||
log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.')
|
||||
log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.")
|
||||
src_face = src_face[0]
|
||||
pts = src_face.landmark_2d_106
|
||||
lmk_203 = self.landmark_runner(driving_image, pts)['pts']
|
||||
lmk = src_face.landmark_2d_106
|
||||
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
|
||||
trajectory.start, trajectory.end = idx, idx
|
||||
else:
|
||||
lmk_203 = self.face_recon_wrapper(driving_image, trajectory.lmk_lst[-1])['pts']
|
||||
# TODO: add IOU check for tracking
|
||||
lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
|
||||
trajectory.end = idx
|
||||
|
||||
trajectory.lmk_lst.append(lmk_203)
|
||||
ret_bbox = parse_bbox_from_landmark(lmk_203, scale=self.crop_cfg.globalscale, vy_ratio=elf.crop_cfg.vy_ratio)['bbox']
|
||||
bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4,
|
||||
trajectory.lmk_lst.append(lmk)
|
||||
|
||||
# crop the face
|
||||
ret_dct = crop_image(
|
||||
frame_rgb, # ndarray
|
||||
lmk, # 106x2 or Nx2
|
||||
dsize=crop_cfg.dsize,
|
||||
scale=crop_cfg.scale,
|
||||
vx_ratio=crop_cfg.vx_ratio,
|
||||
vy_ratio=crop_cfg.vy_ratio,
|
||||
flag_do_rot=crop_cfg.flag_do_rot,
|
||||
)
|
||||
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
|
||||
ret_dct["lmk_crop"] = lmk
|
||||
|
||||
# update a 256x256 version for network input
|
||||
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
|
||||
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
|
||||
|
||||
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop_256x256"])
|
||||
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop_256x256"])
|
||||
trajectory.M_c2o_lst.append(ret_dct['M_c2o'])
|
||||
|
||||
return {
|
||||
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
|
||||
"lmk_crop_lst": trajectory.lmk_crop_lst,
|
||||
"M_c2o_lst": trajectory.M_c2o_lst,
|
||||
}
|
||||
|
||||
def crop_driving_video(self, driving_rgb_lst, **kwargs):
|
||||
"""Tracking based landmarks/alignment and cropping"""
|
||||
trajectory = Trajectory()
|
||||
direction = kwargs.get("direction", "large-small")
|
||||
for idx, frame_rgb in enumerate(driving_rgb_lst):
|
||||
if idx == 0 or trajectory.start == -1:
|
||||
src_face = self.face_analysis_wrapper.get(
|
||||
contiguous(frame_rgb[..., ::-1]),
|
||||
flag_do_landmark_2d_106=True,
|
||||
direction=direction,
|
||||
)
|
||||
if len(src_face) == 0:
|
||||
log(f"No face detected in the frame #{idx}")
|
||||
continue
|
||||
elif len(src_face) > 1:
|
||||
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
|
||||
src_face = src_face[0]
|
||||
lmk = src_face.landmark_2d_106
|
||||
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
|
||||
trajectory.start, trajectory.end = idx, idx
|
||||
else:
|
||||
lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
|
||||
trajectory.end = idx
|
||||
|
||||
trajectory.lmk_lst.append(lmk)
|
||||
ret_bbox = parse_bbox_from_landmark(
|
||||
lmk,
|
||||
scale=self.crop_cfg.scale_crop_driving_video,
|
||||
vx_ratio_crop_driving_video=self.crop_cfg.vx_ratio_crop_driving_video,
|
||||
vy_ratio=self.crop_cfg.vy_ratio_crop_driving_video,
|
||||
)["bbox"]
|
||||
bbox = [
|
||||
ret_bbox[0, 0],
|
||||
ret_bbox[0, 1],
|
||||
ret_bbox[2, 0],
|
||||
ret_bbox[2, 1],
|
||||
] # 4,
|
||||
trajectory.bbox_lst.append(bbox) # bbox
|
||||
trajectory.frame_rgb_lst.append(driving_image)
|
||||
trajectory.frame_rgb_lst.append(frame_rgb)
|
||||
|
||||
global_bbox = average_bbox_lst(trajectory.bbox_lst)
|
||||
|
||||
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
|
||||
ret_dct = crop_image_by_bbox(
|
||||
frame_rgb, global_bbox, lmk=lmk,
|
||||
dsize=self.video_crop_cfg.dsize, flag_rot=self.video_crop_cfg.flag_rot, borderValue=self.video_crop_cfg.borderValue
|
||||
frame_rgb,
|
||||
global_bbox,
|
||||
lmk=lmk,
|
||||
dsize=kwargs.get("dsize", 512),
|
||||
flag_rot=False,
|
||||
borderValue=(0, 0, 0),
|
||||
)
|
||||
frame_rgb_crop = ret_dct['img_crop']
|
||||
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"])
|
||||
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"])
|
||||
|
||||
return {
|
||||
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
|
||||
"lmk_crop_lst": trajectory.lmk_crop_lst,
|
||||
}
|
||||
|
||||
|
||||
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
|
||||
"""Tracking based landmarks/alignment"""
|
||||
trajectory = Trajectory()
|
||||
direction = kwargs.get("direction", "large-small")
|
||||
|
||||
for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst):
|
||||
if idx == 0 or trajectory.start == -1:
|
||||
src_face = self.face_analysis_wrapper.get(
|
||||
contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR
|
||||
flag_do_landmark_2d_106=True,
|
||||
direction=direction,
|
||||
)
|
||||
if len(src_face) == 0:
|
||||
log(f"No face detected in the frame #{idx}")
|
||||
raise Exception(f"No face detected in the frame #{idx}")
|
||||
elif len(src_face) > 1:
|
||||
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
|
||||
src_face = src_face[0]
|
||||
lmk = src_face.landmark_2d_106
|
||||
lmk = self.human_landmark_runner.run(frame_rgb_crop, lmk)
|
||||
trajectory.start, trajectory.end = idx, idx
|
||||
else:
|
||||
lmk = self.human_landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
|
||||
trajectory.end = idx
|
||||
|
||||
trajectory.lmk_lst.append(lmk)
|
||||
return trajectory.lmk_lst
|
||||
|
125
src/utils/dependencies/XPose/config_model/UniPose_SwinT.py
Normal file
@ -0,0 +1,125 @@
|
||||
_base_ = ['coco_transformer.py']
|
||||
|
||||
use_label_enc = True
|
||||
|
||||
num_classes=2
|
||||
|
||||
lr = 0.0001
|
||||
param_dict_type = 'default'
|
||||
lr_backbone = 1e-05
|
||||
lr_backbone_names = ['backbone.0']
|
||||
lr_linear_proj_names = ['reference_points', 'sampling_offsets']
|
||||
lr_linear_proj_mult = 0.1
|
||||
ddetr_lr_param = False
|
||||
batch_size = 2
|
||||
weight_decay = 0.0001
|
||||
epochs = 12
|
||||
lr_drop = 11
|
||||
save_checkpoint_interval = 100
|
||||
clip_max_norm = 0.1
|
||||
onecyclelr = False
|
||||
multi_step_lr = False
|
||||
lr_drop_list = [33, 45]
|
||||
|
||||
|
||||
modelname = 'UniPose'
|
||||
frozen_weights = None
|
||||
backbone = 'swin_T_224_1k'
|
||||
|
||||
|
||||
dilation = False
|
||||
position_embedding = 'sine'
|
||||
pe_temperatureH = 20
|
||||
pe_temperatureW = 20
|
||||
return_interm_indices = [1, 2, 3]
|
||||
backbone_freeze_keywords = None
|
||||
enc_layers = 6
|
||||
dec_layers = 6
|
||||
unic_layers = 0
|
||||
pre_norm = False
|
||||
dim_feedforward = 2048
|
||||
hidden_dim = 256
|
||||
dropout = 0.0
|
||||
nheads = 8
|
||||
num_queries = 900
|
||||
query_dim = 4
|
||||
num_patterns = 0
|
||||
pdetr3_bbox_embed_diff_each_layer = False
|
||||
pdetr3_refHW = -1
|
||||
random_refpoints_xy = False
|
||||
fix_refpoints_hw = -1
|
||||
dabdetr_yolo_like_anchor_update = False
|
||||
dabdetr_deformable_encoder = False
|
||||
dabdetr_deformable_decoder = False
|
||||
use_deformable_box_attn = False
|
||||
box_attn_type = 'roi_align'
|
||||
dec_layer_number = None
|
||||
num_feature_levels = 4
|
||||
enc_n_points = 4
|
||||
dec_n_points = 4
|
||||
decoder_layer_noise = False
|
||||
dln_xy_noise = 0.2
|
||||
dln_hw_noise = 0.2
|
||||
add_channel_attention = False
|
||||
add_pos_value = False
|
||||
two_stage_type = 'standard'
|
||||
two_stage_pat_embed = 0
|
||||
two_stage_add_query_num = 0
|
||||
two_stage_bbox_embed_share = False
|
||||
two_stage_class_embed_share = False
|
||||
two_stage_learn_wh = False
|
||||
two_stage_default_hw = 0.05
|
||||
two_stage_keep_all_tokens = False
|
||||
num_select = 50
|
||||
transformer_activation = 'relu'
|
||||
batch_norm_type = 'FrozenBatchNorm2d'
|
||||
masks = False
|
||||
|
||||
decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content']
|
||||
matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
|
||||
decoder_module_seq = ['sa', 'ca', 'ffn']
|
||||
nms_iou_threshold = -1
|
||||
|
||||
dec_pred_bbox_embed_share = True
|
||||
dec_pred_class_embed_share = True
|
||||
|
||||
|
||||
use_dn = True
|
||||
dn_number = 100
|
||||
dn_box_noise_scale = 1.0
|
||||
dn_label_noise_ratio = 0.5
|
||||
dn_label_coef=1.0
|
||||
dn_bbox_coef=1.0
|
||||
embed_init_tgt = True
|
||||
dn_labelbook_size = 2000
|
||||
|
||||
match_unstable_error = True
|
||||
|
||||
# for ema
|
||||
use_ema = True
|
||||
ema_decay = 0.9997
|
||||
ema_epoch = 0
|
||||
|
||||
use_detached_boxes_dec_out = False
|
||||
|
||||
max_text_len = 256
|
||||
shuffle_type = None
|
||||
|
||||
use_text_enhancer = True
|
||||
use_fusion_layer = True
|
||||
|
||||
use_checkpoint = False # True
|
||||
use_transformer_ckpt = True
|
||||
text_encoder_type = 'bert-base-uncased'
|
||||
|
||||
use_text_cross_attention = True
|
||||
text_dropout = 0.0
|
||||
fusion_dropout = 0.0
|
||||
fusion_droppath = 0.1
|
||||
|
||||
num_body_points=68
|
||||
binary_query_selection = False
|
||||
use_cdn = True
|
||||
ffn_extra_layernorm = False
|
||||
|
||||
fix_size=False
|
@ -0,0 +1,8 @@
|
||||
data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
|
||||
data_aug_max_size = 1333
|
||||
data_aug_scales2_resize = [400, 500, 600]
|
||||
data_aug_scales2_crop = [384, 600]
|
||||
|
||||
|
||||
data_aug_scale_overlap = None
|
||||
|