feat: v2v and gradio upgrade (#172)

* feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

* chore: format and package

* chore: refactor codebase

* feat: v2v

* feat: v2v

* feat: v2v

* feat: v2v

* feat: v2v

* feat: v2v

* feat: gradio tab auto select

* chore: log auto crop

* doc: update changelog

* doc: update changelog

---------

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
This commit is contained in:
Jianzhu Guo 2024-07-19 23:39:05 +08:00 committed by GitHub
parent 24ce67d652
commit 24dcfafdc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 689 additions and 312 deletions

212
app.py
View File

@ -26,18 +26,20 @@ def fast_check_ffmpeg():
except: except:
return False return False
# set tyro theme # set tyro theme
tyro.extras.set_accent_color("bright_cyan") tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig) args = tyro.cli(ArgumentConfig)
if not fast_check_ffmpeg(): if not fast_check_ffmpeg():
raise ImportError( raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html" "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
) )
# specify configs for inference # specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
# global_tab_selection = None
gradio_pipeline = GradioPipeline( gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg, inference_cfg=inference_cfg,
@ -55,10 +57,10 @@ def gpu_wrapped_execute_image(*args, **kwargs):
# assets # assets
title_md = "assets/gradio_title.md" title_md = "assets/gradio/gradio_title.md"
example_portrait_dir = "assets/examples/source" example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving" example_video_dir = "assets/examples/driving"
data_examples = [ data_examples_i2v = [
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
@ -66,6 +68,14 @@ data_examples = [
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True], [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
] ]
data_examples_v2v = [
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-6],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-6],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-6],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-6],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-6],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-6],
]
#################### interface logic #################### #################### interface logic ####################
# Define components first # Define components first
@ -74,80 +84,148 @@ lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="tar
retargeting_input_image = gr.Image(type="filepath") retargeting_input_image = gr.Image(type="filepath")
output_image = gr.Image(type="numpy") output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video() output_video_i2v = gr.Video(autoplay=True)
output_video_concat = gr.Video() output_video_concat_i2v = gr.Video(autoplay=True)
output_video_v2v = gr.Video(autoplay=True)
output_video_concat_v2v = gr.Video(autoplay=True)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md)) gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
gr.Markdown(load_description("assets/gradio/gradio_description_upload.md"))
with gr.Row(): with gr.Row():
with gr.Accordion(open=True, label="Source Portrait"): with gr.Column():
image_input = gr.Image(type="filepath") with gr.Tabs():
gr.Examples( with gr.TabItem("🖼️ Source Image") as tab_image:
examples=[ with gr.Accordion(open=True, label="Source Image"):
[osp.join(example_portrait_dir, "s9.jpg")], source_image_input = gr.Image(type="filepath")
[osp.join(example_portrait_dir, "s6.jpg")], gr.Examples(
[osp.join(example_portrait_dir, "s10.jpg")], examples=[
[osp.join(example_portrait_dir, "s5.jpg")], [osp.join(example_portrait_dir, "s9.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")], [osp.join(example_portrait_dir, "s6.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")], [osp.join(example_portrait_dir, "s10.jpg")],
], [osp.join(example_portrait_dir, "s5.jpg")],
inputs=[image_input], [osp.join(example_portrait_dir, "s7.jpg")],
cache_examples=False, [osp.join(example_portrait_dir, "s12.jpg")],
) ],
with gr.Accordion(open=True, label="Driving Video"): inputs=[source_image_input],
video_input = gr.Video() cache_examples=False,
gr.Examples( )
examples=[
[osp.join(example_video_dir, "d0.mp4")], with gr.TabItem("🎞️ Source Video") as tab_video:
[osp.join(example_video_dir, "d18.mp4")], with gr.Accordion(open=True, label="Source Video"):
[osp.join(example_video_dir, "d19.mp4")], source_video_input = gr.Video()
[osp.join(example_video_dir, "d14.mp4")], gr.Examples(
[osp.join(example_video_dir, "d6.mp4")], examples=[
], [osp.join(example_portrait_dir, "s13.mp4")],
inputs=[video_input], # [osp.join(example_portrait_dir, "s14.mp4")],
cache_examples=False, # [osp.join(example_portrait_dir, "s15.mp4")],
) [osp.join(example_portrait_dir, "s18.mp4")],
# [osp.join(example_portrait_dir, "s19.mp4")],
[osp.join(example_portrait_dir, "s20.mp4")],
],
inputs=[source_video_input],
cache_examples=False,
)
tab_selection = gr.Textbox(visible=False)
tab_image.select(lambda: "Image", None, tab_selection)
tab_video.select(lambda: "Video", None, tab_selection)
with gr.Accordion(open=True, label="Cropping Options for Source Image or Video"):
with gr.Row():
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=2.9, step=0.05)
vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column():
with gr.Accordion(open=True, label="Driving Video"):
driving_video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
)
# with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
with gr.Row():
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=2.9, step=0.05)
vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Row(): with gr.Row():
with gr.Accordion(open=False, label="Animation Instructions and Options"): with gr.Accordion(open=True, label="Animation Options"):
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row(): with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)") flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=3e-6, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-11)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary") process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column(): with gr.Column():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear") process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"): with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video.render() output_video_i2v.render()
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video"): with gr.Accordion(open=True, label="The animated video"):
output_video_concat.render() output_video_concat_i2v.render()
with gr.Row(): with gr.Row():
# Examples # Examples
gr.Markdown("## You could also choose the examples below by one click ⬇️") gr.Markdown("## You could also choose the examples below by one click ⬇️")
with gr.Row(): with gr.Row():
gr.Examples( with gr.Tabs():
examples=data_examples, with gr.TabItem("🖼️ Portrait Animation"):
fn=gpu_wrapped_execute_video, gr.Examples(
inputs=[ examples=data_examples_i2v,
image_input, fn=gpu_wrapped_execute_video,
video_input, inputs=[
flag_relative_input, source_image_input,
flag_do_crop_input, driving_video_input,
flag_remap_input, flag_relative_input,
flag_crop_driving_video_input flag_do_crop_input,
], flag_remap_input,
outputs=[output_image, output_image_paste_back], flag_crop_driving_video_input,
examples_per_page=len(data_examples), ],
cache_examples=False, outputs=[output_image, output_image_paste_back],
) examples_per_page=len(data_examples_i2v),
gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True) cache_examples=False,
)
with gr.TabItem("🎞️ Portrait Video Editing"):
gr.Examples(
examples=data_examples_v2v,
fn=gpu_wrapped_execute_video,
inputs=[
source_video_input,
driving_video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
driving_smooth_observation_variance,
],
outputs=[output_image, output_image_paste_back],
examples_per_page=len(data_examples_v2v),
cache_examples=False,
)
# Retargeting
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True): with gr.Row(visible=True):
eye_retargeting_slider.render() eye_retargeting_slider.render()
lip_retargeting_slider.render() lip_retargeting_slider.render()
@ -185,6 +263,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"): with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render() output_image_paste_back.render()
# binding functions for buttons # binding functions for buttons
process_button_retargeting.click( process_button_retargeting.click(
# fn=gradio_pipeline.execute_image, # fn=gradio_pipeline.execute_image,
@ -196,18 +275,27 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
process_button_animation.click( process_button_animation.click(
fn=gpu_wrapped_execute_video, fn=gpu_wrapped_execute_video,
inputs=[ inputs=[
image_input, source_image_input,
video_input, source_video_input,
driving_video_input,
flag_relative_input, flag_relative_input,
flag_do_crop_input, flag_do_crop_input,
flag_remap_input, flag_remap_input,
flag_crop_driving_video_input flag_crop_driving_video_input,
flag_video_editing_head_rotation,
scale,
vx_ratio,
vy_ratio,
scale_crop_driving_video,
vx_ratio_crop_driving_video,
vy_ratio_crop_driving_video,
driving_smooth_observation_variance,
tab_selection,
], ],
outputs=[output_video, output_video_concat], outputs=[output_video_i2v, output_video_concat_i2v],
show_progress=True show_progress=True
) )
demo.launch( demo.launch(
server_port=args.server_port, server_port=args.server_port,
share=args.share, share=args.share,

View File

@ -9,7 +9,7 @@ The popularity of LivePortrait has exceeded our expectations. If you encounter a
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`. - <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option. - <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option.
### About driving video ### About driving video

View File

@ -0,0 +1,18 @@
## 2024/07/19
**Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! 🎉**
We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk).
### Updates
- <strong>Portrait video editing (v2v):</strong> Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with`--flag_source_video_eye_retargeting`.
- <strong>More options in Gradio:</strong> We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control.
### Community Contributions
- **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance:
- [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KwaiVGI/LivePortrait/issues/150))
- [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KwaiVGI/LivePortrait/issues/126) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait), [#142](https://github.com/KwaiVGI/LivePortrait/issues/142))
- **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KwaiVGI/LivePortrait/issues/119).

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,6 @@
<div style="font-size: 1.2em; text-align: center;">
Step 3: Click the <strong>🚀 Animate</strong> button below to generate, or click <strong>🧹 Clear</strong> to erase the results
</div>
<!-- <div style="font-size: 1.1em; text-align: center;">
<strong style="color: red;">Note:</strong> If both <strong>Source Image</strong> and <strong>Video</strong> are uploaded, the <strong>Source Image</strong> will be used. Please click the <strong>🧹 Clear</strong> button, then re-upload the <strong>Source Image</strong> or <strong>Video</strong>.
</div> -->

View File

@ -0,0 +1,19 @@
<span style="font-size: 1.2em;">🔥 To animate the source image or video with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. In the <strong>Animation Options for Source Image or Video</strong> section, we recommend enabling the <code>do crop (source)</code> option if faces occupy a small portion of your source image or video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. In the <strong>Animation Options for Driving Video</strong> section, the <code>relative head rotation</code> and <code>smooth strength</code> options only take effect if the source input is a video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
3. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
4. If you want to upload your own driving video, <strong>the best practice</strong>:
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
</div>

View File

@ -0,0 +1,13 @@
<br>
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
<div style="flex: 1; text-align: center; margin-right: 20px;">
<div style="display: inline-block;">
Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
</div>
</div>
<div style="flex: 1; text-align: center; margin-left: 20px;">
<div style="display: inline-block;">
Step 2: Upload a <strong>Driving Video</strong> (any aspect ratio) ⬇️
</div>
</div>
</div>

View File

@ -0,0 +1,20 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<!-- <span>Add mimics and lip sync to your static portrait driven by a video</span> -->
<!-- <span>Efficient Portrait Animation with Stitching and Retargeting Control</span> -->
<!-- <br> -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
&nbsp;
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
&nbsp;
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
&nbsp;
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
&nbsp;
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/github/stars/KwaiVGI/LivePortrait
"></a>
</div>
</div>
</div>

View File

@ -1,16 +0,0 @@
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. In the <strong>Animation Options</strong> section, we recommend enabling the <strong>do crop (source)</strong> option if faces occupy a small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
3. If you want to upload your own driving video, <strong>the best practice</strong>:
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
</div>

View File

@ -1,2 +0,0 @@
## 🤗 This is the official gradio demo for **LivePortrait**.
<div style="font-size: 1.2em;">Please upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) and upload a <strong>Driving Video</strong> (1:1 aspect ratio, or any aspect ratio with <code>do crop (driving video)</code> checked).</div>

View File

@ -1,11 +0,0 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
</div>
</div>
</div>

View File

@ -22,10 +22,10 @@ def fast_check_ffmpeg():
def fast_check_args(args: ArgumentConfig): def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source_image): if not osp.exists(args.source):
raise FileNotFoundError(f"source image not found: {args.source_image}") raise FileNotFoundError(f"source info not found: {args.source}")
if not osp.exists(args.driving_info): if not osp.exists(args.driving):
raise FileNotFoundError(f"driving info not found: {args.driving_info}") raise FileNotFoundError(f"driving info not found: {args.driving}")
def main(): def main():
@ -35,10 +35,9 @@ def main():
if not fast_check_ffmpeg(): if not fast_check_ffmpeg():
raise ImportError( raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html" "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
) )
# fast check the args
fast_check_args(args) fast_check_args(args)
# specify configs for inference # specify configs for inference

View File

@ -39,6 +39,7 @@
## 🔥 Updates ## 🔥 Updates
- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md).
- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143). - **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md). - **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)! - **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
@ -47,11 +48,11 @@
## Introduction ## Introduction 📖
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168). This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖. We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
## 🔥 Getting Started ## Getting Started 🏁
### 1. Clone the code and prepare the environment ### 1. Clone the code and prepare the environment
```bash ```bash
git clone https://github.com/KwaiVGI/LivePortrait git clone https://github.com/KwaiVGI/LivePortrait
@ -61,9 +62,10 @@ cd LivePortrait
conda create -n LivePortrait python==3.9 conda create -n LivePortrait python==3.9
conda activate LivePortrait conda activate LivePortrait
# install dependencies with pip (for Linux and Windows) # install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt pip install -r requirements.txt
# for macOS with Apple Silicon # for macOS with Apple Silicon users
pip install -r requirements_macOS.txt pip install -r requirements_macOS.txt
``` ```
@ -113,7 +115,7 @@ python inference.py
PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py
``` ```
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result. If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image or video, and generated result.
<p align="center"> <p align="center">
<img src="./assets/docs/inference.gif" alt="image"> <img src="./assets/docs/inference.gif" alt="image">
@ -122,18 +124,18 @@ If the script runs successfully, you will get an output mp4 file named `animatio
Or, you can change the input by specifying the `-s` and `-d` arguments: Or, you can change the input by specifying the `-s` and `-d` arguments:
```bash ```bash
# source input is an image
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
# disable pasting back to run faster # source input is a video ✨
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d0.mp4
# more options to see # more options to see
python inference.py -h python inference.py -h
``` ```
#### Driving video auto-cropping #### Driving video auto-cropping 📢📢📢
To use your own driving video, we **recommend**: ⬇️
📕 To use your own driving video, we **recommend**:
- Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`. - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
- Focus on the head area, similar to the example videos. - Focus on the head area, similar to the example videos.
- Minimize shoulder movement. - Minimize shoulder movement.
@ -144,25 +146,24 @@ Below is a auto-cropping case by `--flag_crop_driving_video`:
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video
``` ```
If you find the results of auto-cropping is not well, you can modify the `--scale_crop_video`, `--vy_ratio_crop_video` options to adjust the scale and offset, or do it manually. If you find the results of auto-cropping is not well, you can modify the `--scale_crop_driving_video`, `--vy_ratio_crop_driving_video` options to adjust the scale and offset, or do it manually.
#### Motion template making #### Motion template making
You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as: You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as:
```bash ```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl # portrait animation
python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d5.pkl # portrait video editing
``` ```
**Discover more interesting results on our [Homepage](https://liveportrait.github.io)** 😊
### 4. Gradio interface 🤗 ### 4. Gradio interface 🤗
We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by: We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
```bash ```bash
# For Linux and Windows: # For Linux and Windows users (and macOS with Intel??)
python app.py python app.py
# For macOS with Apple Silicon, Intel not supported, this maybe 20x slower than RTX 4090 # For macOS with Apple Silicon users, Intel not supported, this maybe 20x slower than RTX 4090
PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py
``` ```
@ -210,7 +211,7 @@ Discover the invaluable resources contributed by our community to enhance your L
And many more amazing contributions from our community! And many more amazing contributions from our community!
## Acknowledgements ## Acknowledgements 💐
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions. We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
## Citation 💖 ## Citation 💖

View File

@ -19,3 +19,4 @@ matplotlib==3.9.0
imageio-ffmpeg==0.5.1 imageio-ffmpeg==0.5.1
tyro==0.8.5 tyro==0.8.5
gradio==4.37.1 gradio==4.37.1
pykalman==0.9.7

View File

@ -14,35 +14,39 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig): class ArgumentConfig(PrintableConfig):
########## input arguments ########## ########## input arguments ##########
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait or video
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
########## inference arguments ########## ########## inference arguments ##########
flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False. flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
device_id: int = 0 # gpu device id device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP! flag_force_cpu: bool = False # force cpu inference, WIP!
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_eye_retargeting: bool = False # not recommend to be True, WIP flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
flag_lip_retargeting: bool = False # not recommend to be True, WIP flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
flag_eye_retargeting: bool = False # not recommend to be True, WIP
flag_lip_retargeting: bool = False # not recommend to be True, WIP
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
flag_relative_motion: bool = True # whether to use relative motion flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True driving_smooth_observation_variance: float = 3e-6 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
########## crop arguments ########## ########## source crop arguments ##########
scale: float = 2.3 # the ratio of face area is smaller if scale is larger scale: float = 2.3 # the ratio of face area is smaller if scale is larger
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
scale_crop_video: float = 2.2 # scale factor for cropping video ########## driving crop arguments ##########
vx_ratio_crop_video: float = 0. # adjust y offset scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video
vy_ratio_crop_video: float = -0.1 # adjust x offset vx_ratio_crop_driving_video: float = 0. # adjust y offset
vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
########## gradio arguments ########## ########## gradio arguments ##########
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
share: bool = False # whether to share the server to public share: bool = False # whether to share the server to public
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation

View File

@ -15,15 +15,16 @@ class CropConfig(PrintableConfig):
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx" landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
device_id: int = 0 # gpu device id device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP flag_force_cpu: bool = False # force cpu inference, WIP
########## source image cropping option ########## ########## source image or video cropping option ##########
dsize: int = 512 # crop size dsize: int = 512 # crop size
scale: float = 2.5 # scale factor scale: float = 2.8 # scale factor
vx_ratio: float = 0 # vx ratio vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down vy_ratio: float = -0.125 # vy ratio +up, -down
max_face_num: int = 0 # max face number, 0 mean no limit max_face_num: int = 0 # max face number, 0 mean no limit
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
########## driving video auto cropping option ########## ########## driving video auto cropping option ##########
scale_crop_video: float = 2.2 # 2.0 # scale factor for cropping video scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video
vx_ratio_crop_video: float = 0.0 # adjust y offset vx_ratio_crop_driving_video: float = 0.0 # adjust y offset
vy_ratio_crop_video: float = -0.1 # adjust x offset vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
direction: str = "large-small" # direction of cropping direction: str = "large-small" # direction of cropping

View File

@ -25,7 +25,9 @@ class InferenceConfig(PrintableConfig):
flag_use_half_precision: bool = True flag_use_half_precision: bool = True
flag_crop_driving_video: bool = False flag_crop_driving_video: bool = False
device_id: int = 0 device_id: int = 0
flag_lip_zero: bool = True flag_normalize_lip: bool = True
flag_source_video_eye_retargeting: bool = False
flag_video_editing_head_rotation: bool = False
flag_eye_retargeting: bool = False flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False flag_lip_retargeting: bool = False
flag_stitching: bool = True flag_stitching: bool = True
@ -37,7 +39,9 @@ class InferenceConfig(PrintableConfig):
flag_do_torch_compile: bool = False flag_do_torch_compile: bool = False
# NOT EXPORTED PARAMS # NOT EXPORTED PARAMS
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
driving_smooth_observation_variance: float = 3e-6 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
anchor_frame: int = 0 # TO IMPLEMENT anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape input_shape: Tuple[int, int] = (256, 256) # input shape
@ -47,5 +51,5 @@ class InferenceConfig(PrintableConfig):
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
size_gif: int = 256 # default gif size, TO IMPLEMENT size_gif: int = 256 # default gif size, TO IMPLEMENT
source_max_dim: int = 1280 # the max dim of height and width of source image source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image can be divided by this number source_division: int = 2 # make sure the height and width of source image or video can be divided by this number

View File

@ -3,6 +3,8 @@
""" """
Pipeline for gradio Pipeline for gradio
""" """
import os.path as osp
import gradio as gr import gradio as gr
from .config.argument_config import ArgumentConfig from .config.argument_config import ArgumentConfig
@ -11,6 +13,7 @@ from .utils.io import load_img_online
from .utils.rprint import rlog as log from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix from .utils.camera import get_rotation_matrix
from .utils.helper import is_square_video
def update_args(args, user_args): def update_args(args, user_args):
@ -31,23 +34,53 @@ class GradioPipeline(LivePortraitPipeline):
def execute_video( def execute_video(
self, self,
input_image_path, input_source_image_path=None,
input_video_path, input_source_video_path=None,
flag_relative_input, input_driving_video_path=None,
flag_do_crop_input, flag_relative_input=True,
flag_remap_input, flag_do_crop_input=True,
flag_crop_driving_video_input flag_remap_input=True,
flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=3e-6,
tab_selection=None,
): ):
""" for video driven potrait animation """ for video-driven potrait animation or video editing
""" """
if input_image_path is not None and input_video_path is not None: if tab_selection == 'Image':
input_source_path = input_source_image_path
elif tab_selection == 'Video':
input_source_path = input_source_video_path
else:
input_source_path = input_source_image_path
if input_source_path is not None and input_driving_video_path is not None:
if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False:
flag_crop_driving_video_input = True
log("The source video is not square, the driving video will be cropped to square automatically.")
gr.Info("The source video is not square, the driving video will be cropped to square automatically.", duration=2)
args_user = { args_user = {
'source_image': input_image_path, 'source': input_source_path,
'driving_info': input_video_path, 'driving': input_driving_video_path,
'flag_relative': flag_relative_input, 'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input, 'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input, 'flag_pasteback': flag_remap_input,
'flag_crop_driving_video': flag_crop_driving_video_input 'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
'scale_crop_driving_video': scale_crop_driving_video,
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
'driving_smooth_observation_variance': driving_smooth_observation_variance,
} }
# update config from user input # update config from user input
self.args = update_args(self.args, args_user) self.args = update_args(self.args, args_user)
@ -58,7 +91,7 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Run successfully!", duration=2) gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat, return video_path, video_path_concat,
else: else:
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5) raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True): def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
""" for single image retargeting """ for single image retargeting
@ -79,9 +112,8 @@ class GradioPipeline(LivePortraitPipeline):
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user) combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
num_kp = x_s_user.shape[1]
# default: use x_s # default: use x_s
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) x_d_new = x_s_user + eyes_delta + lip_delta
# D(W(f_s; x_s, x_d)) # D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new) out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0] out = self.live_portrait_wrapper.parse_output(out['out'])[0]
@ -114,4 +146,4 @@ class GradioPipeline(LivePortraitPipeline):
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else: else:
# when press the clear button, go here # when press the clear button, go here
raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5) raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)

View File

@ -20,8 +20,9 @@ from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from .utils.filter import smooth
from .utils.rprint import rlog as log from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk # from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapper from .live_portrait_wrapper import LivePortraitWrapper
@ -37,125 +38,252 @@ class LivePortraitPipeline(object):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg) self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg) self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs):
n_frames = I_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_eyes_lst': [],
'c_lip_lst': [],
'x_i_info_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s, R, δ and t for inference
I_i = I_lst[i]
x_i_info = self.live_portrait_wrapper.get_kp_info(I_i)
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
item_dct = {
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
'R': R_i.cpu().numpy().astype(np.float32),
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_eyes = c_eyes_lst[i].astype(np.float32)
template_dct['c_eyes_lst'].append(c_eyes)
c_lip = c_lip_lst[i].astype(np.float32)
template_dct['c_lip_lst'].append(c_lip)
template_dct['x_i_info_lst'].append(x_i_info)
return template_dct
def execute(self, args: ArgumentConfig): def execute(self, args: ArgumentConfig):
# for convenience # for convenience
inf_cfg = self.live_portrait_wrapper.inference_cfg inf_cfg = self.live_portrait_wrapper.inference_cfg
device = self.live_portrait_wrapper.device device = self.live_portrait_wrapper.device
crop_cfg = self.cropper.crop_cfg crop_cfg = self.cropper.crop_cfg
######## process source portrait ######## ######## load source input ########
img_rgb = load_image_rgb(args.source_image) flag_is_source_video = False
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division) source_fps = None
log(f"Load source image from {args.source_image}") if is_image(args.source):
flag_is_source_video = False
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg) img_rgb = load_image_rgb(args.source)
if crop_info is None: img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
raise Exception("No face detected in the source image!") log(f"Load source image from {args.source}")
source_lmk = crop_info['lmk_crop'] source_rgb_lst = [img_rgb]
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] elif is_video(args.source):
flag_is_source_video = True
if inf_cfg.flag_do_crop: source_rgb_lst = load_video(args.source)
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) source_rgb_lst = [resize_to_limit(img, inf_cfg.source_max_dim, inf_cfg.source_division) for img in source_rgb_lst]
else: source_fps = int(get_fps(args.source))
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256 log(f"Load source video from {args.source}, FPS is {source_fps}")
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) else: # source input is an unknown format
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) raise Exception(f"Unknown source format: {args.source}")
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite
if flag_lip_zero:
# let lip-open scalar to be 0 at first
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold:
flag_lip_zero = False
else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ######## ######## process driving info ########
flag_load_from_template = is_template(args.driving_info) flag_load_from_template = is_template(args.driving)
driving_rgb_crop_256x256_lst = None driving_rgb_crop_256x256_lst = None
wfp_template = None wfp_template = None
if flag_load_from_template: if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None # NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green') log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
template_dct = load(args.driving_info) driving_template_dct = load(args.driving)
n_frames = template_dct['n_frames'] c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames']
if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
else:
n_frames = driving_n_frames
# set output_fps # set output_fps
output_fps = template_dct.get('output_fps', inf_cfg.output_fps) output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}') log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video: if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving_info) and is_video(args.driving_info): elif osp.exists(args.driving) and is_video(args.driving):
# load from video file, AND make motion template # load from video file, AND make motion template
log(f"Load video: {args.driving_info}") output_fps = int(get_fps(args.driving))
if osp.isdir(args.driving_info): log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
output_fps = inf_cfg.output_fps
else:
output_fps = int(get_fps(args.driving_info))
log(f'The FPS of {args.driving_info} is: {output_fps}')
log(f"Load video file (mp4 mov avi etc...): {args.driving_info}") driving_rgb_lst = load_video(args.driving)
driving_rgb_lst = load_driving_info(args.driving_info) driving_n_frames = len(driving_rgb_lst)
######## make motion template ######## ######## make motion template ########
log("Start making motion template...") log("Start making driving motion template...")
if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
driving_rgb_lst = driving_rgb_lst[:n_frames]
else:
n_frames = driving_n_frames
if inf_cfg.flag_crop_driving_video: if inf_cfg.flag_crop_driving_video:
ret = self.cropper.crop_driving_video(driving_rgb_lst) ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.') log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst'] if len(ret_d["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else: else:
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst) driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256 driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
#######################################
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst) c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_ratio(driving_lmk_crop_lst)
# save the motion template # save the motion template
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst) I_d_lst = self.live_portrait_wrapper.prepare_videos(driving_rgb_crop_256x256_lst)
template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps) driving_template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving_info) + '.pkl' wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, template_dct) dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}") log(f"Dump motion template to {wfp_template}")
n_frames = I_d_lst.shape[0]
else: else:
raise Exception(f"{args.driving_info} not exists or unsupported driving info types!") raise Exception(f"{args.driving} not exists or unsupported driving info types!")
#########################################
######## prepare for pasteback ######## ######## prepare for pasteback ########
I_p_pstbk_lst = None I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_pstbk_lst = [] I_p_pstbk_lst = []
log("Prepared pasteback mask done.") log("Prepared pasteback mask done.")
#########################################
I_p_lst = [] I_p_lst = []
R_d_0, x_d_0_info = None, None R_d_0, x_d_0_info = None, None
flag_normalize_lip = inf_cfg.flag_normalize_lip # not overwrite
flag_source_video_eye_retargeting = inf_cfg.flag_source_video_eye_retargeting # not overwrite
lip_delta_before_animation, eye_delta_before_animation = None, None
######## process source info ########
if flag_is_source_video:
log(f"Start making source motion template...")
source_rgb_lst = source_rgb_lst[:n_frames]
if inf_cfg.flag_do_crop:
ret_s = self.cropper.crop_source_video(source_rgb_lst, crop_cfg)
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
if len(ret_s["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_s["frame_crop_lst"]))
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
else:
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
# save the motion template
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else: # if the input is a source image, process it only once
crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop_256x256 = crop_info['img_crop_256x256']
if inf_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
# let lip-open scalar to be 0 at first
if flag_normalize_lip:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
######## animate ########
log(f"The animated video consists of {n_frames} frames.")
for i in track(range(n_frames), description='🚀Animating...', total=n_frames): for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
x_d_i_info = template_dct['motion'][i] if flag_is_source_video: # source video
x_d_i_info = dct2device(x_d_i_info, device) x_s_info_tiny = source_template_dct['motion'][i]
R_d_i = x_d_i_info['R_d'] x_s_info_tiny = dct2device(x_s_info_tiny, device)
if i == 0: source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
x_s_info = source_template_dct['x_i_info_lst'][i]
x_c_s = x_s_info['kp']
R_s = x_s_info_tiny['R']
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
# let lip-open scalar to be 0 at first if the input is a video
if flag_normalize_lip:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
if flag_source_video_eye_retargeting:
if i == 0:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
if c_d_eye_before_animation_frame_zero[0][0] < inf_cfg.source_video_eye_retargeting_threshold:
c_d_eye_before_animation_frame_zero = [[0.39]]
combined_eye_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk)
eye_delta_before_animation = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
if i == 0: # cache the first frame
R_d_0 = R_d_i R_d_0 = R_d_i
x_d_0_info = x_d_i_info x_d_0_info = x_d_i_info
if inf_cfg.flag_relative_motion: if inf_cfg.flag_relative_motion:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s if flag_is_source_video:
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) if inf_cfg.flag_video_editing_head_rotation:
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) R_new = x_d_r_lst_smooth[i]
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) else:
R_new = R_s
else:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else: else:
R_new = R_d_i R_new = R_d_i
delta_new = x_d_i_info['exp'] delta_new = x_d_i_info['exp']
@ -168,16 +296,20 @@ class LivePortraitPipeline(object):
# Algorithm 1: # Algorithm 1:
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# without stitching or retargeting # without stitching or retargeting
if flag_lip_zero: if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) x_d_i_new += lip_delta_before_animation
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else: else:
pass pass
elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# with stitching and without retargeting # with stitching and without retargeting
if flag_lip_zero: if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation
else: else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else: else:
eyes_delta, lip_delta = None, None eyes_delta, lip_delta = None, None
if inf_cfg.flag_eye_retargeting: if inf_cfg.flag_eye_retargeting:
@ -193,12 +325,12 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_relative_motion: # use x_s if inf_cfg.flag_relative_motion: # use x_s
x_d_i_new = x_s + \ x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ (eyes_delta if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) (lip_delta if lip_delta is not None else 0)
else: # use x_d,i else: # use x_d,i
x_d_i_new = x_d_i_new + \ x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ (eyes_delta if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) (lip_delta if lip_delta is not None else 0)
if inf_cfg.flag_stitching: if inf_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
@ -208,38 +340,52 @@ class LivePortraitPipeline(object):
I_p_lst.append(I_p_i) I_p_lst.append(I_p_i)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU # TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float) if flag_is_source_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float)
else:
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], source_rgb_lst[0], mask_ori_float)
I_p_pstbk_lst.append(I_p_pstbk) I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir) mkdir(args.output_dir)
wfp_concat = None wfp_concat = None
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info) flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
######### build final concat result ######### ######### build the final concatenation result #########
# driving frame | source image | generation, or source image | generation # driving frame | source frame | generation, or source frame | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst) if flag_is_source_video:
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4') frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# NOTE: update output fps
output_fps = source_fps if flag_is_source_video else output_fps
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_has_audio: if flag_source_has_audio or flag_driving_has_audio:
# final result with concat # final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4') wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio) audio_from_which_video = args.source if flag_source_has_audio else args.driving
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat) os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}") log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
# save drived result # save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else: else:
images2video(I_p_lst, wfp=wfp, fps=output_fps) images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build final result ######### ######### build the final result #########
if flag_has_audio: if flag_source_has_audio or flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4') wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
add_audio_to_video(wfp, args.driving_info, wfp_with_audio) audio_from_which_video = args.source if flag_source_has_audio else args.driving
log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp) os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp} with {wfp_with_audio}") log(f"Replace {wfp} with {wfp_with_audio}")
@ -250,36 +396,3 @@ class LivePortraitPipeline(object):
log(f'Animated video with concat: {wfp_concat}') log(f'Animated video with concat: {wfp_concat}')
return wfp, wfp_concat return wfp, wfp_concat
def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs):
n_frames = I_d_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_d_eyes_lst': [],
'c_d_lip_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s_d, R_d, δ_d and t_d for inference
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
item_dct = {
'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32),
'R_d': R_d_i.cpu().numpy().astype(np.float32),
'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_d_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_d_eyes = c_d_eyes_lst[i].astype(np.float32)
template_dct['c_d_eyes_lst'].append(c_d_eyes)
c_d_lip = c_d_lip_lst[i].astype(np.float32)
template_dct['c_d_lip_lst'].append(c_d_lip)
return template_dct

View File

@ -95,7 +95,7 @@ class LivePortraitWrapper(object):
x = x.to(self.device) x = x.to(self.device)
return x return x
def prepare_driving_videos(self, imgs) -> torch.Tensor: def prepare_videos(self, imgs) -> torch.Tensor:
""" construct the input as standard """ construct the input as standard
imgs: NxBxHxWx3, uint8 imgs: NxBxHxWx3, uint8
""" """
@ -216,7 +216,7 @@ class LivePortraitWrapper(object):
with torch.no_grad(): with torch.no_grad():
delta = self.stitching_retargeting_module['eye'](feat_eye) delta = self.stitching_retargeting_module['eye'](feat_eye)
return delta return delta.reshape(-1, kp_source.shape[1], 3)
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor: def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
""" """
@ -229,7 +229,7 @@ class LivePortraitWrapper(object):
with torch.no_grad(): with torch.no_grad():
delta = self.stitching_retargeting_module['lip'](feat_lip) delta = self.stitching_retargeting_module['lip'](feat_lip)
return delta return delta.reshape(-1, kp_source.shape[1], 3)
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" """
@ -301,10 +301,10 @@ class LivePortraitWrapper(object):
return out return out
def calc_driving_ratio(self, driving_lmk_lst): def calc_ratio(self, lmk_lst):
input_eye_ratio_lst = [] input_eye_ratio_lst = []
input_lip_ratio_lst = [] input_lip_ratio_lst = []
for lmk in driving_lmk_lst: for lmk in lmk_lst:
# for eyes retargeting # for eyes retargeting
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
# for lip retargeting # for lip retargeting

View File

@ -31,6 +31,7 @@ class Trajectory:
end: int = -1 # end frame end: int = -1 # end frame
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
M_c2o_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # M_c2o list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
@ -104,6 +105,7 @@ class Cropper(object):
scale=crop_cfg.scale, scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio, vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio, vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
) )
lmk = self.landmark_runner.run(img_rgb, lmk) lmk = self.landmark_runner.run(img_rgb, lmk)
@ -115,6 +117,58 @@ class Cropper(object):
return ret_dct return ret_dct
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
for idx, frame_rgb in enumerate(source_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
# crop the face
ret_dct = crop_image(
frame_rgb, # ndarray
lmk, # 106x2 or Nx2
dsize=crop_cfg.dsize,
scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.landmark_runner.run(frame_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop_256x256"])
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop_256x256"])
trajectory.M_c2o_lst.append(ret_dct['M_c2o'])
return {
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
"lmk_crop_lst": trajectory.lmk_crop_lst,
"M_c2o_lst": trajectory.M_c2o_lst,
}
def crop_driving_video(self, driving_rgb_lst, **kwargs): def crop_driving_video(self, driving_rgb_lst, **kwargs):
"""Tracking based landmarks/alignment and cropping""" """Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory() trajectory = Trajectory()
@ -142,9 +196,9 @@ class Cropper(object):
trajectory.lmk_lst.append(lmk) trajectory.lmk_lst.append(lmk)
ret_bbox = parse_bbox_from_landmark( ret_bbox = parse_bbox_from_landmark(
lmk, lmk,
scale=self.crop_cfg.scale_crop_video, scale=self.crop_cfg.scale_crop_driving_video,
vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video, vx_ratio_crop_driving_video=self.crop_cfg.vx_ratio_crop_driving_video,
vy_ratio=self.crop_cfg.vy_ratio_crop_video, vy_ratio=self.crop_cfg.vy_ratio_crop_driving_video,
)["bbox"] )["bbox"]
bbox = [ bbox = [
ret_bbox[0, 0], ret_bbox[0, 0],
@ -174,6 +228,7 @@ class Cropper(object):
"lmk_crop_lst": trajectory.lmk_crop_lst, "lmk_crop_lst": trajectory.lmk_crop_lst,
} }
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs): def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
"""Tracking based landmarks/alignment""" """Tracking based landmarks/alignment"""
trajectory = Trajectory() trajectory = Trajectory()

19
src/utils/filter.py Normal file
View File

@ -0,0 +1,19 @@
# coding: utf-8
import torch
import numpy as np
from pykalman import KalmanFilter
def smooth(x_d_lst, shape, device, observation_variance=3e-6, process_variance=1e-5):
x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst]
x_d_stacked = np.vstack(x_d_lst_reshape)
kf = KalmanFilter(
initial_state_mean=x_d_stacked[0],
n_dim_obs=x_d_stacked.shape[1],
transition_covariance=process_variance * np.eye(x_d_stacked.shape[1]),
observation_covariance=observation_variance * np.eye(x_d_stacked.shape[1])
)
smoothed_state_means, _ = kf.smooth(x_d_stacked)
x_d_lst_smooth = [torch.tensor(state_mean.reshape(shape[-2:]), dtype=torch.float32, device=device) for state_mean in smoothed_state_means]
return x_d_lst_smooth

View File

@ -8,6 +8,8 @@ import os
import os.path as osp import os.path as osp
import torch import torch
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import cv2
from ..modules.spade_generator import SPADEDecoder from ..modules.spade_generator import SPADEDecoder
from ..modules.warping_network import WarpingNetwork from ..modules.warping_network import WarpingNetwork
@ -42,6 +44,11 @@ def remove_suffix(filepath):
return osp.join(osp.dirname(filepath), basename(filepath)) return osp.join(osp.dirname(filepath), basename(filepath))
def is_image(file_path):
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff')
return file_path.lower().endswith(image_extensions)
def is_video(file_path): def is_video(file_path):
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
return True return True
@ -143,3 +150,16 @@ def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f: with open(fp, 'r', encoding='utf-8') as f:
content = f.read() content = f.read()
return content return content
def is_square_video(video_path):
video = cv2.VideoCapture(video_path)
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
video.release()
# if width != height:
# gr.Info(f"Uploaded video is not square, force do crop (driving) to be True")
return width == height

View File

@ -1,7 +1,5 @@
# coding: utf-8 # coding: utf-8
import os
from glob import glob
import os.path as osp import os.path as osp
import imageio import imageio
import numpy as np import numpy as np
@ -18,23 +16,17 @@ def load_image_rgb(image_path: str):
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def load_driving_info(driving_info): def load_video(video_info, n_frames=-1):
driving_video_ori = [] reader = imageio.get_reader(video_info, "ffmpeg")
def load_images_from_directory(directory): ret = []
image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg'))) for idx, frame_rgb in enumerate(reader):
return [load_image_rgb(im_path) for im_path in image_paths] if n_frames > 0 and idx >= n_frames:
break
ret.append(frame_rgb)
def load_images_from_video(file_path): reader.close()
reader = imageio.get_reader(file_path, "ffmpeg") return ret
return [image for _, image in enumerate(reader)]
if osp.isdir(driving_info):
driving_video_ori = load_images_from_directory(driving_info)
elif osp.isfile(driving_info):
driving_video_ori = load_images_from_video(driving_info)
return driving_video_ori
def contiguous(obj): def contiguous(obj):

View File

@ -80,14 +80,15 @@ def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):
return img return img
def concat_frames(driving_image_lst, source_image, I_p_lst): def concat_frames(driving_image_lst, source_image_lst, I_p_lst):
# TODO: add more concat style, e.g., left-down corner driving # TODO: add more concat style, e.g., left-down corner driving
out_lst = [] out_lst = []
h, w, _ = I_p_lst[0].shape h, w, _ = I_p_lst[0].shape
source_image_resized_lst = [cv2.resize(img, (w, h)) for img in source_image_lst]
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'): for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
I_p = I_p_lst[idx] I_p = I_p_lst[idx]
source_image_resized = cv2.resize(source_image, (w, h)) source_image_resized = source_image_resized_lst[idx] if len(source_image_lst) > 1 else source_image_resized_lst[0]
if driving_image_lst is None: if driving_image_lst is None:
out = np.hstack((source_image_resized, I_p)) out = np.hstack((source_image_resized, I_p))