feat: v2v and gradio upgrade (#172)

* feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

feat: v2v

* chore: format and package

* chore: refactor codebase

* feat: v2v

* feat: v2v

* feat: v2v

* feat: v2v

* feat: v2v

* feat: v2v

* feat: gradio tab auto select

* chore: log auto crop

* doc: update changelog

* doc: update changelog

---------

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
This commit is contained in:
Jianzhu Guo 2024-07-19 23:39:05 +08:00 committed by GitHub
parent 24ce67d652
commit 24dcfafdc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 689 additions and 312 deletions

148
app.py
View File

@ -26,18 +26,20 @@ def fast_check_ffmpeg():
except:
return False
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
# global_tab_selection = None
gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg,
@ -55,10 +57,10 @@ def gpu_wrapped_execute_image(*args, **kwargs):
# assets
title_md = "assets/gradio_title.md"
title_md = "assets/gradio/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples = [
data_examples_i2v = [
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
@ -66,6 +68,14 @@ data_examples = [
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
data_examples_v2v = [
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-6],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-6],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-6],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-6],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-6],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-6],
]
#################### interface logic ####################
# Define components first
@ -74,15 +84,22 @@ lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="tar
retargeting_input_image = gr.Image(type="filepath")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video()
output_video_concat = gr.Video()
output_video_i2v = gr.Video(autoplay=True)
output_video_concat_i2v = gr.Video(autoplay=True)
output_video_v2v = gr.Video(autoplay=True)
output_video_concat_v2v = gr.Video(autoplay=True)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
gr.Markdown(load_description("assets/gradio/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Source Portrait"):
image_input = gr.Image(type="filepath")
with gr.Column():
with gr.Tabs():
with gr.TabItem("🖼️ Source Image") as tab_image:
with gr.Accordion(open=True, label="Source Image"):
source_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s9.jpg")],
@ -92,11 +109,39 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
],
inputs=[image_input],
inputs=[source_image_input],
cache_examples=False,
)
with gr.TabItem("🎞️ Source Video") as tab_video:
with gr.Accordion(open=True, label="Source Video"):
source_video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s13.mp4")],
# [osp.join(example_portrait_dir, "s14.mp4")],
# [osp.join(example_portrait_dir, "s15.mp4")],
[osp.join(example_portrait_dir, "s18.mp4")],
# [osp.join(example_portrait_dir, "s19.mp4")],
[osp.join(example_portrait_dir, "s20.mp4")],
],
inputs=[source_video_input],
cache_examples=False,
)
tab_selection = gr.Textbox(visible=False)
tab_image.select(lambda: "Image", None, tab_selection)
tab_video.select(lambda: "Video", None, tab_selection)
with gr.Accordion(open=True, label="Cropping Options for Source Image or Video"):
with gr.Row():
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=2.9, step=0.05)
vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column():
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video()
driving_video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
@ -105,49 +150,82 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
],
inputs=[video_input],
inputs=[driving_video_input],
cache_examples=False,
)
# with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
with gr.Row():
with gr.Accordion(open=False, label="Animation Instructions and Options"):
gr.Markdown(load_description("assets/gradio_description_animation.md"))
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=2.9, step=0.05)
vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=3e-6, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-11)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
with gr.Row():
with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video.render()
output_video_i2v.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat.render()
output_video_concat_i2v.render()
with gr.Row():
# Examples
gr.Markdown("## You could also choose the examples below by one click ⬇️")
with gr.Row():
with gr.Tabs():
with gr.TabItem("🖼️ Portrait Animation"):
gr.Examples(
examples=data_examples,
examples=data_examples_i2v,
fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
source_image_input,
driving_video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
flag_crop_driving_video_input,
],
outputs=[output_image, output_image_paste_back],
examples_per_page=len(data_examples),
examples_per_page=len(data_examples_i2v),
cache_examples=False,
)
gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
with gr.TabItem("🎞️ Portrait Video Editing"):
gr.Examples(
examples=data_examples_v2v,
fn=gpu_wrapped_execute_video,
inputs=[
source_video_input,
driving_video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
driving_smooth_observation_variance,
],
outputs=[output_image, output_image_paste_back],
examples_per_page=len(data_examples_v2v),
cache_examples=False,
)
# Retargeting
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True):
eye_retargeting_slider.render()
lip_retargeting_slider.render()
@ -185,6 +263,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render()
# binding functions for buttons
process_button_retargeting.click(
# fn=gradio_pipeline.execute_image,
@ -196,18 +275,27 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
source_image_input,
source_video_input,
driving_video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
scale,
vx_ratio,
vy_ratio,
scale_crop_driving_video,
vx_ratio_crop_driving_video,
vy_ratio_crop_driving_video,
driving_smooth_observation_variance,
tab_selection,
],
outputs=[output_video, output_video_concat],
outputs=[output_video_i2v, output_video_concat_i2v],
show_progress=True
)
demo.launch(
server_port=args.server_port,
share=args.share,

View File

@ -9,7 +9,7 @@ The popularity of LivePortrait has exceeded our expectations. If you encounter a
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option.
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option.
### About driving video

View File

@ -0,0 +1,18 @@
## 2024/07/19
**Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! 🎉**
We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk).
### Updates
- <strong>Portrait video editing (v2v):</strong> Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with`--flag_source_video_eye_retargeting`.
- <strong>More options in Gradio:</strong> We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control.
### Community Contributions
- **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance:
- [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KwaiVGI/LivePortrait/issues/150))
- [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KwaiVGI/LivePortrait/issues/126) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait), [#142](https://github.com/KwaiVGI/LivePortrait/issues/142))
- **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KwaiVGI/LivePortrait/issues/119).

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,6 @@
<div style="font-size: 1.2em; text-align: center;">
Step 3: Click the <strong>🚀 Animate</strong> button below to generate, or click <strong>🧹 Clear</strong> to erase the results
</div>
<!-- <div style="font-size: 1.1em; text-align: center;">
<strong style="color: red;">Note:</strong> If both <strong>Source Image</strong> and <strong>Video</strong> are uploaded, the <strong>Source Image</strong> will be used. Please click the <strong>🧹 Clear</strong> button, then re-upload the <strong>Source Image</strong> or <strong>Video</strong>.
</div> -->

View File

@ -0,0 +1,19 @@
<span style="font-size: 1.2em;">🔥 To animate the source image or video with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. In the <strong>Animation Options for Source Image or Video</strong> section, we recommend enabling the <code>do crop (source)</code> option if faces occupy a small portion of your source image or video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. In the <strong>Animation Options for Driving Video</strong> section, the <code>relative head rotation</code> and <code>smooth strength</code> options only take effect if the source input is a video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
3. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
4. If you want to upload your own driving video, <strong>the best practice</strong>:
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
</div>

View File

@ -0,0 +1,13 @@
<br>
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
<div style="flex: 1; text-align: center; margin-right: 20px;">
<div style="display: inline-block;">
Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
</div>
</div>
<div style="flex: 1; text-align: center; margin-left: 20px;">
<div style="display: inline-block;">
Step 2: Upload a <strong>Driving Video</strong> (any aspect ratio) ⬇️
</div>
</div>
</div>

View File

@ -0,0 +1,20 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<!-- <span>Add mimics and lip sync to your static portrait driven by a video</span> -->
<!-- <span>Efficient Portrait Animation with Stitching and Retargeting Control</span> -->
<!-- <br> -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
&nbsp;
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
&nbsp;
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
&nbsp;
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
&nbsp;
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/github/stars/KwaiVGI/LivePortrait
"></a>
</div>
</div>
</div>

View File

@ -1,16 +0,0 @@
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. In the <strong>Animation Options</strong> section, we recommend enabling the <strong>do crop (source)</strong> option if faces occupy a small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
3. If you want to upload your own driving video, <strong>the best practice</strong>:
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
</div>

View File

@ -1,2 +0,0 @@
## 🤗 This is the official gradio demo for **LivePortrait**.
<div style="font-size: 1.2em;">Please upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) and upload a <strong>Driving Video</strong> (1:1 aspect ratio, or any aspect ratio with <code>do crop (driving video)</code> checked).</div>

View File

@ -1,11 +0,0 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
</div>
</div>
</div>

View File

@ -22,10 +22,10 @@ def fast_check_ffmpeg():
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source_image):
raise FileNotFoundError(f"source image not found: {args.source_image}")
if not osp.exists(args.driving_info):
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
if not osp.exists(args.source):
raise FileNotFoundError(f"source info not found: {args.source}")
if not osp.exists(args.driving):
raise FileNotFoundError(f"driving info not found: {args.driving}")
def main():
@ -35,10 +35,9 @@ def main():
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
# fast check the args
fast_check_args(args)
# specify configs for inference

View File

@ -39,6 +39,7 @@
## 🔥 Updates
- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md).
- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
@ -47,11 +48,11 @@
## Introduction
## Introduction 📖
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
## 🔥 Getting Started
## Getting Started 🏁
### 1. Clone the code and prepare the environment
```bash
git clone https://github.com/KwaiVGI/LivePortrait
@ -61,9 +62,10 @@ cd LivePortrait
conda create -n LivePortrait python==3.9
conda activate LivePortrait
# install dependencies with pip (for Linux and Windows)
# install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt
# for macOS with Apple Silicon
# for macOS with Apple Silicon users
pip install -r requirements_macOS.txt
```
@ -113,7 +115,7 @@ python inference.py
PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py
```
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result.
If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image or video, and generated result.
<p align="center">
<img src="./assets/docs/inference.gif" alt="image">
@ -122,18 +124,18 @@ If the script runs successfully, you will get an output mp4 file named `animatio
Or, you can change the input by specifying the `-s` and `-d` arguments:
```bash
# source input is an image
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
# disable pasting back to run faster
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
# source input is a video ✨
python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d0.mp4
# more options to see
python inference.py -h
```
#### Driving video auto-cropping
📕 To use your own driving video, we **recommend**:
#### Driving video auto-cropping 📢📢📢
To use your own driving video, we **recommend**: ⬇️
- Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
@ -144,25 +146,24 @@ Below is a auto-cropping case by `--flag_crop_driving_video`:
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video
```
If you find the results of auto-cropping is not well, you can modify the `--scale_crop_video`, `--vy_ratio_crop_video` options to adjust the scale and offset, or do it manually.
If you find the results of auto-cropping is not well, you can modify the `--scale_crop_driving_video`, `--vy_ratio_crop_driving_video` options to adjust the scale and offset, or do it manually.
#### Motion template making
You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl # portrait animation
python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d5.pkl # portrait video editing
```
**Discover more interesting results on our [Homepage](https://liveportrait.github.io)** 😊
### 4. Gradio interface 🤗
We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
```bash
# For Linux and Windows:
# For Linux and Windows users (and macOS with Intel??)
python app.py
# For macOS with Apple Silicon, Intel not supported, this maybe 20x slower than RTX 4090
# For macOS with Apple Silicon users, Intel not supported, this maybe 20x slower than RTX 4090
PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py
```
@ -210,7 +211,7 @@ Discover the invaluable resources contributed by our community to enhance your L
And many more amazing contributions from our community!
## Acknowledgements
## Acknowledgements 💐
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
## Citation 💖

View File

@ -19,3 +19,4 @@ matplotlib==3.9.0
imageio-ffmpeg==0.5.1
tyro==0.8.5
gradio==4.37.1
pykalman==0.9.7

View File

@ -14,8 +14,8 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig):
########## input arguments ##########
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait or video
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
########## inference arguments ##########
@ -23,23 +23,27 @@ class ArgumentConfig(PrintableConfig):
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP!
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
flag_eye_retargeting: bool = False # not recommend to be True, WIP
flag_lip_retargeting: bool = False # not recommend to be True, WIP
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
driving_smooth_observation_variance: float = 3e-6 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
########## crop arguments ##########
########## source crop arguments ##########
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
scale_crop_video: float = 2.2 # scale factor for cropping video
vx_ratio_crop_video: float = 0. # adjust y offset
vy_ratio_crop_video: float = -0.1 # adjust x offset
########## driving crop arguments ##########
scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video
vx_ratio_crop_driving_video: float = 0. # adjust y offset
vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
########## gradio arguments ##########
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server

View File

@ -15,15 +15,16 @@ class CropConfig(PrintableConfig):
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
########## source image cropping option ##########
########## source image or video cropping option ##########
dsize: int = 512 # crop size
scale: float = 2.5 # scale factor
scale: float = 2.8 # scale factor
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down
max_face_num: int = 0 # max face number, 0 mean no limit
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
########## driving video auto cropping option ##########
scale_crop_video: float = 2.2 # 2.0 # scale factor for cropping video
vx_ratio_crop_video: float = 0.0 # adjust y offset
vy_ratio_crop_video: float = -0.1 # adjust x offset
scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video
vx_ratio_crop_driving_video: float = 0.0 # adjust y offset
vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
direction: str = "large-small" # direction of cropping

View File

@ -25,7 +25,9 @@ class InferenceConfig(PrintableConfig):
flag_use_half_precision: bool = True
flag_crop_driving_video: bool = False
device_id: int = 0
flag_lip_zero: bool = True
flag_normalize_lip: bool = True
flag_source_video_eye_retargeting: bool = False
flag_video_editing_head_rotation: bool = False
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True
@ -37,7 +39,9 @@ class InferenceConfig(PrintableConfig):
flag_do_torch_compile: bool = False
# NOT EXPORTED PARAMS
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
driving_smooth_observation_variance: float = 3e-6 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape
@ -47,5 +51,5 @@ class InferenceConfig(PrintableConfig):
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
size_gif: int = 256 # default gif size, TO IMPLEMENT
source_max_dim: int = 1280 # the max dim of height and width of source image
source_division: int = 2 # make sure the height and width of source image can be divided by this number
source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number

View File

@ -3,6 +3,8 @@
"""
Pipeline for gradio
"""
import os.path as osp
import gradio as gr
from .config.argument_config import ArgumentConfig
@ -11,6 +13,7 @@ from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.helper import is_square_video
def update_args(args, user_args):
@ -31,23 +34,53 @@ class GradioPipeline(LivePortraitPipeline):
def execute_video(
self,
input_image_path,
input_video_path,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
input_source_image_path=None,
input_source_video_path=None,
input_driving_video_path=None,
flag_relative_input=True,
flag_do_crop_input=True,
flag_remap_input=True,
flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=3e-6,
tab_selection=None,
):
""" for video driven potrait animation
""" for video-driven potrait animation or video editing
"""
if input_image_path is not None and input_video_path is not None:
if tab_selection == 'Image':
input_source_path = input_source_image_path
elif tab_selection == 'Video':
input_source_path = input_source_video_path
else:
input_source_path = input_source_image_path
if input_source_path is not None and input_driving_video_path is not None:
if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False:
flag_crop_driving_video_input = True
log("The source video is not square, the driving video will be cropped to square automatically.")
gr.Info("The source video is not square, the driving video will be cropped to square automatically.", duration=2)
args_user = {
'source_image': input_image_path,
'driving_info': input_video_path,
'flag_relative': flag_relative_input,
'source': input_source_path,
'driving': input_driving_video_path,
'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'flag_crop_driving_video': flag_crop_driving_video_input
'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
'scale_crop_driving_video': scale_crop_driving_video,
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
'driving_smooth_observation_variance': driving_smooth_observation_variance,
}
# update config from user input
self.args = update_args(self.args, args_user)
@ -58,7 +91,7 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
else:
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
""" for single image retargeting
@ -79,9 +112,8 @@ class GradioPipeline(LivePortraitPipeline):
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
num_kp = x_s_user.shape[1]
# default: use x_s
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
x_d_new = x_s_user + eyes_delta + lip_delta
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
@ -114,4 +146,4 @@ class GradioPipeline(LivePortraitPipeline):
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)

View File

@ -20,8 +20,9 @@ from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
from .utils.filter import smooth
from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapper
@ -37,27 +38,178 @@ class LivePortraitPipeline(object):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs):
n_frames = I_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_eyes_lst': [],
'c_lip_lst': [],
'x_i_info_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s, R, δ and t for inference
I_i = I_lst[i]
x_i_info = self.live_portrait_wrapper.get_kp_info(I_i)
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
item_dct = {
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
'R': R_i.cpu().numpy().astype(np.float32),
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_eyes = c_eyes_lst[i].astype(np.float32)
template_dct['c_eyes_lst'].append(c_eyes)
c_lip = c_lip_lst[i].astype(np.float32)
template_dct['c_lip_lst'].append(c_lip)
template_dct['x_i_info_lst'].append(x_i_info)
return template_dct
def execute(self, args: ArgumentConfig):
# for convenience
inf_cfg = self.live_portrait_wrapper.inference_cfg
device = self.live_portrait_wrapper.device
crop_cfg = self.cropper.crop_cfg
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image)
######## load source input ########
flag_is_source_video = False
source_fps = None
if is_image(args.source):
flag_is_source_video = False
img_rgb = load_image_rgb(args.source)
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
log(f"Load source image from {args.source_image}")
log(f"Load source image from {args.source}")
source_rgb_lst = [img_rgb]
elif is_video(args.source):
flag_is_source_video = True
source_rgb_lst = load_video(args.source)
source_rgb_lst = [resize_to_limit(img, inf_cfg.source_max_dim, inf_cfg.source_division) for img in source_rgb_lst]
source_fps = int(get_fps(args.source))
log(f"Load source video from {args.source}, FPS is {source_fps}")
else: # source input is an unknown format
raise Exception(f"Unknown source format: {args.source}")
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
######## process driving info ########
flag_load_from_template = is_template(args.driving)
driving_rgb_crop_256x256_lst = None
wfp_template = None
if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
driving_template_dct = load(args.driving)
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames']
if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
else:
n_frames = driving_n_frames
# set output_fps
output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving) and is_video(args.driving):
# load from video file, AND make motion template
output_fps = int(get_fps(args.driving))
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_rgb_lst = load_video(args.driving)
driving_n_frames = len(driving_rgb_lst)
######## make motion template ########
log("Start making driving motion template...")
if flag_is_source_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
driving_rgb_lst = driving_rgb_lst[:n_frames]
else:
n_frames = driving_n_frames
if inf_cfg.flag_crop_driving_video:
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else:
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
#######################################
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_ratio(driving_lmk_crop_lst)
# save the motion template
I_d_lst = self.live_portrait_wrapper.prepare_videos(driving_rgb_crop_256x256_lst)
driving_template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}")
else:
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
######## prepare for pasteback ########
I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
I_p_pstbk_lst = []
log("Prepared pasteback mask done.")
I_p_lst = []
R_d_0, x_d_0_info = None, None
flag_normalize_lip = inf_cfg.flag_normalize_lip # not overwrite
flag_source_video_eye_retargeting = inf_cfg.flag_source_video_eye_retargeting # not overwrite
lip_delta_before_animation, eye_delta_before_animation = None, None
######## process source info ########
if flag_is_source_video:
log(f"Start making source motion template...")
source_rgb_lst = source_rgb_lst[:n_frames]
if inf_cfg.flag_do_crop:
ret_s = self.cropper.crop_source_video(source_rgb_lst, crop_cfg)
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
if len(ret_s["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_s["frame_crop_lst"]))
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
else:
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
# save the motion template
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else: # if the input is a source image, process it only once
crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
img_crop_256x256 = crop_info['img_crop_256x256']
if inf_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
@ -65,97 +217,73 @@ class LivePortraitPipeline(object):
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite
if flag_lip_zero:
# let lip-open scalar to be 0 at first
if flag_normalize_lip:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold:
flag_lip_zero = False
else:
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ########
flag_load_from_template = is_template(args.driving_info)
driving_rgb_crop_256x256_lst = None
wfp_template = None
if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
template_dct = load(args.driving_info)
n_frames = template_dct['n_frames']
# set output_fps
output_fps = template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving_info) and is_video(args.driving_info):
# load from video file, AND make motion template
log(f"Load video: {args.driving_info}")
if osp.isdir(args.driving_info):
output_fps = inf_cfg.output_fps
else:
output_fps = int(get_fps(args.driving_info))
log(f'The FPS of {args.driving_info} is: {output_fps}')
log(f"Load video file (mp4 mov avi etc...): {args.driving_info}")
driving_rgb_lst = load_driving_info(args.driving_info)
######## make motion template ########
log("Start making motion template...")
if inf_cfg.flag_crop_driving_video:
ret = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.')
driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else:
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst)
# save the motion template
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst)
template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving_info) + '.pkl'
dump(wfp_template, template_dct)
log(f"Dump motion template to {wfp_template}")
n_frames = I_d_lst.shape[0]
else:
raise Exception(f"{args.driving_info} not exists or unsupported driving info types!")
#########################################
######## prepare for pasteback ########
I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_pstbk_lst = []
log("Prepared pasteback mask done.")
#########################################
I_p_lst = []
R_d_0, x_d_0_info = None, None
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
######## animate ########
log(f"The animated video consists of {n_frames} frames.")
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
x_d_i_info = template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R_d']
if flag_is_source_video: # source video
x_s_info_tiny = source_template_dct['motion'][i]
x_s_info_tiny = dct2device(x_s_info_tiny, device)
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
x_s_info = source_template_dct['x_i_info_lst'][i]
x_c_s = x_s_info['kp']
R_s = x_s_info_tiny['R']
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
# let lip-open scalar to be 0 at first if the input is a video
if flag_normalize_lip:
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
if flag_source_video_eye_retargeting:
if i == 0:
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
if c_d_eye_before_animation_frame_zero[0][0] < inf_cfg.source_video_eye_retargeting_threshold:
c_d_eye_before_animation_frame_zero = [[0.39]]
combined_eye_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk)
eye_delta_before_animation = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
if i == 0: # cache the first frame
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if inf_cfg.flag_relative_motion:
if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
else:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
R_new = R_d_i
delta_new = x_d_i_info['exp']
@ -168,16 +296,20 @@ class LivePortraitPipeline(object):
# Algorithm 1:
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# without stitching or retargeting
if flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new += lip_delta_before_animation
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
pass
elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# with stitching and without retargeting
if flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
if flag_normalize_lip and lip_delta_before_animation is not None:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
x_d_i_new += eye_delta_before_animation
else:
eyes_delta, lip_delta = None, None
if inf_cfg.flag_eye_retargeting:
@ -193,12 +325,12 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
if inf_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
@ -208,38 +340,52 @@ class LivePortraitPipeline(object):
I_p_lst.append(I_p_i)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
# TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU
if flag_is_source_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float)
else:
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], source_rgb_lst[0], mask_ori_float)
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir)
wfp_concat = None
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
######### build final concat result #########
# driving frame | source image | generation, or source image | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
######### build the final concatenation result #########
# driving frame | source frame | generation, or source frame | generation
if flag_is_source_video:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# NOTE: update output fps
output_fps = source_fps if flag_is_source_video else output_fps
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_has_audio:
# final result with concat
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
if flag_source_has_audio or flag_driving_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.source if flag_source_has_audio else args.driving
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
# save drived result
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build final result #########
if flag_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
######### build the final result #########
if flag_source_has_audio or flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
audio_from_which_video = args.source if flag_source_has_audio else args.driving
log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp} with {wfp_with_audio}")
@ -250,36 +396,3 @@ class LivePortraitPipeline(object):
log(f'Animated video with concat: {wfp_concat}')
return wfp, wfp_concat
def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs):
n_frames = I_d_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_d_eyes_lst': [],
'c_d_lip_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s_d, R_d, δ_d and t_d for inference
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
item_dct = {
'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32),
'R_d': R_d_i.cpu().numpy().astype(np.float32),
'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_d_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_d_eyes = c_d_eyes_lst[i].astype(np.float32)
template_dct['c_d_eyes_lst'].append(c_d_eyes)
c_d_lip = c_d_lip_lst[i].astype(np.float32)
template_dct['c_d_lip_lst'].append(c_d_lip)
return template_dct

View File

@ -95,7 +95,7 @@ class LivePortraitWrapper(object):
x = x.to(self.device)
return x
def prepare_driving_videos(self, imgs) -> torch.Tensor:
def prepare_videos(self, imgs) -> torch.Tensor:
""" construct the input as standard
imgs: NxBxHxWx3, uint8
"""
@ -216,7 +216,7 @@ class LivePortraitWrapper(object):
with torch.no_grad():
delta = self.stitching_retargeting_module['eye'](feat_eye)
return delta
return delta.reshape(-1, kp_source.shape[1], 3)
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
"""
@ -229,7 +229,7 @@ class LivePortraitWrapper(object):
with torch.no_grad():
delta = self.stitching_retargeting_module['lip'](feat_lip)
return delta
return delta.reshape(-1, kp_source.shape[1], 3)
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
"""
@ -301,10 +301,10 @@ class LivePortraitWrapper(object):
return out
def calc_driving_ratio(self, driving_lmk_lst):
def calc_ratio(self, lmk_lst):
input_eye_ratio_lst = []
input_lip_ratio_lst = []
for lmk in driving_lmk_lst:
for lmk in lmk_lst:
# for eyes retargeting
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
# for lip retargeting

View File

@ -31,6 +31,7 @@ class Trajectory:
end: int = -1 # end frame
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
M_c2o_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # M_c2o list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
@ -104,6 +105,7 @@ class Cropper(object):
scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.landmark_runner.run(img_rgb, lmk)
@ -115,6 +117,58 @@ class Cropper(object):
return ret_dct
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
for idx, frame_rgb in enumerate(source_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
# crop the face
ret_dct = crop_image(
frame_rgb, # ndarray
lmk, # 106x2 or Nx2
dsize=crop_cfg.dsize,
scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.landmark_runner.run(frame_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop_256x256"])
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop_256x256"])
trajectory.M_c2o_lst.append(ret_dct['M_c2o'])
return {
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
"lmk_crop_lst": trajectory.lmk_crop_lst,
"M_c2o_lst": trajectory.M_c2o_lst,
}
def crop_driving_video(self, driving_rgb_lst, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
@ -142,9 +196,9 @@ class Cropper(object):
trajectory.lmk_lst.append(lmk)
ret_bbox = parse_bbox_from_landmark(
lmk,
scale=self.crop_cfg.scale_crop_video,
vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video,
vy_ratio=self.crop_cfg.vy_ratio_crop_video,
scale=self.crop_cfg.scale_crop_driving_video,
vx_ratio_crop_driving_video=self.crop_cfg.vx_ratio_crop_driving_video,
vy_ratio=self.crop_cfg.vy_ratio_crop_driving_video,
)["bbox"]
bbox = [
ret_bbox[0, 0],
@ -174,6 +228,7 @@ class Cropper(object):
"lmk_crop_lst": trajectory.lmk_crop_lst,
}
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
"""Tracking based landmarks/alignment"""
trajectory = Trajectory()

19
src/utils/filter.py Normal file
View File

@ -0,0 +1,19 @@
# coding: utf-8
import torch
import numpy as np
from pykalman import KalmanFilter
def smooth(x_d_lst, shape, device, observation_variance=3e-6, process_variance=1e-5):
x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst]
x_d_stacked = np.vstack(x_d_lst_reshape)
kf = KalmanFilter(
initial_state_mean=x_d_stacked[0],
n_dim_obs=x_d_stacked.shape[1],
transition_covariance=process_variance * np.eye(x_d_stacked.shape[1]),
observation_covariance=observation_variance * np.eye(x_d_stacked.shape[1])
)
smoothed_state_means, _ = kf.smooth(x_d_stacked)
x_d_lst_smooth = [torch.tensor(state_mean.reshape(shape[-2:]), dtype=torch.float32, device=device) for state_mean in smoothed_state_means]
return x_d_lst_smooth

View File

@ -8,6 +8,8 @@ import os
import os.path as osp
import torch
from collections import OrderedDict
import numpy as np
import cv2
from ..modules.spade_generator import SPADEDecoder
from ..modules.warping_network import WarpingNetwork
@ -42,6 +44,11 @@ def remove_suffix(filepath):
return osp.join(osp.dirname(filepath), basename(filepath))
def is_image(file_path):
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff')
return file_path.lower().endswith(image_extensions)
def is_video(file_path):
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
return True
@ -143,3 +150,16 @@ def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
return content
def is_square_video(video_path):
video = cv2.VideoCapture(video_path)
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
video.release()
# if width != height:
# gr.Info(f"Uploaded video is not square, force do crop (driving) to be True")
return width == height

View File

@ -1,7 +1,5 @@
# coding: utf-8
import os
from glob import glob
import os.path as osp
import imageio
import numpy as np
@ -18,23 +16,17 @@ def load_image_rgb(image_path: str):
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def load_driving_info(driving_info):
driving_video_ori = []
def load_video(video_info, n_frames=-1):
reader = imageio.get_reader(video_info, "ffmpeg")
def load_images_from_directory(directory):
image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg')))
return [load_image_rgb(im_path) for im_path in image_paths]
ret = []
for idx, frame_rgb in enumerate(reader):
if n_frames > 0 and idx >= n_frames:
break
ret.append(frame_rgb)
def load_images_from_video(file_path):
reader = imageio.get_reader(file_path, "ffmpeg")
return [image for _, image in enumerate(reader)]
if osp.isdir(driving_info):
driving_video_ori = load_images_from_directory(driving_info)
elif osp.isfile(driving_info):
driving_video_ori = load_images_from_video(driving_info)
return driving_video_ori
reader.close()
return ret
def contiguous(obj):

View File

@ -80,14 +80,15 @@ def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):
return img
def concat_frames(driving_image_lst, source_image, I_p_lst):
def concat_frames(driving_image_lst, source_image_lst, I_p_lst):
# TODO: add more concat style, e.g., left-down corner driving
out_lst = []
h, w, _ = I_p_lst[0].shape
source_image_resized_lst = [cv2.resize(img, (w, h)) for img in source_image_lst]
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
I_p = I_p_lst[idx]
source_image_resized = cv2.resize(source_image, (w, h))
source_image_resized = source_image_resized_lst[idx] if len(source_image_lst) > 1 else source_image_resized_lst[0]
if driving_image_lst is None:
out = np.hstack((source_image_resized, I_p))