diff --git a/.gitignore b/.gitignore
index 07050fd..1f85f19 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,9 +9,13 @@ __pycache__/
**/*.pth
**/*.onnx
+pretrained_weights/*.md
+pretrained_weights/docs
+
# Ipython notebook
*.ipynb
# Temporary files or benchmark resources
animations/*
tmp/*
+.vscode/launch.json
diff --git a/app.py b/app.py
index a82443b..5494d9b 100644
--- a/app.py
+++ b/app.py
@@ -25,28 +25,40 @@ args = tyro.cli(ArgumentConfig)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
+
gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
+
+
+def gpu_wrapped_execute_video(*args, **kwargs):
+ return gradio_pipeline.execute_video(*args, **kwargs)
+
+
+def gpu_wrapped_execute_image(*args, **kwargs):
+ return gradio_pipeline.execute_image(*args, **kwargs)
+
+
# assets
title_md = "assets/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples = [
- [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
- [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
- [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True, True],
- [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True, True],
- [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True, True],
+ [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
+ [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
+ [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
+ [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
+ [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
+ [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
#################### interface logic ####################
# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
-retargeting_input_image = gr.Image(type="numpy")
+retargeting_input_image = gr.Image(type="filepath")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video()
@@ -58,15 +70,39 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Accordion(open=True, label="Source Portrait"):
image_input = gr.Image(type="filepath")
+ gr.Examples(
+ examples=[
+ [osp.join(example_portrait_dir, "s9.jpg")],
+ [osp.join(example_portrait_dir, "s6.jpg")],
+ [osp.join(example_portrait_dir, "s10.jpg")],
+ [osp.join(example_portrait_dir, "s5.jpg")],
+ [osp.join(example_portrait_dir, "s7.jpg")],
+ [osp.join(example_portrait_dir, "s12.jpg")],
+ ],
+ inputs=[image_input],
+ cache_examples=False,
+ )
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video()
- gr.Markdown(load_description("assets/gradio_description_animation.md"))
+ gr.Examples(
+ examples=[
+ [osp.join(example_video_dir, "d0.mp4")],
+ [osp.join(example_video_dir, "d18.mp4")],
+ [osp.join(example_video_dir, "d19.mp4")],
+ [osp.join(example_video_dir, "d14.mp4")],
+ [osp.join(example_video_dir, "d6.mp4")],
+ ],
+ inputs=[video_input],
+ cache_examples=False,
+ )
with gr.Row():
- with gr.Accordion(open=True, label="Animation Options"):
+ with gr.Accordion(open=False, label="Animation Instructions and Options"):
+ gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
- flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
+ flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
+ flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
with gr.Row():
with gr.Column():
process_button_animation = gr.Button("๐ Animate", variant="primary")
@@ -81,24 +117,28 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
output_video_concat.render()
with gr.Row():
# Examples
- gr.Markdown("## You could choose the examples below โฌ๏ธ")
+ gr.Markdown("## You could also choose the examples below by one click โฌ๏ธ")
with gr.Row():
gr.Examples(
examples=data_examples,
+ fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
- flag_remap_input
+ flag_remap_input,
+ flag_crop_driving_video_input
],
- examples_per_page=5
+ outputs=[output_image, output_image_paste_back],
+ examples_per_page=len(data_examples),
+ cache_examples=False,
)
- gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
- with gr.Row():
+ gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
+ with gr.Row(visible=True):
eye_retargeting_slider.render()
lip_retargeting_slider.render()
- with gr.Row():
+ with gr.Row(visible=True):
process_button_retargeting = gr.Button("๐ Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton(
[
@@ -110,10 +150,22 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
],
value="๐งน Clear"
)
- with gr.Row():
+ with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"):
retargeting_input_image.render()
+ gr.Examples(
+ examples=[
+ [osp.join(example_portrait_dir, "s9.jpg")],
+ [osp.join(example_portrait_dir, "s6.jpg")],
+ [osp.join(example_portrait_dir, "s10.jpg")],
+ [osp.join(example_portrait_dir, "s5.jpg")],
+ [osp.join(example_portrait_dir, "s7.jpg")],
+ [osp.join(example_portrait_dir, "s12.jpg")],
+ ],
+ inputs=[retargeting_input_image],
+ cache_examples=False,
+ )
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render()
@@ -122,33 +174,29 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
output_image_paste_back.render()
# binding functions for buttons
process_button_retargeting.click(
- fn=gradio_pipeline.execute_image,
- inputs=[eye_retargeting_slider, lip_retargeting_slider],
+ # fn=gradio_pipeline.execute_image,
+ fn=gpu_wrapped_execute_image,
+ inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_animation.click(
- fn=gradio_pipeline.execute_video,
+ fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
- flag_remap_input
+ flag_remap_input,
+ flag_crop_driving_video_input
],
outputs=[output_video, output_video_concat],
show_progress=True
)
- image_input.change(
- fn=gradio_pipeline.prepare_retargeting,
- inputs=image_input,
- outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
- )
-##########################################################
demo.launch(
- server_name=args.server_name,
server_port=args.server_port,
share=args.share,
+ server_name=args.server_name
)
diff --git a/assets/.gitignore b/assets/.gitignore
new file mode 100644
index 0000000..892dfa4
--- /dev/null
+++ b/assets/.gitignore
@@ -0,0 +1,2 @@
+examples/driving/*.pkl
+examples/driving/*_crop.mp4
diff --git a/assets/docs/changelog/2024-07-10.md b/assets/docs/changelog/2024-07-10.md
new file mode 100644
index 0000000..fe0fa72
--- /dev/null
+++ b/assets/docs/changelog/2024-07-10.md
@@ -0,0 +1,22 @@
+## 2024/07/10
+
+**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** โค๏ธ
+The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
+
+### Updates
+
+- Audio and video concatenating: If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
+
+- Driving video auto-cropping: Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
+
+- Motion template making: Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option.
+
+
+### About driving video
+
+- For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
+
+
+### Others
+
+- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).
diff --git a/assets/examples/driving/d1.mp4 b/assets/examples/driving/d1.mp4
deleted file mode 100644
index e2825c1..0000000
Binary files a/assets/examples/driving/d1.mp4 and /dev/null differ
diff --git a/assets/examples/driving/d1.pkl b/assets/examples/driving/d1.pkl
new file mode 100644
index 0000000..94a4b3f
Binary files /dev/null and b/assets/examples/driving/d1.pkl differ
diff --git a/assets/examples/driving/d10.mp4 b/assets/examples/driving/d10.mp4
new file mode 100644
index 0000000..5e98204
Binary files /dev/null and b/assets/examples/driving/d10.mp4 differ
diff --git a/assets/examples/driving/d11.mp4 b/assets/examples/driving/d11.mp4
new file mode 100644
index 0000000..378d000
Binary files /dev/null and b/assets/examples/driving/d11.mp4 differ
diff --git a/assets/examples/driving/d12.mp4 b/assets/examples/driving/d12.mp4
new file mode 100644
index 0000000..984922e
Binary files /dev/null and b/assets/examples/driving/d12.mp4 differ
diff --git a/assets/examples/driving/d13.mp4 b/assets/examples/driving/d13.mp4
new file mode 100644
index 0000000..6ae3e97
Binary files /dev/null and b/assets/examples/driving/d13.mp4 differ
diff --git a/assets/examples/driving/d14.mp4 b/assets/examples/driving/d14.mp4
new file mode 100644
index 0000000..e4a25d6
Binary files /dev/null and b/assets/examples/driving/d14.mp4 differ
diff --git a/assets/examples/driving/d18.mp4 b/assets/examples/driving/d18.mp4
new file mode 100644
index 0000000..c23ade1
Binary files /dev/null and b/assets/examples/driving/d18.mp4 differ
diff --git a/assets/examples/driving/d19.mp4 b/assets/examples/driving/d19.mp4
new file mode 100644
index 0000000..07562e9
Binary files /dev/null and b/assets/examples/driving/d19.mp4 differ
diff --git a/assets/examples/driving/d2.mp4 b/assets/examples/driving/d2.mp4
deleted file mode 100644
index a14da2d..0000000
Binary files a/assets/examples/driving/d2.mp4 and /dev/null differ
diff --git a/assets/examples/driving/d2.pkl b/assets/examples/driving/d2.pkl
new file mode 100644
index 0000000..893555a
Binary files /dev/null and b/assets/examples/driving/d2.pkl differ
diff --git a/assets/examples/driving/d5.mp4 b/assets/examples/driving/d5.mp4
deleted file mode 100644
index 332bc88..0000000
Binary files a/assets/examples/driving/d5.mp4 and /dev/null differ
diff --git a/assets/examples/driving/d5.pkl b/assets/examples/driving/d5.pkl
new file mode 100644
index 0000000..0a198c6
Binary files /dev/null and b/assets/examples/driving/d5.pkl differ
diff --git a/assets/examples/driving/d7.mp4 b/assets/examples/driving/d7.mp4
deleted file mode 100644
index 81b5ae1..0000000
Binary files a/assets/examples/driving/d7.mp4 and /dev/null differ
diff --git a/assets/examples/driving/d7.pkl b/assets/examples/driving/d7.pkl
new file mode 100644
index 0000000..28ff425
Binary files /dev/null and b/assets/examples/driving/d7.pkl differ
diff --git a/assets/examples/driving/d8.mp4 b/assets/examples/driving/d8.mp4
deleted file mode 100644
index 7fabdde..0000000
Binary files a/assets/examples/driving/d8.mp4 and /dev/null differ
diff --git a/assets/examples/driving/d8.pkl b/assets/examples/driving/d8.pkl
new file mode 100644
index 0000000..b6a97d6
Binary files /dev/null and b/assets/examples/driving/d8.pkl differ
diff --git a/assets/examples/source/s11.jpg b/assets/examples/source/s11.jpg
new file mode 100644
index 0000000..bd2fa2d
Binary files /dev/null and b/assets/examples/source/s11.jpg differ
diff --git a/assets/examples/source/s12.jpg b/assets/examples/source/s12.jpg
new file mode 100644
index 0000000..d3d65c1
Binary files /dev/null and b/assets/examples/source/s12.jpg differ
diff --git a/assets/gradio_description_animation.md b/assets/gradio_description_animation.md
index 34b3897..cad1ad6 100644
--- a/assets/gradio_description_animation.md
+++ b/assets/gradio_description_animation.md
@@ -1,7 +1,16 @@
๐ฅ To animate the source portrait with the driving video, please follow these steps:
- 1. Specify the options in the Animation Options section. We recommend checking the do crop option when facial areas occupy a relatively small portion of your image.
+1. In the Animation Options section, we recommend enabling the do crop (source) option if faces occupy a small portion of your image.
- 2. Press the ๐ Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
+2. Press the ๐ Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
+
+
+3. If you want to upload your own driving video, the best practice:
+
+ - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
+ - Focus on the head area, similar to the example videos.
+ - Minimize shoulder movement.
+ - Make sure the first frame of driving video is a frontal face with **neutral expression**.
+
diff --git a/assets/gradio_description_retargeting.md b/assets/gradio_description_retargeting.md
index a99796d..4ff1a80 100644
--- a/assets/gradio_description_retargeting.md
+++ b/assets/gradio_description_retargeting.md
@@ -1 +1,4 @@
-๐ฅ To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the ๐ Retargeting button. The result would be shown in the middle block. You can try running it multiple times. ๐ Set both ratios to 0.8 to see what's going on!
+
+
+## Retargeting
+๐ฅ To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the ๐ Retargeting button. You can try running it multiple times. ๐ Set both ratios to 0.8 to see what's going on!
diff --git a/assets/gradio_description_upload.md b/assets/gradio_description_upload.md
index 46a5fa5..035a6c2 100644
--- a/assets/gradio_description_upload.md
+++ b/assets/gradio_description_upload.md
@@ -1,2 +1,2 @@
## ๐ค This is the official gradio demo for **LivePortrait**.
-Please upload or use the webcam to get a source portrait to the Source Portrait field and a driving video to the Driving Video field.
+Please upload or use a webcam to get a Source Portrait (any aspect ratio) and upload a Driving Video (1:1 aspect ratio, or any aspect ratio with do crop (driving video)
checked).
diff --git a/assets/gradio_title.md b/assets/gradio_title.md
index e2b765e..c9bbfc2 100644
--- a/assets/gradio_title.md
+++ b/assets/gradio_title.md
@@ -5,6 +5,7 @@
+
diff --git a/inference.py b/inference.py
index 8387e7f..dd7a768 100644
--- a/inference.py
+++ b/inference.py
@@ -1,6 +1,8 @@
# coding: utf-8
+import os.path as osp
import tyro
+import subprocess
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
@@ -11,11 +13,34 @@ def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
+def fast_check_ffmpeg():
+ try:
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
+ return True
+ except:
+ return False
+
+
+def fast_check_args(args: ArgumentConfig):
+ if not osp.exists(args.source_image):
+ raise FileNotFoundError(f"source image not found: {args.source_image}")
+ if not osp.exists(args.driving_info):
+ raise FileNotFoundError(f"driving info not found: {args.driving_info}")
+
+
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
+ if not fast_check_ffmpeg():
+ raise ImportError(
+ "FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
+ )
+
+ # fast check the args
+ fast_check_args(args)
+
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
@@ -29,5 +54,5 @@ def main():
live_portrait_pipeline.execute(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/readme.md b/readme.md
index 060c0b3..bcfacb1 100644
--- a/readme.md
+++ b/readme.md
@@ -4,7 +4,7 @@
Jianzhu Guo 1โ
Dingyun Zhang 1,2
Xiaoqiang Liu 1
- Zhizhou Zhong 1,3
+ Zhizhou Zhong 1,3
Yuan Zhang 1
@@ -35,8 +35,12 @@
## ๐ฅ Updates
-- **`2024/07/04`**: ๐ฅ We released the initial version of the inference code and models. Continuous updates, stay tuned!
-- **`2024/07/04`**: ๐ We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
+- **`2024/07/10`**: ๐ช We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
+- **`2024/07/09`**: ๐ค We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
+- **`2024/07/04`**: ๐ We released the initial version of the inference code and models. Continuous updates, stay tuned!
+- **`2024/07/04`**: ๐ฅ We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
+
+
## Introduction
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
@@ -55,8 +59,19 @@ conda activate LivePortrait
pip install -r requirements.txt
```
+**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/) installed!
+
### 2. Download pretrained weights
-Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory ๐. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
+
+The easiest way to download the pretrained weights is from HuggingFace:
+```bash
+# you may need to run `git lfs install` first
+git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights
+```
+
+Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
+
+Ensuring the directory structure is as follows, or contains:
```text
pretrained_weights
โโโ insightface
@@ -77,6 +92,7 @@ pretrained_weights
### 3. Inference ๐
+#### Fast hands-on
```bash
python inference.py
```
@@ -92,16 +108,37 @@ Or, you can change the input by specifying the `-s` and `-d` arguments:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
-# or disable pasting back
+# disable pasting back to run faster
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
# more options to see
python inference.py -h
```
-**More interesting results can be found in our [Homepage](https://liveportrait.github.io)** ๐
+#### Driving video auto-cropping
-### 4. Gradio interface
+๐ To use your own driving video, we **recommend**:
+ - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
+ - Focus on the head area, similar to the example videos.
+ - Minimize shoulder movement.
+ - Make sure the first frame of driving video is a frontal face with **neutral expression**.
+
+Below is a auto-cropping case by `--flag_crop_driving_video`:
+```bash
+python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video
+```
+
+If you find the results of auto-cropping is not well, you can modify the `--scale_crop_video`, `--vy_ratio_crop_video` options to adjust the scale and offset, or do it manually.
+
+#### Motion template making
+You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as:
+```bash
+python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl
+```
+
+**Discover more interesting results on our [Homepage](https://liveportrait.github.io)** ๐
+
+### 4. Gradio interface ๐ค
We also provide a Gradio interface for a better experience, just run by:
@@ -109,6 +146,10 @@ We also provide a Gradio interface for a better experience, just run by:
python app.py
```
+You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
+
+**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) ๐ค**
+
### 5. Inference speed evaluation ๐๐๐
We have also provided a script to evaluate the inference speed of each module:
@@ -124,10 +165,22 @@ Below are the results of inferring one frame on an RTX 4090 GPU using the native
| Motion Extractor | 28.12 | 108 | 0.84 |
| Spade Generator | 55.37 | 212 | 7.59 |
| Warping Module | 45.53 | 174 | 5.21 |
-| Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 |
+| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
-*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.*
+*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*
+## Community Resources ๐ค
+
+Discover the invaluable resources contributed by our community to enhance your LivePortrait experience:
+
+- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai)
+- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007)
+- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch)
+- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph)
+- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker)
+- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr)
+
+And many more amazing contributions from our community!
## Acknowledgements
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
@@ -135,10 +188,10 @@ We would like to thank the contributors of [FOMM](https://github.com/AliaksandrS
## Citation ๐
If you find LivePortrait useful for your research, welcome to ๐ this repo and cite our work using the following BibTeX:
```bibtex
-@article{guo2024live,
+@article{guo2024liveportrait,
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
- author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang},
- year = {2024},
- journal = {arXiv preprint:2407.03168},
+ author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di},
+ journal = {arXiv preprint arXiv:2407.03168},
+ year = {2024}
}
```
diff --git a/requirements.txt b/requirements.txt
index 73dbda9..b2e1c85 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,7 +11,7 @@ imageio==2.34.2
lmdb==1.4.1
tqdm==4.66.4
rich==13.7.1
-ffmpeg==1.4
+ffmpeg-python==0.2.0
onnxruntime-gpu==1.18.0
onnx==1.16.1
scikit-image==0.24.0
diff --git a/speed.py b/speed.py
index 02459d2..3cad248 100644
--- a/speed.py
+++ b/speed.py
@@ -47,11 +47,11 @@ def load_and_compile_models(cfg, model_config):
"""
Load and compile models for inference
"""
- appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
- motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
- warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
- spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
- stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
+ appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor')
+ motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor')
+ warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module')
+ spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator')
+ stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module')
models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor),
diff --git a/src/config/argument_config.py b/src/config/argument_config.py
index 0431627..0bbaa20 100644
--- a/src/config/argument_config.py
+++ b/src/config/argument_config.py
@@ -1,13 +1,13 @@
# coding: utf-8
"""
-config for user
+All configs for user
"""
-import os.path as osp
from dataclasses import dataclass
import tyro
from typing_extensions import Annotated
+from typing import Optional
from .base_config import PrintableConfig, make_abs_path
@@ -17,28 +17,31 @@ class ArgumentConfig(PrintableConfig):
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
- #####################################
########## inference arguments ##########
- device_id: int = 0
+ flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
+ flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
+ device_id: int = 0 # gpu device id
+ flag_force_cpu: bool = False # force cpu inference, WIP!
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
- flag_eye_retargeting: bool = False
- flag_lip_retargeting: bool = False
- flag_stitching: bool = True # we recommend setting it to True!
- flag_relative: bool = True # whether to use relative motion
+ flag_eye_retargeting: bool = False # not recommend to be True, WIP
+ flag_lip_retargeting: bool = False # not recommend to be True, WIP
+ flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
+ flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
- #########################################
########## crop arguments ##########
- dsize: int = 512
- scale: float = 2.3
- vx_ratio: float = 0 # vx ratio
- vy_ratio: float = -0.125 # vy ratio +up, -down
- ####################################
+ scale: float = 2.3 # the ratio of face area is smaller if scale is larger
+ vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
+ vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
+
+ scale_crop_video: float = 2.2 # scale factor for cropping video
+ vx_ratio_crop_video: float = 0. # adjust y offset
+ vy_ratio_crop_video: float = -0.1 # adjust x offset
########## gradio arguments ##########
- server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
- share: bool = True
- server_name: str = "0.0.0.0"
+ server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
+ share: bool = False # whether to share the server to public
+ server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
diff --git a/src/config/crop_config.py b/src/config/crop_config.py
index d3c79be..f3b12ef 100644
--- a/src/config/crop_config.py
+++ b/src/config/crop_config.py
@@ -4,15 +4,26 @@
parameters used for crop faces
"""
-import os.path as osp
from dataclasses import dataclass
-from typing import Union, List
+
from .base_config import PrintableConfig
@dataclass(repr=False) # use repr from PrintableConfig
class CropConfig(PrintableConfig):
+ insightface_root: str = "../../pretrained_weights/insightface"
+ landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
+ device_id: int = 0 # gpu device id
+ flag_force_cpu: bool = False # force cpu inference, WIP
+ ########## source image cropping option ##########
dsize: int = 512 # crop size
- scale: float = 2.3 # scale factor
+ scale: float = 2.5 # scale factor
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down
+ max_face_num: int = 0 # max face number, 0 mean no limit
+
+ ########## driving video auto cropping option ##########
+ scale_crop_video: float = 2.2 # 2.0 # scale factor for cropping video
+ vx_ratio_crop_video: float = 0.0 # adjust y offset
+ vy_ratio_crop_video: float = -0.1 # adjust x offset
+ direction: str = "large-small" # direction of cropping
diff --git a/src/config/inference_config.py b/src/config/inference_config.py
index e94aeb8..70eedd8 100644
--- a/src/config/inference_config.py
+++ b/src/config/inference_config.py
@@ -5,6 +5,8 @@ config dataclass used for inference
"""
import os.path as osp
+import cv2
+from numpy import ndarray
from dataclasses import dataclass
from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path
@@ -12,38 +14,38 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig):
+ # MODEL CONFIG, NOT EXPOERTED PARAMS
models_config: str = make_abs_path('./models.yaml') # portrait animation config
- checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint
- checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint
- checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint
- checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint
-
- checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint
- flag_use_half_precision: bool = True # whether to use half precision
-
- flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
- lip_zero_threshold: float = 0.03
+ checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
+ checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
+ checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G
+ checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
+ checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
+ # EXPOERTED PARAMS
+ flag_use_half_precision: bool = True
+ flag_crop_driving_video: bool = False
+ device_id: int = 0
+ flag_lip_zero: bool = True
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
- flag_stitching: bool = True # we recommend setting it to True!
+ flag_stitching: bool = True
+ flag_relative_motion: bool = True
+ flag_pasteback: bool = True
+ flag_do_crop: bool = True
+ flag_do_rot: bool = True
+ flag_force_cpu: bool = False
- flag_relative: bool = True # whether to use relative motion
- anchor_frame: int = 0 # set this value if find_best_frame is True
+ # NOT EXPOERTED PARAMS
+ lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
+ anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
- output_fps: int = 30 # fps for output video
crf: int = 15 # crf for output video
+ output_fps: int = 25 # default output fps
- flag_write_result: bool = True # whether to write output video
- flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
- mask_crop = None
- flag_write_gif: bool = False
- size_gif: int = 256
- ref_max_shape: int = 1280
- ref_shape_n: int = 2
-
- device_id: int = 0
- flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
- flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
+ mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
+ size_gif: int = 256 # default gif size, TO IMPLEMENT
+ source_max_dim: int = 1280 # the max dim of height and width of source image
+ source_division: int = 2 # make sure the height and width of source image can be divided by this number
diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py
index c717897..f7343f7 100644
--- a/src/gradio_pipeline.py
+++ b/src/gradio_pipeline.py
@@ -4,13 +4,14 @@
Pipeline for gradio
"""
import gradio as gr
+
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
-from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
+
def update_args(args, user_args):
"""update the args according to user inputs
@@ -20,22 +21,13 @@ def update_args(args, user_args):
setattr(args, k, v)
return args
+
class GradioPipeline(LivePortraitPipeline):
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
- # for single image retargeting
- self.start_prepare = False
- self.f_s_user = None
- self.x_c_s_info_user = None
- self.x_s_user = None
- self.source_lmk_user = None
- self.mask_ori = None
- self.img_rgb = None
- self.crop_M_c2o = None
-
def execute_video(
self,
@@ -44,7 +36,8 @@ class GradioPipeline(LivePortraitPipeline):
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
- ):
+ flag_crop_driving_video_input
+ ):
""" for video driven potrait animation
"""
if input_image_path is not None and input_video_path is not None:
@@ -54,6 +47,7 @@ class GradioPipeline(LivePortraitPipeline):
'flag_relative': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
+ 'flag_crop_driving_video': flag_crop_driving_video_input
}
# update config from user input
self.args = update_args(self.args, args_user)
@@ -66,51 +60,45 @@ class GradioPipeline(LivePortraitPipeline):
else:
raise gr.Error("The input source portrait or driving video hasn't been prepared yet ๐ฅ!", duration=5)
- def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
+ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
""" for single image retargeting
"""
- if input_eye_ratio is None or input_eye_ratio is None:
+ # disposable feature
+ f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
+ self.prepare_retargeting(input_image, flag_do_crop)
+
+ if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input ๐ฅ!", duration=5)
- elif self.f_s_user is None:
- if self.start_prepare:
- raise gr.Error(
- "The source portrait is under processing ๐ฅ! Please wait for a second.",
- duration=5
- )
- else:
- raise gr.Error(
- "The source portrait hasn't been prepared yet ๐ฅ! Please scroll to the top of the page to upload.",
- duration=5
- )
else:
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
+ x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
+ f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
# โ_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
- combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
- eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
# โ_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
- combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
- lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
- num_kp = self.x_s_user.shape[1]
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
+ lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
+ num_kp = x_s_user.shape[1]
# default: use x_s
- x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
+ x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, xโฒ_d))
- out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
+ out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
- out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
+ out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
-
- def prepare_retargeting(self, input_image_path, flag_do_crop = True):
+ def prepare_retargeting(self, input_image, flag_do_crop=True):
""" for single image retargeting
"""
- if input_image_path is not None:
- gr.Info("Upload successfully!", duration=2)
- self.start_prepare = True
- inference_cfg = self.live_portrait_wrapper.cfg
+ if input_image is not None:
+ # gr.Info("Upload successfully!", duration=2)
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ########
- img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
- log(f"Load source image from {input_image_path}.")
- crop_info = self.cropper.crop_single_image(img_rgb)
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
+ log(f"Load source image from {input_image}.")
+ crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
@@ -118,23 +106,12 @@ class GradioPipeline(LivePortraitPipeline):
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
-
- # record global info for next time use
- self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
- self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
- self.x_s_info_user = x_s_info
- self.source_lmk_user = crop_info['lmk_crop']
- self.img_rgb = img_rgb
- self.crop_M_c2o = crop_info['M_c2o']
- self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
- # update slider
- eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
- eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
- lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
- lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
- # for vis
- self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
- return eye_close_ratio, lip_close_ratio, self.I_s_vis
+ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
+ x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
+ source_lmk_user = crop_info['lmk_crop']
+ crop_M_c2o = crop_info['M_c2o']
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
+ return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
- return 0.8, 0.8, self.I_s_vis
+ raise gr.Error("The retargeting input hasn't been prepared yet ๐ฅ!", duration=5)
diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py
index 7fda1f5..a27b52f 100644
--- a/src/live_portrait_pipeline.py
+++ b/src/live_portrait_pipeline.py
@@ -4,13 +4,12 @@
Pipeline of LivePortrait
"""
-# TODO:
-# 1. ๅฝๅๅๅฎๆๆ็ๆจกๆฟ้ฝๆฏๅทฒ็ป่ฃๅฅฝ็๏ผ้่ฆไฟฎๆนไธ
-# 2. pickๆ ทไพๅพ source + driving
+import torch
+torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
-import cv2
+import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
-import pickle
+import os
import os.path as osp
from rich.progress import track
@@ -19,12 +18,12 @@ from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
-from .utils.video import images2video, concat_frames
+from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back
-from .utils.retargeting_utils import calc_lip_close_ratio
-from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
-from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
+from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load
+from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix
from .utils.rprint import rlog as log
+# from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapper
@@ -35,84 +34,124 @@ def make_abs_path(fn):
class LivePortraitPipeline(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
- self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
- self.cropper = Cropper(crop_cfg=crop_cfg)
+ self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
+ self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
def execute(self, args: ArgumentConfig):
- inference_cfg = self.live_portrait_wrapper.cfg # for convenience
+ # for convenience
+ inf_cfg = self.live_portrait_wrapper.inference_cfg
+ device = self.live_portrait_wrapper.device
+ crop_cfg = self.cropper.crop_cfg
+
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image)
- img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
+ img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
log(f"Load source image from {args.source_image}")
- crop_info = self.cropper.crop_single_image(img_rgb)
+
+ crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
+ if crop_info is None:
+ raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
- if inference_cfg.flag_do_crop:
+
+ if inf_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
- I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
+ img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
+ I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
- if inference_cfg.flag_lip_zero:
+ flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite
+ if flag_lip_zero:
# let lip-open scalar to be 0 at first
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
- if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
- inference_cfg.flag_lip_zero = False
+ if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold:
+ flag_lip_zero = False
else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ########
- if is_video(args.driving_info):
- log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
- # TODO: ่ฟ้trackไธไธ้ฉฑๅจ่ง้ข -> ๆๅปบๆจกๆฟ
+ flag_load_from_template = is_template(args.driving_info)
+ driving_rgb_crop_256x256_lst = None
+ wfp_template = None
+
+ if flag_load_from_template:
+ # NOTE: load from template, it is fast, but the cropping video is None
+ log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
+ template_dct = load(args.driving_info)
+ n_frames = template_dct['n_frames']
+
+ # set output_fps
+ output_fps = template_dct.get('output_fps', inf_cfg.output_fps)
+ log(f'The FPS of template: {output_fps}')
+
+ if args.flag_crop_driving_video:
+ log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
+
+ elif osp.exists(args.driving_info) and is_video(args.driving_info):
+ # load from video file, AND make motion template
+ log(f"Load video: {args.driving_info}")
+ if osp.isdir(args.driving_info):
+ output_fps = inf_cfg.output_fps
+ else:
+ output_fps = int(get_fps(args.driving_info))
+ log(f'The FPS of {args.driving_info} is: {output_fps}')
+
+ log(f"Load video file (mp4 mov avi etc...): {args.driving_info}")
driving_rgb_lst = load_driving_info(args.driving_info)
- driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
- I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
+
+ ######## make motion template ########
+ log("Start making motion template...")
+ if inf_cfg.flag_crop_driving_video:
+ ret = self.cropper.crop_driving_video(driving_rgb_lst)
+ log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.')
+ driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst']
+ driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
+ else:
+ driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
+ driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
+
+ c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst)
+ # save the motion template
+ I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst)
+ template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
+
+ wfp_template = remove_suffix(args.driving_info) + '.pkl'
+ dump(wfp_template, template_dct)
+ log(f"Dump motion template to {wfp_template}")
+
n_frames = I_d_lst.shape[0]
- if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
- driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
- input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
- elif is_template(args.driving_info):
- log(f"Load from video templates {args.driving_info}")
- with open(args.driving_info, 'rb') as f:
- template_lst, driving_lmk_lst = pickle.load(f)
- n_frames = template_lst[0]['n_frames']
- input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
else:
- raise Exception("Unsupported driving types!")
+ raise Exception(f"{args.driving_info} not exists or unsupported driving info types!")
#########################################
######## prepare for pasteback ########
- if inference_cfg.flag_pasteback:
- mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
- I_p_paste_lst = []
+ I_p_pstbk_lst = None
+ if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
+ mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
+ I_p_pstbk_lst = []
+ log("Prepared pasteback mask done.")
#########################################
I_p_lst = []
R_d_0, x_d_0_info = None, None
- for i in track(range(n_frames), description='Animating...', total=n_frames):
- if is_video(args.driving_info):
- # extract kp info by M
- I_d_i = I_d_lst[i]
- x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
- R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
- else:
- # from template
- x_d_i_info = template_lst[i]
- x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
- R_d_i = x_d_i_info['R_d']
+
+ for i in track(range(n_frames), description='๐Animating...', total=n_frames):
+ x_d_i_info = template_dct['motion'][i]
+ x_d_i_info = dct2device(x_d_i_info, device)
+ R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
- if inference_cfg.flag_relative:
+ if inf_cfg.flag_relative_motion:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
@@ -123,36 +162,36 @@ class LivePortraitPipeline(object):
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
- t_new[..., 2].fill_(0) # zero tz
+ t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
# Algorithm 1:
- if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
+ if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# without stitching or retargeting
- if inference_cfg.flag_lip_zero:
+ if flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
pass
- elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
+ elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# with stitching and without retargeting
- if inference_cfg.flag_lip_zero:
+ if flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
- if inference_cfg.flag_eye_retargeting:
- c_d_eyes_i = input_eye_ratio_lst[i]
+ if inf_cfg.flag_eye_retargeting:
+ c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# โ_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
- if inference_cfg.flag_lip_retargeting:
- c_d_lip_i = input_lip_ratio_lst[i]
+ if inf_cfg.flag_lip_retargeting:
+ c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# โ_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
- if inference_cfg.flag_relative: # use x_s
+ if inf_cfg.flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
@@ -161,30 +200,86 @@ class LivePortraitPipeline(object):
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
- if inference_cfg.flag_stitching:
+ if inf_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
- if inference_cfg.flag_pasteback:
- I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
- I_p_paste_lst.append(I_p_i_to_ori_blend)
+ if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
+ # TODO: pasteback is slow, considering optimize it using multi-threading or GPU
+ I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
+ I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir)
wfp_concat = None
- if is_video(args.driving_info):
- frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
- # save (driving frames, source image, drived frames) result
- wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
- images2video(frames_concatenated, wfp=wfp_concat)
+ flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
+
+ ######### build final concact result #########
+ # driving frame | source image | generation, or source image | generation
+ frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
+ wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
+ images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
+
+ if flag_has_audio:
+ # final result with concact
+ wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
+ add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
+ os.replace(wfp_concat_with_audio, wfp_concat)
+ log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
# save drived result
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
- if inference_cfg.flag_pasteback:
- images2video(I_p_paste_lst, wfp=wfp)
+ if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
+ images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
- images2video(I_p_lst, wfp=wfp)
+ images2video(I_p_lst, wfp=wfp, fps=output_fps)
+
+ ######### build final result #########
+ if flag_has_audio:
+ wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
+ add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
+ os.replace(wfp_with_audio, wfp)
+ log(f"Replace {wfp} with {wfp_with_audio}")
+
+ # final log
+ if wfp_template not in (None, ''):
+ log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
+ log(f'Animated video: {wfp}')
+ log(f'Animated video with concact: {wfp_concat}')
return wfp, wfp_concat
+
+ def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs):
+ n_frames = I_d_lst.shape[0]
+ template_dct = {
+ 'n_frames': n_frames,
+ 'output_fps': kwargs.get('output_fps', 25),
+ 'motion': [],
+ 'c_d_eyes_lst': [],
+ 'c_d_lip_lst': [],
+ }
+
+ for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
+ # collect s_d, R_d, ฮด_d and t_d for inference
+ I_d_i = I_d_lst[i]
+ x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
+ R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
+
+ item_dct = {
+ 'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32),
+ 'R_d': R_d_i.cpu().numpy().astype(np.float32),
+ 'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32),
+ 't': x_d_i_info['t'].cpu().numpy().astype(np.float32),
+ }
+
+ template_dct['motion'].append(item_dct)
+
+ c_d_eyes = c_d_eyes_lst[i].astype(np.float32)
+ template_dct['c_d_eyes_lst'].append(c_d_eyes)
+
+ c_d_lip = c_d_lip_lst[i].astype(np.float32)
+ template_dct['c_d_lip_lst'].append(c_d_lip)
+
+ return template_dct
diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py
index 0ad9d06..8869b95 100644
--- a/src/live_portrait_wrapper.py
+++ b/src/live_portrait_wrapper.py
@@ -20,45 +20,51 @@ from .utils.rprint import rlog as log
class LivePortraitWrapper(object):
- def __init__(self, cfg: InferenceConfig):
+ def __init__(self, inference_cfg: InferenceConfig):
- model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
+ self.inference_cfg = inference_cfg
+ self.device_id = inference_cfg.device_id
+ if inference_cfg.flag_force_cpu:
+ self.device = 'cpu'
+ else:
+ self.device = 'cuda:' + str(self.device_id)
+ model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
- self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
+ self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor done.')
# init M
- self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
+ self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor done.')
# init W
- self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
+ self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
log(f'Load warping_module done.')
# init G
- self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
+ self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
log(f'Load spade_generator done.')
# init S and R
- if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
- self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
+ if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
+ self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module done.')
else:
self.stitching_retargeting_module = None
- self.cfg = cfg
- self.device_id = cfg.device_id
+
+
self.timer = Timer()
def update_config(self, user_args):
for k, v in user_args.items():
- if hasattr(self.cfg, k):
- setattr(self.cfg, k, v)
+ if hasattr(self.inference_cfg, k):
+ setattr(self.inference_cfg, k, v)
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
""" construct the input as standard
img: HxWx3, uint8, 256x256
"""
h, w = img.shape[:2]
- if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
- x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
+ if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
+ x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
else:
x = img.copy()
@@ -70,7 +76,7 @@ class LivePortraitWrapper(object):
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
x = np.clip(x, 0, 1) # clip to 0~1
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
- x = x.cuda(self.device_id)
+ x = x.to(self.device)
return x
def prepare_driving_videos(self, imgs) -> torch.Tensor:
@@ -87,7 +93,7 @@ class LivePortraitWrapper(object):
y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
- y = y.cuda(self.device_id)
+ y = y.to(self.device)
return y
@@ -96,7 +102,7 @@ class LivePortraitWrapper(object):
x: Bx3xHxW, normalized to 0~1
"""
with torch.no_grad():
- with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
+ with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
feature_3d = self.appearance_feature_extractor(x)
return feature_3d.float()
@@ -108,10 +114,10 @@ class LivePortraitWrapper(object):
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
"""
with torch.no_grad():
- with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
+ with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
kp_info = self.motion_extractor(x)
- if self.cfg.flag_use_half_precision:
+ if self.inference_cfg.flag_use_half_precision:
# float the dict
for k, v in kp_info.items():
if isinstance(v, torch.Tensor):
@@ -254,14 +260,14 @@ class LivePortraitWrapper(object):
"""
# The line 18 in Algorithm 1: D(W(f_s; x_s, xโฒ_d,i)๏ผ
with torch.no_grad():
- with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
+ with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
# float the dict
- if self.cfg.flag_use_half_precision:
+ if self.inference_cfg.flag_use_half_precision:
for k, v in ret_dct.items():
if isinstance(v, torch.Tensor):
ret_dct[k] = v.float()
@@ -278,7 +284,7 @@ class LivePortraitWrapper(object):
return out
- def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
+ def calc_driving_ratio(self, driving_lmk_lst):
input_eye_ratio_lst = []
input_lip_ratio_lst = []
for lmk in driving_lmk_lst:
@@ -288,20 +294,18 @@ class LivePortraitWrapper(object):
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
return input_eye_ratio_lst, input_lip_ratio_lst
- def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
- eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
- eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
- input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
+ def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
+ c_s_eyes = calc_eye_close_ratio(source_lmk[None])
+ c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
+ c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
# [c_s,eyes, c_d,eyes,i]
- combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
+ combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
return combined_eye_ratio_tensor
- def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
- lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
- lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
+ def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
+ c_s_lip = calc_lip_close_ratio(source_lmk[None])
+ c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
+ c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
# [c_s,lip, c_d,lip,i]
- input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
- if input_lip_ratio_tensor.shape != [1, 1]:
- input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
- combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
+ combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
return combined_lip_ratio_tensor
diff --git a/src/template_maker.py b/src/template_maker.py
deleted file mode 100644
index 7f3ce06..0000000
--- a/src/template_maker.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# coding: utf-8
-
-"""
-Make video template
-"""
-
-import os
-import cv2
-import numpy as np
-import pickle
-from rich.progress import track
-from .utils.cropper import Cropper
-
-from .utils.io import load_driving_info
-from .utils.camera import get_rotation_matrix
-from .utils.helper import mkdir, basename
-from .utils.rprint import rlog as log
-from .config.crop_config import CropConfig
-from .config.inference_config import InferenceConfig
-from .live_portrait_wrapper import LivePortraitWrapper
-
-class TemplateMaker:
-
- def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
- self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
- self.cropper = Cropper(crop_cfg=crop_cfg)
-
- def make_motion_template(self, video_fp: str, output_path: str, **kwargs):
- """ make video template (.pkl format)
- video_fp: driving video file path
- output_path: where to save the pickle file
- """
-
- driving_rgb_lst = load_driving_info(video_fp)
- driving_rgb_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
- driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
- I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst)
-
- n_frames = I_d_lst.shape[0]
-
- templates = []
-
-
- for i in track(range(n_frames), description='Making templates...', total=n_frames):
- I_d_i = I_d_lst[i]
- x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
- R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
- # collect s_d, R_d, ฮด_d and t_d for inference
- template_dct = {
- 'n_frames': n_frames,
- 'frames_index': i,
- }
- template_dct['scale'] = x_d_i_info['scale'].cpu().numpy().astype(np.float32)
- template_dct['R_d'] = R_d_i.cpu().numpy().astype(np.float32)
- template_dct['exp'] = x_d_i_info['exp'].cpu().numpy().astype(np.float32)
- template_dct['t'] = x_d_i_info['t'].cpu().numpy().astype(np.float32)
-
- templates.append(template_dct)
-
- mkdir(output_path)
- # Save the dictionary as a pickle file
- pickle_fp = os.path.join(output_path, f'{basename(video_fp)}.pkl')
- with open(pickle_fp, 'wb') as f:
- pickle.dump([templates, driving_lmk_lst], f)
- log(f"Template saved at {pickle_fp}")
diff --git a/src/utils/camera.py b/src/utils/camera.py
index 8bbfc90..a3dd942 100644
--- a/src/utils/camera.py
+++ b/src/utils/camera.py
@@ -31,8 +31,6 @@ def headpose_pred_to_degree(pred):
def get_rotation_matrix(pitch_, yaw_, roll_):
""" the input is in degree
"""
- # calculate the rotation matrix: vps @ rot
-
# transform to radian
pitch = pitch_ / 180 * PI
yaw = yaw_ / 180 * PI
diff --git a/src/utils/crop.py b/src/utils/crop.py
index 8f23363..065b9f0 100644
--- a/src/utils/crop.py
+++ b/src/utils/crop.py
@@ -281,11 +281,10 @@ def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=Fals
dtype=DTYPE
)
- if flag_rot and angle is None:
- print('angle is None, but flag_rotate is True', style="bold yellow")
+ # if flag_rot and angle is None:
+ # print('angle is None, but flag_rotate is True', style="bold yellow")
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
-
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
@@ -362,17 +361,6 @@ def crop_image(img, pts: np.ndarray, **kwargs):
flag_do_rot=kwargs.get('flag_do_rot', True),
)
- if img is None:
- M_INV_H = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
- M = np.linalg.inv(M_INV_H)
- ret_dct = {
- 'M': M[:2, ...], # from the original image to the cropped image
- 'M_o2c': M[:2, ...], # from the cropped image to the original image
- 'img_crop': None,
- 'pt_crop': None,
- }
- return ret_dct
-
img_crop = _transform_img(img, M_INV, dsize) # origin to crop
pt_crop = _transform_pts(pts, M_INV)
@@ -397,16 +385,14 @@ def average_bbox_lst(bbox_lst):
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
"""prepare mask for later image paste back
"""
- if mask_crop is None:
- mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
mask_ori = mask_ori.astype(np.float32) / 255.
return mask_ori
-def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
+def paste_back(img_crop, M_c2o, img_ori, mask_ori):
"""paste back the image
"""
- dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
- result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
- result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
- return result
\ No newline at end of file
+ dsize = (img_ori.shape[1], img_ori.shape[0])
+ result = _transform_img(img_crop, M_c2o, dsize=dsize)
+ result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8)
+ return result
diff --git a/src/utils/cropper.py b/src/utils/cropper.py
index d5d511c..916d33b 100644
--- a/src/utils/cropper.py
+++ b/src/utils/cropper.py
@@ -1,20 +1,23 @@
# coding: utf-8
-import gradio as gr
-import numpy as np
import os.path as osp
-from typing import List, Union, Tuple
from dataclasses import dataclass, field
-import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
+from typing import List, Tuple, Union
-from .landmark_runner import LandmarkRunner
-from .face_analysis_diy import FaceAnalysisDIY
-from .helper import prefix
-from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst
-from .timer import Timer
+import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
+import numpy as np
+
+from ..config.crop_config import CropConfig
+from .crop import (
+ average_bbox_lst,
+ crop_image,
+ crop_image_by_bbox,
+ parse_bbox_from_landmark,
+)
+from .io import contiguous
from .rprint import rlog as log
-from .io import load_image_rgb
-from .video import VideoWriter, get_fps, change_video_fps
+from .face_analysis_diy import FaceAnalysisDIY
+from .landmark_runner import LandmarkRunner
def make_abs_path(fn):
@@ -23,123 +26,171 @@ def make_abs_path(fn):
@dataclass
class Trajectory:
- start: int = -1 # ่ตทๅงๅธง ้ญๅบ้ด
- end: int = -1 # ็ปๆๅธง ้ญๅบ้ด
+ start: int = -1 # start frame
+ end: int = -1 # end frame
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
+
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
+ lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
class Cropper(object):
def __init__(self, **kwargs) -> None:
- device_id = kwargs.get('device_id', 0)
+ self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
+ device_id = kwargs.get("device_id", 0)
+ flag_force_cpu = kwargs.get("flag_force_cpu", False)
+ if flag_force_cpu:
+ device = "cpu"
+ face_analysis_wrapper_provicer = ["CPUExecutionProvider"]
+ else:
+ device = "cuda"
+ face_analysis_wrapper_provicer = ["CUDAExecutionProvider"]
self.landmark_runner = LandmarkRunner(
- ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'),
- onnx_provider='cuda',
- device_id=device_id
+ ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
+ onnx_provider=device,
+ device_id=device_id,
)
self.landmark_runner.warmup()
self.face_analysis_wrapper = FaceAnalysisDIY(
- name='buffalo_l',
- root=make_abs_path('../../pretrained_weights/insightface'),
- providers=["CUDAExecutionProvider"]
+ name="buffalo_l",
+ root=make_abs_path(self.crop_cfg.insightface_root),
+ providers=face_analysis_wrapper_provicer,
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
self.face_analysis_wrapper.warmup()
- self.crop_cfg = kwargs.get('crop_cfg', None)
-
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.crop_cfg, k):
setattr(self.crop_cfg, k, v)
- def crop_single_image(self, obj, **kwargs):
- direction = kwargs.get('direction', 'large-small')
-
- # crop and align a single image
- if isinstance(obj, str):
- img_rgb = load_image_rgb(obj)
- elif isinstance(obj, np.ndarray):
- img_rgb = obj
+ def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig):
+ # crop a source image and get neccessary information
+ img_rgb = img_rgb_.copy() # copy it
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
src_face = self.face_analysis_wrapper.get(
- img_rgb,
+ img_bgr,
flag_do_landmark_2d_106=True,
- direction=direction
+ direction=crop_cfg.direction,
+ max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
- log('No face detected in the source image.')
- raise gr.Error("No face detected in the source image ๐ฅ!", duration=5)
- raise Exception("No face detected in the source image!")
+ log("No face detected in the source image.")
+ return None
elif len(src_face) > 1:
- log(f'More than one face detected in the image, only pick one face by rule {direction}.')
+ log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
+ # NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0]
- pts = src_face.landmark_2d_106
+ lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
# crop the face
ret_dct = crop_image(
img_rgb, # ndarray
- pts, # 106x2 or Nx2
- dsize=kwargs.get('dsize', 512),
- scale=kwargs.get('scale', 2.3),
- vy_ratio=kwargs.get('vy_ratio', -0.15),
+ lmk, # 106x2 or Nx2
+ dsize=crop_cfg.dsize,
+ scale=crop_cfg.scale,
+ vx_ratio=crop_cfg.vx_ratio,
+ vy_ratio=crop_cfg.vy_ratio,
)
- # update a 256x256 version for network input or else
- ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA)
- ret_dct['pt_crop_256x256'] = ret_dct['pt_crop'] * 256 / kwargs.get('dsize', 512)
- recon_ret = self.landmark_runner.run(img_rgb, pts)
- lmk = recon_ret['pts']
- ret_dct['lmk_crop'] = lmk
+ lmk = self.landmark_runner.run(img_rgb, lmk)
+ ret_dct["lmk_crop"] = lmk
+
+ # update a 256x256 version for network input
+ ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
+ ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
return ret_dct
- def get_retargeting_lmk_info(self, driving_rgb_lst):
- # TODO: implement a tracking-based version
- driving_lmk_lst = []
- for driving_image in driving_rgb_lst:
- ret_dct = self.crop_single_image(driving_image)
- driving_lmk_lst.append(ret_dct['lmk_crop'])
- return driving_lmk_lst
-
- def make_video_clip(self, driving_rgb_lst, output_path, output_fps=30, **kwargs):
+ def crop_driving_video(self, driving_rgb_lst, **kwargs):
+ """Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
- direction = kwargs.get('direction', 'large-small')
- for idx, driving_image in enumerate(driving_rgb_lst):
+ direction = kwargs.get("direction", "large-small")
+ for idx, frame_rgb in enumerate(driving_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
- driving_image,
+ contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
- direction=direction
+ direction=direction,
)
if len(src_face) == 0:
- # No face detected in the driving_image
+ log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
- log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.')
+ log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
- pts = src_face.landmark_2d_106
- lmk_203 = self.landmark_runner(driving_image, pts)['pts']
+ lmk = src_face.landmark_2d_106
+ lmk = self.landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
- lmk_203 = self.face_recon_wrapper(driving_image, trajectory.lmk_lst[-1])['pts']
+ lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
- trajectory.lmk_lst.append(lmk_203)
- ret_bbox = parse_bbox_from_landmark(lmk_203, scale=self.crop_cfg.globalscale, vy_ratio=elf.crop_cfg.vy_ratio)['bbox']
- bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4,
+ trajectory.lmk_lst.append(lmk)
+ ret_bbox = parse_bbox_from_landmark(
+ lmk,
+ scale=self.crop_cfg.scale_crop_video,
+ vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video,
+ vy_ratio=self.crop_cfg.vy_ratio_crop_video,
+ )["bbox"]
+ bbox = [
+ ret_bbox[0, 0],
+ ret_bbox[0, 1],
+ ret_bbox[2, 0],
+ ret_bbox[2, 1],
+ ] # 4,
trajectory.bbox_lst.append(bbox) # bbox
- trajectory.frame_rgb_lst.append(driving_image)
+ trajectory.frame_rgb_lst.append(frame_rgb)
global_bbox = average_bbox_lst(trajectory.bbox_lst)
+
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
ret_dct = crop_image_by_bbox(
- frame_rgb, global_bbox, lmk=lmk,
- dsize=self.video_crop_cfg.dsize, flag_rot=self.video_crop_cfg.flag_rot, borderValue=self.video_crop_cfg.borderValue
+ frame_rgb,
+ global_bbox,
+ lmk=lmk,
+ dsize=kwargs.get("dsize", 512),
+ flag_rot=False,
+ borderValue=(0, 0, 0),
)
- frame_rgb_crop = ret_dct['img_crop']
+ trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"])
+ trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"])
+
+ return {
+ "frame_crop_lst": trajectory.frame_rgb_crop_lst,
+ "lmk_crop_lst": trajectory.lmk_crop_lst,
+ }
+
+ def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
+ """Tracking based landmarks/alignment"""
+ trajectory = Trajectory()
+ direction = kwargs.get("direction", "large-small")
+
+ for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst):
+ if idx == 0 or trajectory.start == -1:
+ src_face = self.face_analysis_wrapper.get(
+ contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR
+ flag_do_landmark_2d_106=True,
+ direction=direction,
+ )
+ if len(src_face) == 0:
+ log(f"No face detected in the frame #{idx}")
+ raise Exception(f"No face detected in the frame #{idx}")
+ elif len(src_face) > 1:
+ log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
+ src_face = src_face[0]
+ lmk = src_face.landmark_2d_106
+ lmk = self.landmark_runner.run(frame_rgb_crop, lmk)
+ trajectory.start, trajectory.end = idx, idx
+ else:
+ lmk = self.landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
+ trajectory.end = idx
+
+ trajectory.lmk_lst.append(lmk)
+ return trajectory.lmk_lst
diff --git a/src/utils/face_analysis_diy.py b/src/utils/face_analysis_diy.py
index 456be5e..f13a659 100644
--- a/src/utils/face_analysis_diy.py
+++ b/src/utils/face_analysis_diy.py
@@ -39,7 +39,7 @@ class FaceAnalysisDIY(FaceAnalysis):
self.timer = Timer()
def get(self, img_bgr, **kwargs):
- max_num = kwargs.get('max_num', 0) # the number of the detected faces, 0 means no limit
+ max_num = kwargs.get('max_face_num', 0) # the number of the detected faces, 0 means no limit
flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection
direction = kwargs.get('direction', 'large-small') # sorting direction
face_center = None
diff --git a/src/utils/helper.py b/src/utils/helper.py
index 4974fc5..0e2af94 100644
--- a/src/utils/helper.py
+++ b/src/utils/helper.py
@@ -37,6 +37,11 @@ def basename(filename):
return prefix(osp.basename(filename))
+def remove_suffix(filepath):
+ """a/b/c.jpg -> a/b/c"""
+ return osp.join(osp.dirname(filepath), basename(filepath))
+
+
def is_video(file_path):
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
return True
@@ -63,9 +68,9 @@ def squeeze_tensor_to_numpy(tensor):
return out
-def dct2cuda(dct: dict, device_id: int):
+def dct2device(dct: dict, device):
for key in dct:
- dct[key] = torch.tensor(dct[key]).cuda(device_id)
+ dct[key] = torch.tensor(dct[key]).to(device)
return dct
@@ -94,13 +99,13 @@ def load_model(ckpt_path, model_config, device, model_type):
model_params = model_config['model_params'][f'{model_type}_params']
if model_type == 'appearance_feature_extractor':
- model = AppearanceFeatureExtractor(**model_params).cuda(device)
+ model = AppearanceFeatureExtractor(**model_params).to(device)
elif model_type == 'motion_extractor':
- model = MotionExtractor(**model_params).cuda(device)
+ model = MotionExtractor(**model_params).to(device)
elif model_type == 'warping_module':
- model = WarpingNetwork(**model_params).cuda(device)
+ model = WarpingNetwork(**model_params).to(device)
elif model_type == 'spade_generator':
- model = SPADEDecoder(**model_params).cuda(device)
+ model = SPADEDecoder(**model_params).to(device)
elif model_type == 'stitching_retargeting_module':
# Special handling for stitching and retargeting module
config = model_config['model_params']['stitching_retargeting_module_params']
@@ -108,17 +113,17 @@ def load_model(ckpt_path, model_config, device, model_type):
stitcher = StitchingRetargetingNetwork(**config.get('stitching'))
stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder']))
- stitcher = stitcher.cuda(device)
+ stitcher = stitcher.to(device)
stitcher.eval()
retargetor_lip = StitchingRetargetingNetwork(**config.get('lip'))
retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth']))
- retargetor_lip = retargetor_lip.cuda(device)
+ retargetor_lip = retargetor_lip.to(device)
retargetor_lip.eval()
retargetor_eye = StitchingRetargetingNetwork(**config.get('eye'))
retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye']))
- retargetor_eye = retargetor_eye.cuda(device)
+ retargetor_eye = retargetor_eye.to(device)
retargetor_eye.eval()
return {
@@ -134,20 +139,6 @@ def load_model(ckpt_path, model_config, device, model_type):
return model
-# get coefficients of Eqn. 7
-def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i):
- if config.relative:
- new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s
- new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp'])
- else:
- new_rotation = R_t_i
- new_expression = t_i_kp_info['exp']
- new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t'])
- new_translation[..., 2].fill_(0) # Keep the z-axis unchanged
- new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
- return new_rotation, new_expression, new_translation, new_scale
-
-
def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
diff --git a/src/utils/io.py b/src/utils/io.py
index 29a7e00..28c2d99 100644
--- a/src/utils/io.py
+++ b/src/utils/io.py
@@ -5,8 +5,11 @@ from glob import glob
import os.path as osp
import imageio
import numpy as np
+import pickle
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
+from .helper import mkdir, suffix
+
def load_image_rgb(image_path: str):
if not osp.exists(image_path):
@@ -23,8 +26,8 @@ def load_driving_info(driving_info):
return [load_image_rgb(im_path) for im_path in image_paths]
def load_images_from_video(file_path):
- reader = imageio.get_reader(file_path)
- return [image for idx, image in enumerate(reader)]
+ reader = imageio.get_reader(file_path, "ffmpeg")
+ return [image for _, image in enumerate(reader)]
if osp.isdir(driving_info):
driving_video_ori = load_images_from_directory(driving_info)
@@ -40,7 +43,7 @@ def contiguous(obj):
return obj
-def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
+def resize_to_limit(img: np.ndarray, max_dim=1920, division=2):
"""
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
:param img: the image to be processed.
@@ -61,9 +64,9 @@ def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
img = cv2.resize(img, (new_w, new_h))
# ensure that the image dimensions are multiples of n
- n = max(n, 1)
- new_h = img.shape[0] - (img.shape[0] % n)
- new_w = img.shape[1] - (img.shape[1] % n)
+ division = max(division, 1)
+ new_h = img.shape[0] - (img.shape[0] % division)
+ new_w = img.shape[1] - (img.shape[1] % division)
if new_h == 0 or new_w == 0:
# when the width or height is less than n, no need to process
@@ -87,7 +90,7 @@ def load_img_online(obj, mode="bgr", **kwargs):
img = obj
# Resize image to satisfy constraints
- img = resize_to_limit(img, max_dim=max_dim, n=n)
+ img = resize_to_limit(img, max_dim=max_dim, division=n)
if mode.lower() == "bgr":
return contiguous(img)
@@ -95,3 +98,28 @@ def load_img_online(obj, mode="bgr", **kwargs):
return contiguous(img[..., ::-1])
else:
raise Exception(f"Unknown mode {mode}")
+
+
+def load(fp):
+ suffix_ = suffix(fp)
+
+ if suffix_ == "npy":
+ return np.load(fp)
+ elif suffix_ == "pkl":
+ return pickle.load(open(fp, "rb"))
+ else:
+ raise Exception(f"Unknown type: {suffix}")
+
+
+def dump(wfp, obj):
+ wd = osp.split(wfp)[0]
+ if wd != "" and not osp.exists(wd):
+ mkdir(wd)
+
+ _suffix = suffix(wfp)
+ if _suffix == "npy":
+ np.save(wfp, obj)
+ elif _suffix == "pkl":
+ pickle.dump(obj, open(wfp, "wb"))
+ else:
+ raise Exception("Unknown type: {}".format(_suffix))
diff --git a/src/utils/landmark_runner.py b/src/utils/landmark_runner.py
index 7b0dcbe..7680a2c 100644
--- a/src/utils/landmark_runner.py
+++ b/src/utils/landmark_runner.py
@@ -25,6 +25,7 @@ def to_ndarray(obj):
class LandmarkRunner(object):
"""landmark runner"""
+
def __init__(self, **kwargs):
ckpt_path = kwargs.get('ckpt_path')
onnx_provider = kwargs.get('onnx_provider', 'cuda') # ้ป่ฎค็จcuda
@@ -55,6 +56,7 @@ class LandmarkRunner(object):
crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1)
img_crop_rgb = crop_dct['img_crop']
else:
+ # NOTE: force resize to 224x224, NOT RECOMMEND!
img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize))
scale = max(img_rgb.shape[:2]) / self.dsize
crop_dct = {
@@ -70,15 +72,13 @@ class LandmarkRunner(object):
out_lst = self._run(inp)
out_pts = out_lst[2]
- pts = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224
- pts = _transform_pts(pts, M=crop_dct['M_c2o'])
+ # 2d landmarks 203 points
+ lmk = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224
+ lmk = _transform_pts(lmk, M=crop_dct['M_c2o'])
- return {
- 'pts': pts, # 2d landmarks 203 points
- }
+ return lmk
def warmup(self):
- # ๆ้ dummy image่ฟ่กwarmup
self.timer.tic()
dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32)
diff --git a/src/utils/retargeting_utils.py b/src/utils/retargeting_utils.py
index 20a1bdd..ae2e5f5 100644
--- a/src/utils/retargeting_utils.py
+++ b/src/utils/retargeting_utils.py
@@ -7,32 +7,11 @@ import numpy as np
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
- """
- Calculate the ratio of the distance between two pairs of landmarks.
-
- Parameters:
- lmk (np.ndarray): Landmarks array of shape (B, N, 2).
- idx1, idx2, idx3, idx4 (int): Indices of the landmarks.
- eps (float): Small value to avoid division by zero.
-
- Returns:
- np.ndarray: Calculated distance ratio.
- """
return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
(np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))
def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
- """
- Calculate the eye-close ratio for left and right eyes.
-
- Parameters:
- lmk (np.ndarray): Landmarks array of shape (B, N, 2).
- target_eye_ratio (np.ndarray, optional): Additional target eye ratio array to include.
-
- Returns:
- np.ndarray: Concatenated eye-close ratios.
- """
lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
if target_eye_ratio is not None:
@@ -42,13 +21,4 @@ def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -
def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
- """
- Calculate the lip-close ratio.
-
- Parameters:
- lmk (np.ndarray): Landmarks array of shape (B, N, 2).
-
- Returns:
- np.ndarray: Calculated lip-close ratio.
- """
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
diff --git a/src/utils/video.py b/src/utils/video.py
index 720e082..5144e03 100644
--- a/src/utils/video.py
+++ b/src/utils/video.py
@@ -1,7 +1,9 @@
# coding: utf-8
"""
-functions for processing video
+Functions for processing video
+
+ATTENTION: you need to install ffmpeg and ffprobe in your env!
"""
import os.path as osp
@@ -9,14 +11,15 @@ import numpy as np
import subprocess
import imageio
import cv2
-
from rich.progress import track
-from .helper import prefix
+
+from .rprint import rlog as log
from .rprint import rprint as print
+from .helper import prefix
def exec_cmd(cmd):
- subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
def images2video(images, wfp, **kwargs):
@@ -35,7 +38,7 @@ def images2video(images, wfp, **kwargs):
)
n = len(images)
- for i in track(range(n), description='writing', transient=True):
+ for i in track(range(n), description='Writing', transient=True):
if image_mode.lower() == 'bgr':
writer.append_data(images[i][..., ::-1])
else:
@@ -43,9 +46,6 @@ def images2video(images, wfp, **kwargs):
writer.close()
- # print(f':smiley: Dump to {wfp}\n', style="bold green")
- print(f'Dump to {wfp}\n')
-
def video2gif(video_fp, fps=30, size=256):
if osp.exists(video_fp):
@@ -54,10 +54,10 @@ def video2gif(video_fp, fps=30, size=256):
palette_wfp = osp.join(d, 'palette.png')
gif_wfp = osp.join(d, f'{fn}.gif')
# generate the palette
- cmd = f'ffmpeg -i {video_fp} -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" {palette_wfp} -y'
+ cmd = f'ffmpeg -i "{video_fp}" -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" "{palette_wfp}" -y'
exec_cmd(cmd)
# use the palette to generate the gif
- cmd = f'ffmpeg -i {video_fp} -i {palette_wfp} -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" {gif_wfp} -y'
+ cmd = f'ffmpeg -i "{video_fp}" -i "{palette_wfp}" -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" "{gif_wfp}" -y'
exec_cmd(cmd)
else:
print(f'video_fp: {video_fp} not exists!')
@@ -65,7 +65,7 @@ def video2gif(video_fp, fps=30, size=256):
def merge_audio_video(video_fp, audio_fp, wfp):
if osp.exists(video_fp) and osp.exists(audio_fp):
- cmd = f'ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y'
+ cmd = f'ffmpeg -i "{video_fp}" -i "{audio_fp}" -c:v copy -c:a aac "{wfp}" -y'
exec_cmd(cmd)
print(f'merge {video_fp} and {audio_fp} to {wfp}')
else:
@@ -80,21 +80,23 @@ def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):
return img
-def concat_frames(I_p_lst, driving_rgb_lst, img_rgb):
+def concat_frames(driving_image_lst, source_image, I_p_lst):
# TODO: add more concat style, e.g., left-down corner driving
out_lst = []
+ h, w, _ = I_p_lst[0].shape
+
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
- source_image_drived = I_p_lst[idx]
- image_drive = driving_rgb_lst[idx]
+ I_p = I_p_lst[idx]
+ source_image_resized = cv2.resize(source_image, (w, h))
- # resize images to match source_image_drived shape
- h, w, _ = source_image_drived.shape
- image_drive_resized = cv2.resize(image_drive, (w, h))
- img_rgb_resized = cv2.resize(img_rgb, (w, h))
+ if driving_image_lst is None:
+ out = np.hstack((source_image_resized, I_p))
+ else:
+ driving_image = driving_image_lst[idx]
+ driving_image_resized = cv2.resize(driving_image, (w, h))
+ out = np.hstack((driving_image_resized, source_image_resized, I_p))
- # concatenate images horizontally
- frame = np.concatenate((image_drive_resized, img_rgb_resized, source_image_drived), axis=1)
- out_lst.append(frame)
+ out_lst.append(out)
return out_lst
@@ -126,14 +128,84 @@ class VideoWriter:
self.writer.close()
-def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=5):
- cmd = f"ffmpeg -i {input_file} -c:v {codec} -crf {crf} -r {fps} {output_file} -y"
+def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=12):
+ cmd = f'ffmpeg -i "{input_file}" -c:v {codec} -crf {crf} -r {fps} "{output_file}" -y'
exec_cmd(cmd)
-def get_fps(filepath):
- import ffmpeg
- probe = ffmpeg.probe(filepath)
- video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
- fps = eval(video_stream['avg_frame_rate'])
+def get_fps(filepath, default_fps=25):
+ try:
+ fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS)
+
+ if fps in (0, None):
+ fps = default_fps
+ except Exception as e:
+ log(e)
+ fps = default_fps
+
return fps
+
+
+def has_audio_stream(video_path: str) -> bool:
+ """
+ Check if the video file contains an audio stream.
+
+ :param video_path: Path to the video file
+ :return: True if the video contains an audio stream, False otherwise
+ """
+ if osp.isdir(video_path):
+ return False
+
+ cmd = [
+ 'ffprobe',
+ '-v', 'error',
+ '-select_streams', 'a',
+ '-show_entries', 'stream=codec_type',
+ '-of', 'default=noprint_wrappers=1:nokey=1',
+ f'"{video_path}"'
+ ]
+
+ try:
+ # result = subprocess.run(cmd, capture_output=True, text=True)
+ result = exec_cmd(' '.join(cmd))
+ if result.returncode != 0:
+ log(f"Error occurred while probing video: {result.stderr}")
+ return False
+
+ # Check if there is any output from ffprobe command
+ return bool(result.stdout.strip())
+ except Exception as e:
+ log(f"Error occurred while probing video: {video_path}, you may need to install ffprobe! Now set audio to false!", style="bold red")
+ return False
+
+
+def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
+ cmd = [
+ 'ffmpeg',
+ '-y',
+ '-i', f'"{silent_video_path}"',
+ '-i', f'"{audio_video_path}"',
+ '-map', '0:v',
+ '-map', '1:a',
+ '-c:v', 'copy',
+ '-shortest',
+ f'"{output_video_path}"'
+ ]
+
+ try:
+ exec_cmd(' '.join(cmd))
+ log(f"Video with audio generated successfully: {output_video_path}")
+ except subprocess.CalledProcessError as e:
+ log(f"Error occurred: {e}")
+
+
+def bb_intersection_over_union(boxA, boxB):
+ xA = max(boxA[0], boxB[0])
+ yA = max(boxA[1], boxB[1])
+ xB = min(boxA[2], boxB[2])
+ yB = min(boxA[3], boxB[3])
+ interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
+ boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
+ boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
+ iou = interArea / float(boxAArea + boxBArea - interArea)
+ return iou
diff --git a/src/utils/viz.py b/src/utils/viz.py
new file mode 100644
index 0000000..59443cb
--- /dev/null
+++ b/src/utils/viz.py
@@ -0,0 +1,19 @@
+# coding: utf-8
+
+import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
+
+
+def viz_lmk(img_, vps, **kwargs):
+ """ๅฏ่งๅ็น"""
+ lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA
+ img_for_viz = img_.copy()
+ for pt in vps:
+ cv2.circle(
+ img_for_viz,
+ (int(pt[0]), int(pt[1])),
+ radius=kwargs.get("radius", 1),
+ color=(0, 255, 0),
+ thickness=kwargs.get("thickness", 1),
+ lineType=lineType,
+ )
+ return img_for_viz
diff --git a/video2template.py b/video2template.py
deleted file mode 100644
index c187396..0000000
--- a/video2template.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# coding: utf-8
-
-"""
-[WIP] Pipeline for video template preparation
-"""
-
-import tyro
-from src.config.crop_config import CropConfig
-from src.config.inference_config import InferenceConfig
-from src.config.argument_config import ArgumentConfig
-from src.template_maker import TemplateMaker
-
-
-def partial_fields(target_class, kwargs):
- return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
-
-
-def main():
- # set tyro theme
- tyro.extras.set_accent_color("bright_cyan")
- args = tyro.cli(ArgumentConfig)
-
- # specify configs for inference
- inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
- crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
-
- video_template_maker = TemplateMaker(
- inference_cfg=inference_cfg,
- crop_cfg=crop_cfg
- )
-
- # run
- video_template_maker.make_motion_template(args.driving_video_path, args.template_output_dir)
-
-
-if __name__ == '__main__':
- main()