Merge branch 'main' into develop

This commit is contained in:
Maki 2024-07-12 16:09:58 +09:00
commit b0ad5dc17c
49 changed files with 825 additions and 553 deletions

4
.gitignore vendored
View File

@ -9,9 +9,13 @@ __pycache__/
**/*.pth
**/*.onnx
pretrained_weights/*.md
pretrained_weights/docs
# Ipython notebook
*.ipynb
# Temporary files or benchmark resources
animations/*
tmp/*
.vscode/launch.json

102
app.py
View File

@ -25,28 +25,40 @@ args = tyro.cli(ArgumentConfig)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline.execute_video(*args, **kwargs)
def gpu_wrapped_execute_image(*args, **kwargs):
return gradio_pipeline.execute_image(*args, **kwargs)
# assets
title_md = "assets/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples = [
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True, True],
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True, True],
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True, True],
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
#################### interface logic ####################
# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
retargeting_input_image = gr.Image(type="numpy")
retargeting_input_image = gr.Image(type="filepath")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video()
@ -58,15 +70,39 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Accordion(open=True, label="Source Portrait"):
image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s9.jpg")],
[osp.join(example_portrait_dir, "s6.jpg")],
[osp.join(example_portrait_dir, "s10.jpg")],
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
],
inputs=[image_input],
cache_examples=False,
)
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
],
inputs=[video_input],
cache_examples=False,
)
with gr.Row():
with gr.Accordion(open=False, label="Animation Instructions and Options"):
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
with gr.Row():
with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
@ -81,24 +117,28 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
output_video_concat.render()
with gr.Row():
# Examples
gr.Markdown("## You could choose the examples below ⬇️")
gr.Markdown("## You could also choose the examples below by one click ⬇️")
with gr.Row():
gr.Examples(
examples=data_examples,
fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input
flag_remap_input,
flag_crop_driving_video_input
],
examples_per_page=5
outputs=[output_image, output_image_paste_back],
examples_per_page=len(data_examples),
cache_examples=False,
)
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
with gr.Row():
gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True):
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Row():
with gr.Row(visible=True):
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton(
[
@ -110,10 +150,22 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
],
value="🧹 Clear"
)
with gr.Row():
with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"):
retargeting_input_image.render()
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s9.jpg")],
[osp.join(example_portrait_dir, "s6.jpg")],
[osp.join(example_portrait_dir, "s10.jpg")],
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
],
inputs=[retargeting_input_image],
cache_examples=False,
)
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render()
@ -122,33 +174,29 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
output_image_paste_back.render()
# binding functions for buttons
process_button_retargeting.click(
fn=gradio_pipeline.execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider],
# fn=gradio_pipeline.execute_image,
fn=gpu_wrapped_execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_animation.click(
fn=gradio_pipeline.execute_video,
fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input
flag_remap_input,
flag_crop_driving_video_input
],
outputs=[output_video, output_video_concat],
show_progress=True
)
image_input.change(
fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
)
##########################################################
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
share=args.share,
server_name=args.server_name
)

2
assets/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
examples/driving/*.pkl
examples/driving/*_crop.mp4

View File

@ -0,0 +1,22 @@
## 2024/07/10
**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️
The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
### Updates
- <strong>Audio and video concatenating: </strong> If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option.
### About driving video
- For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
### Others
- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

View File

@ -1,7 +1,16 @@
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
1. In the <strong>Animation Options</strong> section, we recommend enabling the <strong>do crop (source)</strong> option if faces occupy a small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
3. If you want to upload your own driving video, <strong>the best practice</strong>:
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
</div>

View File

@ -1 +1,4 @@
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
<br>
## Retargeting
<span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>

View File

@ -1,2 +1,2 @@
## 🤗 This is the official gradio demo for **LivePortrait**.
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
<div style="font-size: 1.2em;">Please upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) and upload a <strong>Driving Video</strong> (1:1 aspect ratio, or any aspect ratio with <code>do crop (driving video)</code> checked).</div>

View File

@ -5,6 +5,7 @@
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
</div>
</div>
</div>

View File

@ -1,6 +1,8 @@
# coding: utf-8
import os.path as osp
import tyro
import subprocess
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
@ -11,11 +13,34 @@ def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source_image):
raise FileNotFoundError(f"source image not found: {args.source_image}")
if not osp.exists(args.driving_info):
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
)
# fast check the args
fast_check_args(args)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
@ -29,5 +54,5 @@ def main():
live_portrait_pipeline.execute(args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -4,7 +4,7 @@
<a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup>&emsp;
<a href='https://scholar.google.com/citations?user=t88nyvsAAAAJ&hl' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup>&emsp;
<a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup>&emsp;
</div>
@ -35,8 +35,12 @@
## 🔥 Updates
- **`2024/07/04`**: 🔥 We released the initial version of the inference code and models. Continuous updates, stay tuned!
- **`2024/07/04`**: 😊 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
- **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned!
- **`2024/07/04`**: 🔥 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
## Introduction
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
@ -55,8 +59,19 @@ conda activate LivePortrait
pip install -r requirements.txt
```
**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/) installed!
### 2. Download pretrained weights
Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
The easiest way to download the pretrained weights is from HuggingFace:
```bash
# you may need to run `git lfs install` first
git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights
```
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
Ensuring the directory structure is as follows, or contains:
```text
pretrained_weights
├── insightface
@ -77,6 +92,7 @@ pretrained_weights
### 3. Inference 🚀
#### Fast hands-on
```bash
python inference.py
```
@ -92,16 +108,37 @@ Or, you can change the input by specifying the `-s` and `-d` arguments:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
# or disable pasting back
# disable pasting back to run faster
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
# more options to see
python inference.py -h
```
**More interesting results can be found in our [Homepage](https://liveportrait.github.io)** 😊
#### Driving video auto-cropping
### 4. Gradio interface
📕 To use your own driving video, we **recommend**:
- Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
Below is a auto-cropping case by `--flag_crop_driving_video`:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video
```
If you find the results of auto-cropping is not well, you can modify the `--scale_crop_video`, `--vy_ratio_crop_video` options to adjust the scale and offset, or do it manually.
#### Motion template making
You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl
```
**Discover more interesting results on our [Homepage](https://liveportrait.github.io)** 😊
### 4. Gradio interface 🤗
We also provide a Gradio interface for a better experience, just run by:
@ -109,6 +146,10 @@ We also provide a Gradio interface for a better experience, just run by:
python app.py
```
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
### 5. Inference speed evaluation 🚀🚀🚀
We have also provided a script to evaluate the inference speed of each module:
@ -126,8 +167,20 @@ Below are the results of inferring one frame on an RTX 4090 GPU using the native
| Warping Module | 45.53 | 174 | 5.21 |
| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.*
*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*
## Community Resources 🤗
Discover the invaluable resources contributed by our community to enhance your LivePortrait experience:
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai)
- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007)
- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch)
- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph)
- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker)
- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr)
And many more amazing contributions from our community!
## Acknowledgements
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
@ -135,10 +188,10 @@ We would like to thank the contributors of [FOMM](https://github.com/AliaksandrS
## Citation 💖
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
```bibtex
@article{guo2024live,
@article{guo2024liveportrait,
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang},
year = {2024},
journal = {arXiv preprint:2407.03168},
author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di},
journal = {arXiv preprint arXiv:2407.03168},
year = {2024}
}
```

View File

@ -11,7 +11,7 @@ imageio==2.34.2
lmdb==1.4.1
tqdm==4.66.4
rich==13.7.1
ffmpeg==1.4
ffmpeg-python==0.2.0
onnxruntime-gpu==1.18.0
onnx==1.16.1
scikit-image==0.24.0

View File

@ -47,11 +47,11 @@ def load_and_compile_models(cfg, model_config):
"""
Load and compile models for inference
"""
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module')
models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor),

View File

@ -1,13 +1,13 @@
# coding: utf-8
"""
config for user
All configs for user
"""
import os.path as osp
from dataclasses import dataclass
import tyro
from typing_extensions import Annotated
from typing import Optional
from .base_config import PrintableConfig, make_abs_path
@ -17,28 +17,31 @@ class ArgumentConfig(PrintableConfig):
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
#####################################
########## inference arguments ##########
device_id: int = 0
flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP!
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative motion
flag_eye_retargeting: bool = False # not recommend to be True, WIP
flag_lip_retargeting: bool = False # not recommend to be True, WIP
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
#########################################
########## crop arguments ##########
dsize: int = 512
scale: float = 2.3
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down
####################################
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
scale_crop_video: float = 2.2 # scale factor for cropping video
vx_ratio_crop_video: float = 0. # adjust y offset
vy_ratio_crop_video: float = -0.1 # adjust x offset
########## gradio arguments ##########
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
share: bool = True
server_name: str = "0.0.0.0"
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
share: bool = False # whether to share the server to public
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all

View File

@ -4,15 +4,26 @@
parameters used for crop faces
"""
import os.path as osp
from dataclasses import dataclass
from typing import Union, List
from .base_config import PrintableConfig
@dataclass(repr=False) # use repr from PrintableConfig
class CropConfig(PrintableConfig):
insightface_root: str = "../../pretrained_weights/insightface"
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
########## source image cropping option ##########
dsize: int = 512 # crop size
scale: float = 2.3 # scale factor
scale: float = 2.5 # scale factor
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down
max_face_num: int = 0 # max face number, 0 mean no limit
########## driving video auto cropping option ##########
scale_crop_video: float = 2.2 # 2.0 # scale factor for cropping video
vx_ratio_crop_video: float = 0.0 # adjust y offset
vy_ratio_crop_video: float = -0.1 # adjust x offset
direction: str = "large-small" # direction of cropping

View File

@ -5,6 +5,8 @@ config dataclass used for inference
"""
import os.path as osp
import cv2
from numpy import ndarray
from dataclasses import dataclass
from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path
@ -12,38 +14,38 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig):
# MODEL CONFIG, NOT EXPOERTED PARAMS
models_config: str = make_abs_path('./models.yaml') # portrait animation config
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint
flag_use_half_precision: bool = True # whether to use half precision
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
lip_zero_threshold: float = 0.03
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
# EXPOERTED PARAMS
flag_use_half_precision: bool = True
flag_crop_driving_video: bool = False
device_id: int = 0
flag_lip_zero: bool = True
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!
flag_stitching: bool = True
flag_relative_motion: bool = True
flag_pasteback: bool = True
flag_do_crop: bool = True
flag_do_rot: bool = True
flag_force_cpu: bool = False
flag_relative: bool = True # whether to use relative motion
anchor_frame: int = 0 # set this value if find_best_frame is True
# NOT EXPOERTED PARAMS
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
output_fps: int = 30 # fps for output video
crf: int = 15 # crf for output video
output_fps: int = 25 # default output fps
flag_write_result: bool = True # whether to write output video
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
mask_crop = None
flag_write_gif: bool = False
size_gif: int = 256
ref_max_shape: int = 1280
ref_shape_n: int = 2
device_id: int = 0
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
size_gif: int = 256 # default gif size, TO IMPLEMENT
source_max_dim: int = 1280 # the max dim of height and width of source image
source_division: int = 2 # make sure the height and width of source image can be divided by this number

View File

@ -4,13 +4,14 @@
Pipeline for gradio
"""
import gradio as gr
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
def update_args(args, user_args):
"""update the args according to user inputs
@ -20,22 +21,13 @@ def update_args(args, user_args):
setattr(args, k, v)
return args
class GradioPipeline(LivePortraitPipeline):
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
# for single image retargeting
self.start_prepare = False
self.f_s_user = None
self.x_c_s_info_user = None
self.x_s_user = None
self.source_lmk_user = None
self.mask_ori = None
self.img_rgb = None
self.crop_M_c2o = None
def execute_video(
self,
@ -44,6 +36,7 @@ class GradioPipeline(LivePortraitPipeline):
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
):
""" for video driven potrait animation
"""
@ -54,6 +47,7 @@ class GradioPipeline(LivePortraitPipeline):
'flag_relative': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'flag_crop_driving_video': flag_crop_driving_video_input
}
# update config from user input
self.args = update_args(self.args, args_user)
@ -66,51 +60,45 @@ class GradioPipeline(LivePortraitPipeline):
else:
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
""" for single image retargeting
"""
if input_eye_ratio is None or input_eye_ratio is None:
# disposable feature
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, flag_do_crop)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
elif self.f_s_user is None:
if self.start_prepare:
raise gr.Error(
"The source portrait is under processing 💥! Please wait for a second.",
duration=5
)
else:
raise gr.Error(
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
duration=5
)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
num_kp = self.x_s_user.shape[1]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
num_kp = x_s_user.shape[1]
# default: use x_s
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
def prepare_retargeting(self, input_image, flag_do_crop=True):
""" for single image retargeting
"""
if input_image_path is not None:
gr.Info("Upload successfully!", duration=2)
self.start_prepare = True
inference_cfg = self.live_portrait_wrapper.cfg
if input_image is not None:
# gr.Info("Upload successfully!", duration=2)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
@ -118,23 +106,12 @@ class GradioPipeline(LivePortraitPipeline):
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
# record global info for next time use
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
self.img_rgb = img_rgb
self.crop_M_c2o = crop_info['M_c2o']
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
# for vis
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
return eye_close_ratio, lip_close_ratio, self.I_s_vis
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
return 0.8, 0.8, self.I_s_vis
raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)

View File

@ -4,13 +4,12 @@
Pipeline of LivePortrait
"""
# TODO:
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
# 2. pick样例图 source + driving
import torch
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
import cv2
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import pickle
import os
import os.path as osp
from rich.progress import track
@ -19,12 +18,12 @@ from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.retargeting_utils import calc_lip_close_ratio
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix
from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapper
@ -35,84 +34,124 @@ def make_abs_path(fn):
class LivePortraitPipeline(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg)
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
def execute(self, args: ArgumentConfig):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
# for convenience
inf_cfg = self.live_portrait_wrapper.inference_cfg
device = self.live_portrait_wrapper.device
crop_cfg = self.cropper.crop_cfg
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image)
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
log(f"Load source image from {args.source_image}")
crop_info = self.cropper.crop_single_image(img_rgb)
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
if inference_cfg.flag_do_crop:
if inf_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
if inference_cfg.flag_lip_zero:
flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite
if flag_lip_zero:
# let lip-open scalar to be 0 at first
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
inference_cfg.flag_lip_zero = False
if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold:
flag_lip_zero = False
else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ########
if is_video(args.driving_info):
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
# TODO: 这里track一下驱动视频 -> 构建模板
driving_rgb_lst = load_driving_info(args.driving_info)
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
n_frames = I_d_lst.shape[0]
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
elif is_template(args.driving_info):
log(f"Load from video templates {args.driving_info}")
with open(args.driving_info, 'rb') as f:
template_lst, driving_lmk_lst = pickle.load(f)
n_frames = template_lst[0]['n_frames']
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
flag_load_from_template = is_template(args.driving_info)
driving_rgb_crop_256x256_lst = None
wfp_template = None
if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
template_dct = load(args.driving_info)
n_frames = template_dct['n_frames']
# set output_fps
output_fps = template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving_info) and is_video(args.driving_info):
# load from video file, AND make motion template
log(f"Load video: {args.driving_info}")
if osp.isdir(args.driving_info):
output_fps = inf_cfg.output_fps
else:
raise Exception("Unsupported driving types!")
output_fps = int(get_fps(args.driving_info))
log(f'The FPS of {args.driving_info} is: {output_fps}')
log(f"Load video file (mp4 mov avi etc...): {args.driving_info}")
driving_rgb_lst = load_driving_info(args.driving_info)
######## make motion template ########
log("Start making motion template...")
if inf_cfg.flag_crop_driving_video:
ret = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.')
driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else:
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst)
# save the motion template
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst)
template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving_info) + '.pkl'
dump(wfp_template, template_dct)
log(f"Dump motion template to {wfp_template}")
n_frames = I_d_lst.shape[0]
else:
raise Exception(f"{args.driving_info} not exists or unsupported driving info types!")
#########################################
######## prepare for pasteback ########
if inference_cfg.flag_pasteback:
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_paste_lst = []
I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_pstbk_lst = []
log("Prepared pasteback mask done.")
#########################################
I_p_lst = []
R_d_0, x_d_0_info = None, None
for i in track(range(n_frames), description='Animating...', total=n_frames):
if is_video(args.driving_info):
# extract kp info by M
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
else:
# from template
x_d_i_info = template_lst[i]
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
x_d_i_info = template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if inference_cfg.flag_relative:
if inf_cfg.flag_relative_motion:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
@ -127,32 +166,32 @@ class LivePortraitPipeline(object):
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
# Algorithm 1:
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# without stitching or retargeting
if inference_cfg.flag_lip_zero:
if flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
pass
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# with stitching and without retargeting
if inference_cfg.flag_lip_zero:
if flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
if inference_cfg.flag_eye_retargeting:
c_d_eyes_i = input_eye_ratio_lst[i]
if inf_cfg.flag_eye_retargeting:
c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
if inference_cfg.flag_lip_retargeting:
c_d_lip_i = input_lip_ratio_lst[i]
if inf_cfg.flag_lip_retargeting:
c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
if inference_cfg.flag_relative: # use x_s
if inf_cfg.flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
@ -161,30 +200,86 @@ class LivePortraitPipeline(object):
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
if inference_cfg.flag_stitching:
if inf_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if inference_cfg.flag_pasteback:
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
I_p_paste_lst.append(I_p_i_to_ori_blend)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir)
wfp_concat = None
if is_video(args.driving_info):
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
# save (driving frames, source image, drived frames) result
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
######### build final concact result #########
# driving frame | source image | generation, or source image | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat)
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_has_audio:
# final result with concact
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
# save drived result
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
if inference_cfg.flag_pasteback:
images2video(I_p_paste_lst, wfp=wfp)
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp)
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build final result #########
if flag_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp} with {wfp_with_audio}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concact: {wfp_concat}')
return wfp, wfp_concat
def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs):
n_frames = I_d_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_d_eyes_lst': [],
'c_d_lip_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s_d, R_d, δ_d and t_d for inference
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
item_dct = {
'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32),
'R_d': R_d_i.cpu().numpy().astype(np.float32),
'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_d_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_d_eyes = c_d_eyes_lst[i].astype(np.float32)
template_dct['c_d_eyes_lst'].append(c_d_eyes)
c_d_lip = c_d_lip_lst[i].astype(np.float32)
template_dct['c_d_lip_lst'].append(c_d_lip)
return template_dct

View File

@ -20,45 +20,51 @@ from .utils.rprint import rlog as log
class LivePortraitWrapper(object):
def __init__(self, cfg: InferenceConfig):
def __init__(self, inference_cfg: InferenceConfig):
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
self.inference_cfg = inference_cfg
self.device_id = inference_cfg.device_id
if inference_cfg.flag_force_cpu:
self.device = 'cpu'
else:
self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor done.')
# init M
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor done.')
# init W
self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
log(f'Load warping_module done.')
# init G
self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
log(f'Load spade_generator done.')
# init S and R
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module done.')
else:
self.stitching_retargeting_module = None
self.cfg = cfg
self.device_id = cfg.device_id
self.timer = Timer()
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.cfg, k):
setattr(self.cfg, k, v)
if hasattr(self.inference_cfg, k):
setattr(self.inference_cfg, k, v)
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
""" construct the input as standard
img: HxWx3, uint8, 256x256
"""
h, w = img.shape[:2]
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
else:
x = img.copy()
@ -70,7 +76,7 @@ class LivePortraitWrapper(object):
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
x = np.clip(x, 0, 1) # clip to 0~1
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
x = x.cuda(self.device_id)
x = x.to(self.device)
return x
def prepare_driving_videos(self, imgs) -> torch.Tensor:
@ -87,7 +93,7 @@ class LivePortraitWrapper(object):
y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
y = y.cuda(self.device_id)
y = y.to(self.device)
return y
@ -96,7 +102,7 @@ class LivePortraitWrapper(object):
x: Bx3xHxW, normalized to 0~1
"""
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
feature_3d = self.appearance_feature_extractor(x)
return feature_3d.float()
@ -108,10 +114,10 @@ class LivePortraitWrapper(object):
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
"""
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
kp_info = self.motion_extractor(x)
if self.cfg.flag_use_half_precision:
if self.inference_cfg.flag_use_half_precision:
# float the dict
for k, v in kp_info.items():
if isinstance(v, torch.Tensor):
@ -254,14 +260,14 @@ class LivePortraitWrapper(object):
"""
# The line 18 in Algorithm 1: D(W(f_s; x_s, x_d,i)
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
# float the dict
if self.cfg.flag_use_half_precision:
if self.inference_cfg.flag_use_half_precision:
for k, v in ret_dct.items():
if isinstance(v, torch.Tensor):
ret_dct[k] = v.float()
@ -278,7 +284,7 @@ class LivePortraitWrapper(object):
return out
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
def calc_driving_ratio(self, driving_lmk_lst):
input_eye_ratio_lst = []
input_lip_ratio_lst = []
for lmk in driving_lmk_lst:
@ -288,20 +294,18 @@ class LivePortraitWrapper(object):
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
return input_eye_ratio_lst, input_lip_ratio_lst
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
# [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
return combined_eye_ratio_tensor
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
c_s_lip = calc_lip_close_ratio(source_lmk[None])
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
# [c_s,lip, c_d,lip,i]
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
if input_lip_ratio_tensor.shape != [1, 1]:
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
return combined_lip_ratio_tensor

View File

@ -1,65 +0,0 @@
# coding: utf-8
"""
Make video template
"""
import os
import cv2
import numpy as np
import pickle
from rich.progress import track
from .utils.cropper import Cropper
from .utils.io import load_driving_info
from .utils.camera import get_rotation_matrix
from .utils.helper import mkdir, basename
from .utils.rprint import rlog as log
from .config.crop_config import CropConfig
from .config.inference_config import InferenceConfig
from .live_portrait_wrapper import LivePortraitWrapper
class TemplateMaker:
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg)
def make_motion_template(self, video_fp: str, output_path: str, **kwargs):
""" make video template (.pkl format)
video_fp: driving video file path
output_path: where to save the pickle file
"""
driving_rgb_lst = load_driving_info(video_fp)
driving_rgb_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst)
n_frames = I_d_lst.shape[0]
templates = []
for i in track(range(n_frames), description='Making templates...', total=n_frames):
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
# collect s_d, R_d, δ_d and t_d for inference
template_dct = {
'n_frames': n_frames,
'frames_index': i,
}
template_dct['scale'] = x_d_i_info['scale'].cpu().numpy().astype(np.float32)
template_dct['R_d'] = R_d_i.cpu().numpy().astype(np.float32)
template_dct['exp'] = x_d_i_info['exp'].cpu().numpy().astype(np.float32)
template_dct['t'] = x_d_i_info['t'].cpu().numpy().astype(np.float32)
templates.append(template_dct)
mkdir(output_path)
# Save the dictionary as a pickle file
pickle_fp = os.path.join(output_path, f'{basename(video_fp)}.pkl')
with open(pickle_fp, 'wb') as f:
pickle.dump([templates, driving_lmk_lst], f)
log(f"Template saved at {pickle_fp}")

View File

@ -31,8 +31,6 @@ def headpose_pred_to_degree(pred):
def get_rotation_matrix(pitch_, yaw_, roll_):
""" the input is in degree
"""
# calculate the rotation matrix: vps @ rot
# transform to radian
pitch = pitch_ / 180 * PI
yaw = yaw_ / 180 * PI

View File

@ -281,11 +281,10 @@ def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=Fals
dtype=DTYPE
)
if flag_rot and angle is None:
print('angle is None, but flag_rotate is True', style="bold yellow")
# if flag_rot and angle is None:
# print('angle is None, but flag_rotate is True', style="bold yellow")
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
@ -362,17 +361,6 @@ def crop_image(img, pts: np.ndarray, **kwargs):
flag_do_rot=kwargs.get('flag_do_rot', True),
)
if img is None:
M_INV_H = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
M = np.linalg.inv(M_INV_H)
ret_dct = {
'M': M[:2, ...], # from the original image to the cropped image
'M_o2c': M[:2, ...], # from the cropped image to the original image
'img_crop': None,
'pt_crop': None,
}
return ret_dct
img_crop = _transform_img(img, M_INV, dsize) # origin to crop
pt_crop = _transform_pts(pts, M_INV)
@ -397,16 +385,14 @@ def average_bbox_lst(bbox_lst):
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
"""prepare mask for later image paste back
"""
if mask_crop is None:
mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
mask_ori = mask_ori.astype(np.float32) / 255.
return mask_ori
def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
def paste_back(img_crop, M_c2o, img_ori, mask_ori):
"""paste back the image
"""
dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
dsize = (img_ori.shape[1], img_ori.shape[0])
result = _transform_img(img_crop, M_c2o, dsize=dsize)
result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8)
return result

View File

@ -1,20 +1,23 @@
# coding: utf-8
import gradio as gr
import numpy as np
import os.path as osp
from typing import List, Union, Tuple
from dataclasses import dataclass, field
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
from typing import List, Tuple, Union
from .landmark_runner import LandmarkRunner
from .face_analysis_diy import FaceAnalysisDIY
from .helper import prefix
from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst
from .timer import Timer
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
from ..config.crop_config import CropConfig
from .crop import (
average_bbox_lst,
crop_image,
crop_image_by_bbox,
parse_bbox_from_landmark,
)
from .io import contiguous
from .rprint import rlog as log
from .io import load_image_rgb
from .video import VideoWriter, get_fps, change_video_fps
from .face_analysis_diy import FaceAnalysisDIY
from .landmark_runner import LandmarkRunner
def make_abs_path(fn):
@ -23,123 +26,171 @@ def make_abs_path(fn):
@dataclass
class Trajectory:
start: int = -1 # 起始帧 闭区间
end: int = -1 # 结束帧 闭区间
start: int = -1 # start frame
end: int = -1 # end frame
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
class Cropper(object):
def __init__(self, **kwargs) -> None:
device_id = kwargs.get('device_id', 0)
self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
device_id = kwargs.get("device_id", 0)
flag_force_cpu = kwargs.get("flag_force_cpu", False)
if flag_force_cpu:
device = "cpu"
face_analysis_wrapper_provicer = ["CPUExecutionProvider"]
else:
device = "cuda"
face_analysis_wrapper_provicer = ["CUDAExecutionProvider"]
self.landmark_runner = LandmarkRunner(
ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'),
onnx_provider='cuda',
device_id=device_id
ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
onnx_provider=device,
device_id=device_id,
)
self.landmark_runner.warmup()
self.face_analysis_wrapper = FaceAnalysisDIY(
name='buffalo_l',
root=make_abs_path('../../pretrained_weights/insightface'),
providers=["CUDAExecutionProvider"]
name="buffalo_l",
root=make_abs_path(self.crop_cfg.insightface_root),
providers=face_analysis_wrapper_provicer,
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
self.face_analysis_wrapper.warmup()
self.crop_cfg = kwargs.get('crop_cfg', None)
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.crop_cfg, k):
setattr(self.crop_cfg, k, v)
def crop_single_image(self, obj, **kwargs):
direction = kwargs.get('direction', 'large-small')
# crop and align a single image
if isinstance(obj, str):
img_rgb = load_image_rgb(obj)
elif isinstance(obj, np.ndarray):
img_rgb = obj
def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig):
# crop a source image and get neccessary information
img_rgb = img_rgb_.copy() # copy it
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
src_face = self.face_analysis_wrapper.get(
img_rgb,
img_bgr,
flag_do_landmark_2d_106=True,
direction=direction
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
log('No face detected in the source image.')
raise gr.Error("No face detected in the source image 💥!", duration=5)
raise Exception("No face detected in the source image!")
log("No face detected in the source image.")
return None
elif len(src_face) > 1:
log(f'More than one face detected in the image, only pick one face by rule {direction}.')
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
# NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0]
pts = src_face.landmark_2d_106
lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
# crop the face
ret_dct = crop_image(
img_rgb, # ndarray
pts, # 106x2 or Nx2
dsize=kwargs.get('dsize', 512),
scale=kwargs.get('scale', 2.3),
vy_ratio=kwargs.get('vy_ratio', -0.15),
lmk, # 106x2 or Nx2
dsize=crop_cfg.dsize,
scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
)
# update a 256x256 version for network input or else
ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct['pt_crop_256x256'] = ret_dct['pt_crop'] * 256 / kwargs.get('dsize', 512)
recon_ret = self.landmark_runner.run(img_rgb, pts)
lmk = recon_ret['pts']
ret_dct['lmk_crop'] = lmk
lmk = self.landmark_runner.run(img_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
return ret_dct
def get_retargeting_lmk_info(self, driving_rgb_lst):
# TODO: implement a tracking-based version
driving_lmk_lst = []
for driving_image in driving_rgb_lst:
ret_dct = self.crop_single_image(driving_image)
driving_lmk_lst.append(ret_dct['lmk_crop'])
return driving_lmk_lst
def make_video_clip(self, driving_rgb_lst, output_path, output_fps=30, **kwargs):
def crop_driving_video(self, driving_rgb_lst, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
direction = kwargs.get('direction', 'large-small')
for idx, driving_image in enumerate(driving_rgb_lst):
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb in enumerate(driving_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
driving_image,
contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
direction=direction
direction=direction,
)
if len(src_face) == 0:
# No face detected in the driving_image
log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.')
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
pts = src_face.landmark_2d_106
lmk_203 = self.landmark_runner(driving_image, pts)['pts']
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk_203 = self.face_recon_wrapper(driving_image, trajectory.lmk_lst[-1])['pts']
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk_203)
ret_bbox = parse_bbox_from_landmark(lmk_203, scale=self.crop_cfg.globalscale, vy_ratio=elf.crop_cfg.vy_ratio)['bbox']
bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4,
trajectory.lmk_lst.append(lmk)
ret_bbox = parse_bbox_from_landmark(
lmk,
scale=self.crop_cfg.scale_crop_video,
vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video,
vy_ratio=self.crop_cfg.vy_ratio_crop_video,
)["bbox"]
bbox = [
ret_bbox[0, 0],
ret_bbox[0, 1],
ret_bbox[2, 0],
ret_bbox[2, 1],
] # 4,
trajectory.bbox_lst.append(bbox) # bbox
trajectory.frame_rgb_lst.append(driving_image)
trajectory.frame_rgb_lst.append(frame_rgb)
global_bbox = average_bbox_lst(trajectory.bbox_lst)
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
ret_dct = crop_image_by_bbox(
frame_rgb, global_bbox, lmk=lmk,
dsize=self.video_crop_cfg.dsize, flag_rot=self.video_crop_cfg.flag_rot, borderValue=self.video_crop_cfg.borderValue
frame_rgb,
global_bbox,
lmk=lmk,
dsize=kwargs.get("dsize", 512),
flag_rot=False,
borderValue=(0, 0, 0),
)
frame_rgb_crop = ret_dct['img_crop']
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"])
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"])
return {
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
"lmk_crop_lst": trajectory.lmk_crop_lst,
}
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
"""Tracking based landmarks/alignment"""
trajectory = Trajectory()
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR
flag_do_landmark_2d_106=True,
direction=direction,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
raise Exception(f"No face detected in the frame #{idx}")
elif len(src_face) > 1:
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb_crop, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
return trajectory.lmk_lst

View File

@ -39,7 +39,7 @@ class FaceAnalysisDIY(FaceAnalysis):
self.timer = Timer()
def get(self, img_bgr, **kwargs):
max_num = kwargs.get('max_num', 0) # the number of the detected faces, 0 means no limit
max_num = kwargs.get('max_face_num', 0) # the number of the detected faces, 0 means no limit
flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection
direction = kwargs.get('direction', 'large-small') # sorting direction
face_center = None

View File

@ -37,6 +37,11 @@ def basename(filename):
return prefix(osp.basename(filename))
def remove_suffix(filepath):
"""a/b/c.jpg -> a/b/c"""
return osp.join(osp.dirname(filepath), basename(filepath))
def is_video(file_path):
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
return True
@ -63,9 +68,9 @@ def squeeze_tensor_to_numpy(tensor):
return out
def dct2cuda(dct: dict, device_id: int):
def dct2device(dct: dict, device):
for key in dct:
dct[key] = torch.tensor(dct[key]).cuda(device_id)
dct[key] = torch.tensor(dct[key]).to(device)
return dct
@ -94,13 +99,13 @@ def load_model(ckpt_path, model_config, device, model_type):
model_params = model_config['model_params'][f'{model_type}_params']
if model_type == 'appearance_feature_extractor':
model = AppearanceFeatureExtractor(**model_params).cuda(device)
model = AppearanceFeatureExtractor(**model_params).to(device)
elif model_type == 'motion_extractor':
model = MotionExtractor(**model_params).cuda(device)
model = MotionExtractor(**model_params).to(device)
elif model_type == 'warping_module':
model = WarpingNetwork(**model_params).cuda(device)
model = WarpingNetwork(**model_params).to(device)
elif model_type == 'spade_generator':
model = SPADEDecoder(**model_params).cuda(device)
model = SPADEDecoder(**model_params).to(device)
elif model_type == 'stitching_retargeting_module':
# Special handling for stitching and retargeting module
config = model_config['model_params']['stitching_retargeting_module_params']
@ -108,17 +113,17 @@ def load_model(ckpt_path, model_config, device, model_type):
stitcher = StitchingRetargetingNetwork(**config.get('stitching'))
stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder']))
stitcher = stitcher.cuda(device)
stitcher = stitcher.to(device)
stitcher.eval()
retargetor_lip = StitchingRetargetingNetwork(**config.get('lip'))
retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth']))
retargetor_lip = retargetor_lip.cuda(device)
retargetor_lip = retargetor_lip.to(device)
retargetor_lip.eval()
retargetor_eye = StitchingRetargetingNetwork(**config.get('eye'))
retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye']))
retargetor_eye = retargetor_eye.cuda(device)
retargetor_eye = retargetor_eye.to(device)
retargetor_eye.eval()
return {
@ -134,20 +139,6 @@ def load_model(ckpt_path, model_config, device, model_type):
return model
# get coefficients of Eqn. 7
def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i):
if config.relative:
new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s
new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp'])
else:
new_rotation = R_t_i
new_expression = t_i_kp_info['exp']
new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t'])
new_translation[..., 2].fill_(0) # Keep the z-axis unchanged
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
return new_rotation, new_expression, new_translation, new_scale
def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()

View File

@ -5,8 +5,11 @@ from glob import glob
import os.path as osp
import imageio
import numpy as np
import pickle
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
from .helper import mkdir, suffix
def load_image_rgb(image_path: str):
if not osp.exists(image_path):
@ -23,8 +26,8 @@ def load_driving_info(driving_info):
return [load_image_rgb(im_path) for im_path in image_paths]
def load_images_from_video(file_path):
reader = imageio.get_reader(file_path)
return [image for idx, image in enumerate(reader)]
reader = imageio.get_reader(file_path, "ffmpeg")
return [image for _, image in enumerate(reader)]
if osp.isdir(driving_info):
driving_video_ori = load_images_from_directory(driving_info)
@ -40,7 +43,7 @@ def contiguous(obj):
return obj
def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
def resize_to_limit(img: np.ndarray, max_dim=1920, division=2):
"""
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
:param img: the image to be processed.
@ -61,9 +64,9 @@ def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
img = cv2.resize(img, (new_w, new_h))
# ensure that the image dimensions are multiples of n
n = max(n, 1)
new_h = img.shape[0] - (img.shape[0] % n)
new_w = img.shape[1] - (img.shape[1] % n)
division = max(division, 1)
new_h = img.shape[0] - (img.shape[0] % division)
new_w = img.shape[1] - (img.shape[1] % division)
if new_h == 0 or new_w == 0:
# when the width or height is less than n, no need to process
@ -87,7 +90,7 @@ def load_img_online(obj, mode="bgr", **kwargs):
img = obj
# Resize image to satisfy constraints
img = resize_to_limit(img, max_dim=max_dim, n=n)
img = resize_to_limit(img, max_dim=max_dim, division=n)
if mode.lower() == "bgr":
return contiguous(img)
@ -95,3 +98,28 @@ def load_img_online(obj, mode="bgr", **kwargs):
return contiguous(img[..., ::-1])
else:
raise Exception(f"Unknown mode {mode}")
def load(fp):
suffix_ = suffix(fp)
if suffix_ == "npy":
return np.load(fp)
elif suffix_ == "pkl":
return pickle.load(open(fp, "rb"))
else:
raise Exception(f"Unknown type: {suffix}")
def dump(wfp, obj):
wd = osp.split(wfp)[0]
if wd != "" and not osp.exists(wd):
mkdir(wd)
_suffix = suffix(wfp)
if _suffix == "npy":
np.save(wfp, obj)
elif _suffix == "pkl":
pickle.dump(obj, open(wfp, "wb"))
else:
raise Exception("Unknown type: {}".format(_suffix))

View File

@ -25,6 +25,7 @@ def to_ndarray(obj):
class LandmarkRunner(object):
"""landmark runner"""
def __init__(self, **kwargs):
ckpt_path = kwargs.get('ckpt_path')
onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda
@ -55,6 +56,7 @@ class LandmarkRunner(object):
crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1)
img_crop_rgb = crop_dct['img_crop']
else:
# NOTE: force resize to 224x224, NOT RECOMMEND!
img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize))
scale = max(img_rgb.shape[:2]) / self.dsize
crop_dct = {
@ -70,15 +72,13 @@ class LandmarkRunner(object):
out_lst = self._run(inp)
out_pts = out_lst[2]
pts = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224
pts = _transform_pts(pts, M=crop_dct['M_c2o'])
# 2d landmarks 203 points
lmk = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224
lmk = _transform_pts(lmk, M=crop_dct['M_c2o'])
return {
'pts': pts, # 2d landmarks 203 points
}
return lmk
def warmup(self):
# 构造dummy image进行warmup
self.timer.tic()
dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32)

View File

@ -7,32 +7,11 @@ import numpy as np
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
"""
Calculate the ratio of the distance between two pairs of landmarks.
Parameters:
lmk (np.ndarray): Landmarks array of shape (B, N, 2).
idx1, idx2, idx3, idx4 (int): Indices of the landmarks.
eps (float): Small value to avoid division by zero.
Returns:
np.ndarray: Calculated distance ratio.
"""
return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
(np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))
def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
"""
Calculate the eye-close ratio for left and right eyes.
Parameters:
lmk (np.ndarray): Landmarks array of shape (B, N, 2).
target_eye_ratio (np.ndarray, optional): Additional target eye ratio array to include.
Returns:
np.ndarray: Concatenated eye-close ratios.
"""
lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
if target_eye_ratio is not None:
@ -42,13 +21,4 @@ def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -
def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
"""
Calculate the lip-close ratio.
Parameters:
lmk (np.ndarray): Landmarks array of shape (B, N, 2).
Returns:
np.ndarray: Calculated lip-close ratio.
"""
return calculate_distance_ratio(lmk, 90, 102, 48, 66)

View File

@ -1,7 +1,9 @@
# coding: utf-8
"""
functions for processing video
Functions for processing video
ATTENTION: you need to install ffmpeg and ffprobe in your env!
"""
import os.path as osp
@ -9,14 +11,15 @@ import numpy as np
import subprocess
import imageio
import cv2
from rich.progress import track
from .helper import prefix
from .rprint import rlog as log
from .rprint import rprint as print
from .helper import prefix
def exec_cmd(cmd):
subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
def images2video(images, wfp, **kwargs):
@ -35,7 +38,7 @@ def images2video(images, wfp, **kwargs):
)
n = len(images)
for i in track(range(n), description='writing', transient=True):
for i in track(range(n), description='Writing', transient=True):
if image_mode.lower() == 'bgr':
writer.append_data(images[i][..., ::-1])
else:
@ -43,9 +46,6 @@ def images2video(images, wfp, **kwargs):
writer.close()
# print(f':smiley: Dump to {wfp}\n', style="bold green")
print(f'Dump to {wfp}\n')
def video2gif(video_fp, fps=30, size=256):
if osp.exists(video_fp):
@ -54,10 +54,10 @@ def video2gif(video_fp, fps=30, size=256):
palette_wfp = osp.join(d, 'palette.png')
gif_wfp = osp.join(d, f'{fn}.gif')
# generate the palette
cmd = f'ffmpeg -i {video_fp} -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" {palette_wfp} -y'
cmd = f'ffmpeg -i "{video_fp}" -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" "{palette_wfp}" -y'
exec_cmd(cmd)
# use the palette to generate the gif
cmd = f'ffmpeg -i {video_fp} -i {palette_wfp} -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" {gif_wfp} -y'
cmd = f'ffmpeg -i "{video_fp}" -i "{palette_wfp}" -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" "{gif_wfp}" -y'
exec_cmd(cmd)
else:
print(f'video_fp: {video_fp} not exists!')
@ -65,7 +65,7 @@ def video2gif(video_fp, fps=30, size=256):
def merge_audio_video(video_fp, audio_fp, wfp):
if osp.exists(video_fp) and osp.exists(audio_fp):
cmd = f'ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y'
cmd = f'ffmpeg -i "{video_fp}" -i "{audio_fp}" -c:v copy -c:a aac "{wfp}" -y'
exec_cmd(cmd)
print(f'merge {video_fp} and {audio_fp} to {wfp}')
else:
@ -80,21 +80,23 @@ def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):
return img
def concat_frames(I_p_lst, driving_rgb_lst, img_rgb):
def concat_frames(driving_image_lst, source_image, I_p_lst):
# TODO: add more concat style, e.g., left-down corner driving
out_lst = []
h, w, _ = I_p_lst[0].shape
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
source_image_drived = I_p_lst[idx]
image_drive = driving_rgb_lst[idx]
I_p = I_p_lst[idx]
source_image_resized = cv2.resize(source_image, (w, h))
# resize images to match source_image_drived shape
h, w, _ = source_image_drived.shape
image_drive_resized = cv2.resize(image_drive, (w, h))
img_rgb_resized = cv2.resize(img_rgb, (w, h))
if driving_image_lst is None:
out = np.hstack((source_image_resized, I_p))
else:
driving_image = driving_image_lst[idx]
driving_image_resized = cv2.resize(driving_image, (w, h))
out = np.hstack((driving_image_resized, source_image_resized, I_p))
# concatenate images horizontally
frame = np.concatenate((image_drive_resized, img_rgb_resized, source_image_drived), axis=1)
out_lst.append(frame)
out_lst.append(out)
return out_lst
@ -126,14 +128,84 @@ class VideoWriter:
self.writer.close()
def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=5):
cmd = f"ffmpeg -i {input_file} -c:v {codec} -crf {crf} -r {fps} {output_file} -y"
def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=12):
cmd = f'ffmpeg -i "{input_file}" -c:v {codec} -crf {crf} -r {fps} "{output_file}" -y'
exec_cmd(cmd)
def get_fps(filepath):
import ffmpeg
probe = ffmpeg.probe(filepath)
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
fps = eval(video_stream['avg_frame_rate'])
def get_fps(filepath, default_fps=25):
try:
fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS)
if fps in (0, None):
fps = default_fps
except Exception as e:
log(e)
fps = default_fps
return fps
def has_audio_stream(video_path: str) -> bool:
"""
Check if the video file contains an audio stream.
:param video_path: Path to the video file
:return: True if the video contains an audio stream, False otherwise
"""
if osp.isdir(video_path):
return False
cmd = [
'ffprobe',
'-v', 'error',
'-select_streams', 'a',
'-show_entries', 'stream=codec_type',
'-of', 'default=noprint_wrappers=1:nokey=1',
f'"{video_path}"'
]
try:
# result = subprocess.run(cmd, capture_output=True, text=True)
result = exec_cmd(' '.join(cmd))
if result.returncode != 0:
log(f"Error occurred while probing video: {result.stderr}")
return False
# Check if there is any output from ffprobe command
return bool(result.stdout.strip())
except Exception as e:
log(f"Error occurred while probing video: {video_path}, you may need to install ffprobe! Now set audio to false!", style="bold red")
return False
def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
cmd = [
'ffmpeg',
'-y',
'-i', f'"{silent_video_path}"',
'-i', f'"{audio_video_path}"',
'-map', '0:v',
'-map', '1:a',
'-c:v', 'copy',
'-shortest',
f'"{output_video_path}"'
]
try:
exec_cmd(' '.join(cmd))
log(f"Video with audio generated successfully: {output_video_path}")
except subprocess.CalledProcessError as e:
log(f"Error occurred: {e}")
def bb_intersection_over_union(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
iou = interArea / float(boxAArea + boxBArea - interArea)
return iou

19
src/utils/viz.py Normal file
View File

@ -0,0 +1,19 @@
# coding: utf-8
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
def viz_lmk(img_, vps, **kwargs):
"""可视化点"""
lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA
img_for_viz = img_.copy()
for pt in vps:
cv2.circle(
img_for_viz,
(int(pt[0]), int(pt[1])),
radius=kwargs.get("radius", 1),
color=(0, 255, 0),
thickness=kwargs.get("thickness", 1),
lineType=lineType,
)
return img_for_viz

View File

@ -1,37 +0,0 @@
# coding: utf-8
"""
[WIP] Pipeline for video template preparation
"""
import tyro
from src.config.crop_config import CropConfig
from src.config.inference_config import InferenceConfig
from src.config.argument_config import ArgumentConfig
from src.template_maker import TemplateMaker
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
video_template_maker = TemplateMaker(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)
# run
video_template_maker.make_motion_template(args.driving_video_path, args.template_output_dir)
if __name__ == '__main__':
main()