mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-23 13:04:23 +00:00
Merge branch 'main' into develop
This commit is contained in:
commit
b0ad5dc17c
4
.gitignore
vendored
4
.gitignore
vendored
@ -9,9 +9,13 @@ __pycache__/
|
|||||||
**/*.pth
|
**/*.pth
|
||||||
**/*.onnx
|
**/*.onnx
|
||||||
|
|
||||||
|
pretrained_weights/*.md
|
||||||
|
pretrained_weights/docs
|
||||||
|
|
||||||
# Ipython notebook
|
# Ipython notebook
|
||||||
*.ipynb
|
*.ipynb
|
||||||
|
|
||||||
# Temporary files or benchmark resources
|
# Temporary files or benchmark resources
|
||||||
animations/*
|
animations/*
|
||||||
tmp/*
|
tmp/*
|
||||||
|
.vscode/launch.json
|
||||||
|
102
app.py
102
app.py
@ -25,28 +25,40 @@ args = tyro.cli(ArgumentConfig)
|
|||||||
# specify configs for inference
|
# specify configs for inference
|
||||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||||
|
|
||||||
gradio_pipeline = GradioPipeline(
|
gradio_pipeline = GradioPipeline(
|
||||||
inference_cfg=inference_cfg,
|
inference_cfg=inference_cfg,
|
||||||
crop_cfg=crop_cfg,
|
crop_cfg=crop_cfg,
|
||||||
args=args
|
args=args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_wrapped_execute_video(*args, **kwargs):
|
||||||
|
return gradio_pipeline.execute_video(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_wrapped_execute_image(*args, **kwargs):
|
||||||
|
return gradio_pipeline.execute_image(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# assets
|
# assets
|
||||||
title_md = "assets/gradio_title.md"
|
title_md = "assets/gradio_title.md"
|
||||||
example_portrait_dir = "assets/examples/source"
|
example_portrait_dir = "assets/examples/source"
|
||||||
example_video_dir = "assets/examples/driving"
|
example_video_dir = "assets/examples/driving"
|
||||||
data_examples = [
|
data_examples = [
|
||||||
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
|
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||||
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
|
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||||
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True, True],
|
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||||
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True, True],
|
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
|
||||||
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True, True],
|
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
|
||||||
|
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
|
||||||
]
|
]
|
||||||
#################### interface logic ####################
|
#################### interface logic ####################
|
||||||
|
|
||||||
# Define components first
|
# Define components first
|
||||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
||||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||||
retargeting_input_image = gr.Image(type="numpy")
|
retargeting_input_image = gr.Image(type="filepath")
|
||||||
output_image = gr.Image(type="numpy")
|
output_image = gr.Image(type="numpy")
|
||||||
output_image_paste_back = gr.Image(type="numpy")
|
output_image_paste_back = gr.Image(type="numpy")
|
||||||
output_video = gr.Video()
|
output_video = gr.Video()
|
||||||
@ -58,15 +70,39 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Accordion(open=True, label="Source Portrait"):
|
with gr.Accordion(open=True, label="Source Portrait"):
|
||||||
image_input = gr.Image(type="filepath")
|
image_input = gr.Image(type="filepath")
|
||||||
|
gr.Examples(
|
||||||
|
examples=[
|
||||||
|
[osp.join(example_portrait_dir, "s9.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s6.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s10.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||||
|
],
|
||||||
|
inputs=[image_input],
|
||||||
|
cache_examples=False,
|
||||||
|
)
|
||||||
with gr.Accordion(open=True, label="Driving Video"):
|
with gr.Accordion(open=True, label="Driving Video"):
|
||||||
video_input = gr.Video()
|
video_input = gr.Video()
|
||||||
gr.Markdown(load_description("assets/gradio_description_animation.md"))
|
gr.Examples(
|
||||||
|
examples=[
|
||||||
|
[osp.join(example_video_dir, "d0.mp4")],
|
||||||
|
[osp.join(example_video_dir, "d18.mp4")],
|
||||||
|
[osp.join(example_video_dir, "d19.mp4")],
|
||||||
|
[osp.join(example_video_dir, "d14.mp4")],
|
||||||
|
[osp.join(example_video_dir, "d6.mp4")],
|
||||||
|
],
|
||||||
|
inputs=[video_input],
|
||||||
|
cache_examples=False,
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Accordion(open=True, label="Animation Options"):
|
with gr.Accordion(open=False, label="Animation Instructions and Options"):
|
||||||
|
gr.Markdown(load_description("assets/gradio_description_animation.md"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
||||||
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
|
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
|
||||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||||
|
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||||
@ -81,24 +117,28 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
output_video_concat.render()
|
output_video_concat.render()
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# Examples
|
# Examples
|
||||||
gr.Markdown("## You could choose the examples below ⬇️")
|
gr.Markdown("## You could also choose the examples below by one click ⬇️")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
examples=data_examples,
|
examples=data_examples,
|
||||||
|
fn=gpu_wrapped_execute_video,
|
||||||
inputs=[
|
inputs=[
|
||||||
image_input,
|
image_input,
|
||||||
video_input,
|
video_input,
|
||||||
flag_relative_input,
|
flag_relative_input,
|
||||||
flag_do_crop_input,
|
flag_do_crop_input,
|
||||||
flag_remap_input
|
flag_remap_input,
|
||||||
|
flag_crop_driving_video_input
|
||||||
],
|
],
|
||||||
examples_per_page=5
|
outputs=[output_image, output_image_paste_back],
|
||||||
|
examples_per_page=len(data_examples),
|
||||||
|
cache_examples=False,
|
||||||
)
|
)
|
||||||
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
|
gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
|
||||||
with gr.Row():
|
with gr.Row(visible=True):
|
||||||
eye_retargeting_slider.render()
|
eye_retargeting_slider.render()
|
||||||
lip_retargeting_slider.render()
|
lip_retargeting_slider.render()
|
||||||
with gr.Row():
|
with gr.Row(visible=True):
|
||||||
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
||||||
process_button_reset_retargeting = gr.ClearButton(
|
process_button_reset_retargeting = gr.ClearButton(
|
||||||
[
|
[
|
||||||
@ -110,10 +150,22 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
],
|
],
|
||||||
value="🧹 Clear"
|
value="🧹 Clear"
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row(visible=True):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="Retargeting Input"):
|
with gr.Accordion(open=True, label="Retargeting Input"):
|
||||||
retargeting_input_image.render()
|
retargeting_input_image.render()
|
||||||
|
gr.Examples(
|
||||||
|
examples=[
|
||||||
|
[osp.join(example_portrait_dir, "s9.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s6.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s10.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||||
|
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||||
|
],
|
||||||
|
inputs=[retargeting_input_image],
|
||||||
|
cache_examples=False,
|
||||||
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion(open=True, label="Retargeting Result"):
|
with gr.Accordion(open=True, label="Retargeting Result"):
|
||||||
output_image.render()
|
output_image.render()
|
||||||
@ -122,33 +174,29 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
output_image_paste_back.render()
|
output_image_paste_back.render()
|
||||||
# binding functions for buttons
|
# binding functions for buttons
|
||||||
process_button_retargeting.click(
|
process_button_retargeting.click(
|
||||||
fn=gradio_pipeline.execute_image,
|
# fn=gradio_pipeline.execute_image,
|
||||||
inputs=[eye_retargeting_slider, lip_retargeting_slider],
|
fn=gpu_wrapped_execute_image,
|
||||||
|
inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
|
||||||
outputs=[output_image, output_image_paste_back],
|
outputs=[output_image, output_image_paste_back],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
process_button_animation.click(
|
process_button_animation.click(
|
||||||
fn=gradio_pipeline.execute_video,
|
fn=gpu_wrapped_execute_video,
|
||||||
inputs=[
|
inputs=[
|
||||||
image_input,
|
image_input,
|
||||||
video_input,
|
video_input,
|
||||||
flag_relative_input,
|
flag_relative_input,
|
||||||
flag_do_crop_input,
|
flag_do_crop_input,
|
||||||
flag_remap_input
|
flag_remap_input,
|
||||||
|
flag_crop_driving_video_input
|
||||||
],
|
],
|
||||||
outputs=[output_video, output_video_concat],
|
outputs=[output_video, output_video_concat],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
image_input.change(
|
|
||||||
fn=gradio_pipeline.prepare_retargeting,
|
|
||||||
inputs=image_input,
|
|
||||||
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
|
|
||||||
)
|
|
||||||
|
|
||||||
##########################################################
|
|
||||||
|
|
||||||
demo.launch(
|
demo.launch(
|
||||||
server_name=args.server_name,
|
|
||||||
server_port=args.server_port,
|
server_port=args.server_port,
|
||||||
share=args.share,
|
share=args.share,
|
||||||
|
server_name=args.server_name
|
||||||
)
|
)
|
||||||
|
2
assets/.gitignore
vendored
Normal file
2
assets/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
examples/driving/*.pkl
|
||||||
|
examples/driving/*_crop.mp4
|
22
assets/docs/changelog/2024-07-10.md
Normal file
22
assets/docs/changelog/2024-07-10.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
## 2024/07/10
|
||||||
|
|
||||||
|
**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️
|
||||||
|
The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
|
||||||
|
|
||||||
|
### Updates
|
||||||
|
|
||||||
|
- <strong>Audio and video concatenating: </strong> If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
|
||||||
|
|
||||||
|
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
|
||||||
|
|
||||||
|
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option.
|
||||||
|
|
||||||
|
|
||||||
|
### About driving video
|
||||||
|
|
||||||
|
- For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
|
||||||
|
|
||||||
|
|
||||||
|
### Others
|
||||||
|
|
||||||
|
- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).
|
Binary file not shown.
BIN
assets/examples/driving/d1.pkl
Normal file
BIN
assets/examples/driving/d1.pkl
Normal file
Binary file not shown.
BIN
assets/examples/driving/d10.mp4
Normal file
BIN
assets/examples/driving/d10.mp4
Normal file
Binary file not shown.
BIN
assets/examples/driving/d11.mp4
Normal file
BIN
assets/examples/driving/d11.mp4
Normal file
Binary file not shown.
BIN
assets/examples/driving/d12.mp4
Normal file
BIN
assets/examples/driving/d12.mp4
Normal file
Binary file not shown.
BIN
assets/examples/driving/d13.mp4
Normal file
BIN
assets/examples/driving/d13.mp4
Normal file
Binary file not shown.
BIN
assets/examples/driving/d14.mp4
Normal file
BIN
assets/examples/driving/d14.mp4
Normal file
Binary file not shown.
BIN
assets/examples/driving/d18.mp4
Normal file
BIN
assets/examples/driving/d18.mp4
Normal file
Binary file not shown.
BIN
assets/examples/driving/d19.mp4
Normal file
BIN
assets/examples/driving/d19.mp4
Normal file
Binary file not shown.
Binary file not shown.
BIN
assets/examples/driving/d2.pkl
Normal file
BIN
assets/examples/driving/d2.pkl
Normal file
Binary file not shown.
Binary file not shown.
BIN
assets/examples/driving/d5.pkl
Normal file
BIN
assets/examples/driving/d5.pkl
Normal file
Binary file not shown.
Binary file not shown.
BIN
assets/examples/driving/d7.pkl
Normal file
BIN
assets/examples/driving/d7.pkl
Normal file
Binary file not shown.
Binary file not shown.
BIN
assets/examples/driving/d8.pkl
Normal file
BIN
assets/examples/driving/d8.pkl
Normal file
Binary file not shown.
BIN
assets/examples/source/s11.jpg
Normal file
BIN
assets/examples/source/s11.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 102 KiB |
BIN
assets/examples/source/s12.jpg
Normal file
BIN
assets/examples/source/s12.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 49 KiB |
@ -1,7 +1,16 @@
|
|||||||
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
|
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
|
||||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||||
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
|
1. In the <strong>Animation Options</strong> section, we recommend enabling the <strong>do crop (source)</strong> option if faces occupy a small portion of your image.
|
||||||
</div>
|
</div>
|
||||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||||
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
|
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
|
||||||
|
</div>
|
||||||
|
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||||
|
3. If you want to upload your own driving video, <strong>the best practice</strong>:
|
||||||
|
|
||||||
|
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
|
||||||
|
- Focus on the head area, similar to the example videos.
|
||||||
|
- Minimize shoulder movement.
|
||||||
|
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
@ -1 +1,4 @@
|
|||||||
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
<br>
|
||||||
|
|
||||||
|
## Retargeting
|
||||||
|
<span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
## 🤗 This is the official gradio demo for **LivePortrait**.
|
## 🤗 This is the official gradio demo for **LivePortrait**.
|
||||||
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
|
<div style="font-size: 1.2em;">Please upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) and upload a <strong>Driving Video</strong> (1:1 aspect ratio, or any aspect ratio with <code>do crop (driving video)</code> checked).</div>
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
||||||
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
||||||
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
||||||
|
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
27
inference.py
27
inference.py
@ -1,6 +1,8 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
|
import os.path as osp
|
||||||
import tyro
|
import tyro
|
||||||
|
import subprocess
|
||||||
from src.config.argument_config import ArgumentConfig
|
from src.config.argument_config import ArgumentConfig
|
||||||
from src.config.inference_config import InferenceConfig
|
from src.config.inference_config import InferenceConfig
|
||||||
from src.config.crop_config import CropConfig
|
from src.config.crop_config import CropConfig
|
||||||
@ -11,11 +13,34 @@ def partial_fields(target_class, kwargs):
|
|||||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||||
|
|
||||||
|
|
||||||
|
def fast_check_ffmpeg():
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def fast_check_args(args: ArgumentConfig):
|
||||||
|
if not osp.exists(args.source_image):
|
||||||
|
raise FileNotFoundError(f"source image not found: {args.source_image}")
|
||||||
|
if not osp.exists(args.driving_info):
|
||||||
|
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# set tyro theme
|
# set tyro theme
|
||||||
tyro.extras.set_accent_color("bright_cyan")
|
tyro.extras.set_accent_color("bright_cyan")
|
||||||
args = tyro.cli(ArgumentConfig)
|
args = tyro.cli(ArgumentConfig)
|
||||||
|
|
||||||
|
if not fast_check_ffmpeg():
|
||||||
|
raise ImportError(
|
||||||
|
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
|
||||||
|
)
|
||||||
|
|
||||||
|
# fast check the args
|
||||||
|
fast_check_args(args)
|
||||||
|
|
||||||
# specify configs for inference
|
# specify configs for inference
|
||||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||||
@ -29,5 +54,5 @@ def main():
|
|||||||
live_portrait_pipeline.execute(args)
|
live_portrait_pipeline.execute(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
79
readme.md
79
readme.md
@ -4,7 +4,7 @@
|
|||||||
<a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup> 
|
<a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup> 
|
||||||
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup> 
|
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup> 
|
||||||
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup> 
|
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup> 
|
||||||
<a href='https://github.com/KwaiVGI' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup> 
|
<a href='https://scholar.google.com/citations?user=t88nyvsAAAAJ&hl' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup> 
|
||||||
<a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup> 
|
<a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup> 
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -35,8 +35,12 @@
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Updates
|
## 🔥 Updates
|
||||||
- **`2024/07/04`**: 🔥 We released the initial version of the inference code and models. Continuous updates, stay tuned!
|
- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).
|
||||||
- **`2024/07/04`**: 😊 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
|
- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!
|
||||||
|
- **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned!
|
||||||
|
- **`2024/07/04`**: 🔥 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
|
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
|
||||||
@ -55,8 +59,19 @@ conda activate LivePortrait
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/) installed!
|
||||||
|
|
||||||
### 2. Download pretrained weights
|
### 2. Download pretrained weights
|
||||||
Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
|
|
||||||
|
The easiest way to download the pretrained weights is from HuggingFace:
|
||||||
|
```bash
|
||||||
|
# you may need to run `git lfs install` first
|
||||||
|
git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
|
||||||
|
|
||||||
|
Ensuring the directory structure is as follows, or contains:
|
||||||
```text
|
```text
|
||||||
pretrained_weights
|
pretrained_weights
|
||||||
├── insightface
|
├── insightface
|
||||||
@ -77,6 +92,7 @@ pretrained_weights
|
|||||||
|
|
||||||
### 3. Inference 🚀
|
### 3. Inference 🚀
|
||||||
|
|
||||||
|
#### Fast hands-on
|
||||||
```bash
|
```bash
|
||||||
python inference.py
|
python inference.py
|
||||||
```
|
```
|
||||||
@ -92,16 +108,37 @@ Or, you can change the input by specifying the `-s` and `-d` arguments:
|
|||||||
```bash
|
```bash
|
||||||
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
|
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
|
||||||
|
|
||||||
# or disable pasting back
|
# disable pasting back to run faster
|
||||||
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
|
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
|
||||||
|
|
||||||
# more options to see
|
# more options to see
|
||||||
python inference.py -h
|
python inference.py -h
|
||||||
```
|
```
|
||||||
|
|
||||||
**More interesting results can be found in our [Homepage](https://liveportrait.github.io)** 😊
|
#### Driving video auto-cropping
|
||||||
|
|
||||||
### 4. Gradio interface
|
📕 To use your own driving video, we **recommend**:
|
||||||
|
- Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
|
||||||
|
- Focus on the head area, similar to the example videos.
|
||||||
|
- Minimize shoulder movement.
|
||||||
|
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
|
||||||
|
|
||||||
|
Below is a auto-cropping case by `--flag_crop_driving_video`:
|
||||||
|
```bash
|
||||||
|
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video
|
||||||
|
```
|
||||||
|
|
||||||
|
If you find the results of auto-cropping is not well, you can modify the `--scale_crop_video`, `--vy_ratio_crop_video` options to adjust the scale and offset, or do it manually.
|
||||||
|
|
||||||
|
#### Motion template making
|
||||||
|
You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as:
|
||||||
|
```bash
|
||||||
|
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl
|
||||||
|
```
|
||||||
|
|
||||||
|
**Discover more interesting results on our [Homepage](https://liveportrait.github.io)** 😊
|
||||||
|
|
||||||
|
### 4. Gradio interface 🤗
|
||||||
|
|
||||||
We also provide a Gradio interface for a better experience, just run by:
|
We also provide a Gradio interface for a better experience, just run by:
|
||||||
|
|
||||||
@ -109,6 +146,10 @@ We also provide a Gradio interface for a better experience, just run by:
|
|||||||
python app.py
|
python app.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
|
||||||
|
|
||||||
|
**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗**
|
||||||
|
|
||||||
### 5. Inference speed evaluation 🚀🚀🚀
|
### 5. Inference speed evaluation 🚀🚀🚀
|
||||||
We have also provided a script to evaluate the inference speed of each module:
|
We have also provided a script to evaluate the inference speed of each module:
|
||||||
|
|
||||||
@ -124,10 +165,22 @@ Below are the results of inferring one frame on an RTX 4090 GPU using the native
|
|||||||
| Motion Extractor | 28.12 | 108 | 0.84 |
|
| Motion Extractor | 28.12 | 108 | 0.84 |
|
||||||
| Spade Generator | 55.37 | 212 | 7.59 |
|
| Spade Generator | 55.37 | 212 | 7.59 |
|
||||||
| Warping Module | 45.53 | 174 | 5.21 |
|
| Warping Module | 45.53 | 174 | 5.21 |
|
||||||
| Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 |
|
| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
|
||||||
|
|
||||||
*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.*
|
*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*
|
||||||
|
|
||||||
|
## Community Resources 🤗
|
||||||
|
|
||||||
|
Discover the invaluable resources contributed by our community to enhance your LivePortrait experience:
|
||||||
|
|
||||||
|
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai)
|
||||||
|
- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007)
|
||||||
|
- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch)
|
||||||
|
- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph)
|
||||||
|
- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker)
|
||||||
|
- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr)
|
||||||
|
|
||||||
|
And many more amazing contributions from our community!
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
|
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
|
||||||
@ -135,10 +188,10 @@ We would like to thank the contributors of [FOMM](https://github.com/AliaksandrS
|
|||||||
## Citation 💖
|
## Citation 💖
|
||||||
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
|
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{guo2024live,
|
@article{guo2024liveportrait,
|
||||||
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
|
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
|
||||||
author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang},
|
author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di},
|
||||||
year = {2024},
|
journal = {arXiv preprint arXiv:2407.03168},
|
||||||
journal = {arXiv preprint:2407.03168},
|
year = {2024}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -11,7 +11,7 @@ imageio==2.34.2
|
|||||||
lmdb==1.4.1
|
lmdb==1.4.1
|
||||||
tqdm==4.66.4
|
tqdm==4.66.4
|
||||||
rich==13.7.1
|
rich==13.7.1
|
||||||
ffmpeg==1.4
|
ffmpeg-python==0.2.0
|
||||||
onnxruntime-gpu==1.18.0
|
onnxruntime-gpu==1.18.0
|
||||||
onnx==1.16.1
|
onnx==1.16.1
|
||||||
scikit-image==0.24.0
|
scikit-image==0.24.0
|
||||||
|
10
speed.py
10
speed.py
@ -47,11 +47,11 @@ def load_and_compile_models(cfg, model_config):
|
|||||||
"""
|
"""
|
||||||
Load and compile models for inference
|
Load and compile models for inference
|
||||||
"""
|
"""
|
||||||
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor')
|
||||||
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor')
|
||||||
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module')
|
||||||
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator')
|
||||||
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module')
|
||||||
|
|
||||||
models_with_params = [
|
models_with_params = [
|
||||||
('Appearance Feature Extractor', appearance_feature_extractor),
|
('Appearance Feature Extractor', appearance_feature_extractor),
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config for user
|
All configs for user
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os.path as osp
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import tyro
|
import tyro
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
from typing import Optional
|
||||||
from .base_config import PrintableConfig, make_abs_path
|
from .base_config import PrintableConfig, make_abs_path
|
||||||
|
|
||||||
|
|
||||||
@ -17,28 +17,31 @@ class ArgumentConfig(PrintableConfig):
|
|||||||
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
|
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
|
||||||
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
||||||
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
||||||
#####################################
|
|
||||||
|
|
||||||
########## inference arguments ##########
|
########## inference arguments ##########
|
||||||
device_id: int = 0
|
flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
|
||||||
|
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
|
||||||
|
device_id: int = 0 # gpu device id
|
||||||
|
flag_force_cpu: bool = False # force cpu inference, WIP!
|
||||||
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
||||||
flag_eye_retargeting: bool = False
|
flag_eye_retargeting: bool = False # not recommend to be True, WIP
|
||||||
flag_lip_retargeting: bool = False
|
flag_lip_retargeting: bool = False # not recommend to be True, WIP
|
||||||
flag_stitching: bool = True # we recommend setting it to True!
|
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
|
||||||
flag_relative: bool = True # whether to use relative motion
|
flag_relative_motion: bool = True # whether to use relative motion
|
||||||
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
||||||
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
|
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
|
||||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
||||||
#########################################
|
|
||||||
|
|
||||||
########## crop arguments ##########
|
########## crop arguments ##########
|
||||||
dsize: int = 512
|
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
||||||
scale: float = 2.3
|
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
|
||||||
vx_ratio: float = 0 # vx ratio
|
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
|
||||||
vy_ratio: float = -0.125 # vy ratio +up, -down
|
|
||||||
####################################
|
scale_crop_video: float = 2.2 # scale factor for cropping video
|
||||||
|
vx_ratio_crop_video: float = 0. # adjust y offset
|
||||||
|
vy_ratio_crop_video: float = -0.1 # adjust x offset
|
||||||
|
|
||||||
########## gradio arguments ##########
|
########## gradio arguments ##########
|
||||||
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
|
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
|
||||||
share: bool = True
|
share: bool = False # whether to share the server to public
|
||||||
server_name: str = "0.0.0.0"
|
server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
|
||||||
|
@ -4,15 +4,26 @@
|
|||||||
parameters used for crop faces
|
parameters used for crop faces
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os.path as osp
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union, List
|
|
||||||
from .base_config import PrintableConfig
|
from .base_config import PrintableConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False) # use repr from PrintableConfig
|
@dataclass(repr=False) # use repr from PrintableConfig
|
||||||
class CropConfig(PrintableConfig):
|
class CropConfig(PrintableConfig):
|
||||||
|
insightface_root: str = "../../pretrained_weights/insightface"
|
||||||
|
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
|
||||||
|
device_id: int = 0 # gpu device id
|
||||||
|
flag_force_cpu: bool = False # force cpu inference, WIP
|
||||||
|
########## source image cropping option ##########
|
||||||
dsize: int = 512 # crop size
|
dsize: int = 512 # crop size
|
||||||
scale: float = 2.3 # scale factor
|
scale: float = 2.5 # scale factor
|
||||||
vx_ratio: float = 0 # vx ratio
|
vx_ratio: float = 0 # vx ratio
|
||||||
vy_ratio: float = -0.125 # vy ratio +up, -down
|
vy_ratio: float = -0.125 # vy ratio +up, -down
|
||||||
|
max_face_num: int = 0 # max face number, 0 mean no limit
|
||||||
|
|
||||||
|
########## driving video auto cropping option ##########
|
||||||
|
scale_crop_video: float = 2.2 # 2.0 # scale factor for cropping video
|
||||||
|
vx_ratio_crop_video: float = 0.0 # adjust y offset
|
||||||
|
vy_ratio_crop_video: float = -0.1 # adjust x offset
|
||||||
|
direction: str = "large-small" # direction of cropping
|
||||||
|
@ -5,6 +5,8 @@ config dataclass used for inference
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import cv2
|
||||||
|
from numpy import ndarray
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal, Tuple
|
from typing import Literal, Tuple
|
||||||
from .base_config import PrintableConfig, make_abs_path
|
from .base_config import PrintableConfig, make_abs_path
|
||||||
@ -12,38 +14,38 @@ from .base_config import PrintableConfig, make_abs_path
|
|||||||
|
|
||||||
@dataclass(repr=False) # use repr from PrintableConfig
|
@dataclass(repr=False) # use repr from PrintableConfig
|
||||||
class InferenceConfig(PrintableConfig):
|
class InferenceConfig(PrintableConfig):
|
||||||
|
# MODEL CONFIG, NOT EXPOERTED PARAMS
|
||||||
models_config: str = make_abs_path('./models.yaml') # portrait animation config
|
models_config: str = make_abs_path('./models.yaml') # portrait animation config
|
||||||
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint
|
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
|
||||||
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint
|
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
|
||||||
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint
|
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G
|
||||||
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint
|
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
|
||||||
|
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
|
||||||
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint
|
|
||||||
flag_use_half_precision: bool = True # whether to use half precision
|
|
||||||
|
|
||||||
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
|
||||||
lip_zero_threshold: float = 0.03
|
|
||||||
|
|
||||||
|
# EXPOERTED PARAMS
|
||||||
|
flag_use_half_precision: bool = True
|
||||||
|
flag_crop_driving_video: bool = False
|
||||||
|
device_id: int = 0
|
||||||
|
flag_lip_zero: bool = True
|
||||||
flag_eye_retargeting: bool = False
|
flag_eye_retargeting: bool = False
|
||||||
flag_lip_retargeting: bool = False
|
flag_lip_retargeting: bool = False
|
||||||
flag_stitching: bool = True # we recommend setting it to True!
|
flag_stitching: bool = True
|
||||||
|
flag_relative_motion: bool = True
|
||||||
|
flag_pasteback: bool = True
|
||||||
|
flag_do_crop: bool = True
|
||||||
|
flag_do_rot: bool = True
|
||||||
|
flag_force_cpu: bool = False
|
||||||
|
|
||||||
flag_relative: bool = True # whether to use relative motion
|
# NOT EXPOERTED PARAMS
|
||||||
anchor_frame: int = 0 # set this value if find_best_frame is True
|
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
|
||||||
|
anchor_frame: int = 0 # TO IMPLEMENT
|
||||||
|
|
||||||
input_shape: Tuple[int, int] = (256, 256) # input shape
|
input_shape: Tuple[int, int] = (256, 256) # input shape
|
||||||
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
|
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
|
||||||
output_fps: int = 30 # fps for output video
|
|
||||||
crf: int = 15 # crf for output video
|
crf: int = 15 # crf for output video
|
||||||
|
output_fps: int = 25 # default output fps
|
||||||
|
|
||||||
flag_write_result: bool = True # whether to write output video
|
mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
|
||||||
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
size_gif: int = 256 # default gif size, TO IMPLEMENT
|
||||||
mask_crop = None
|
source_max_dim: int = 1280 # the max dim of height and width of source image
|
||||||
flag_write_gif: bool = False
|
source_division: int = 2 # make sure the height and width of source image can be divided by this number
|
||||||
size_gif: int = 256
|
|
||||||
ref_max_shape: int = 1280
|
|
||||||
ref_shape_n: int = 2
|
|
||||||
|
|
||||||
device_id: int = 0
|
|
||||||
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
|
|
||||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
|
||||||
|
@ -4,13 +4,14 @@
|
|||||||
Pipeline for gradio
|
Pipeline for gradio
|
||||||
"""
|
"""
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from .config.argument_config import ArgumentConfig
|
from .config.argument_config import ArgumentConfig
|
||||||
from .live_portrait_pipeline import LivePortraitPipeline
|
from .live_portrait_pipeline import LivePortraitPipeline
|
||||||
from .utils.io import load_img_online
|
from .utils.io import load_img_online
|
||||||
from .utils.rprint import rlog as log
|
from .utils.rprint import rlog as log
|
||||||
from .utils.crop import prepare_paste_back, paste_back
|
from .utils.crop import prepare_paste_back, paste_back
|
||||||
from .utils.camera import get_rotation_matrix
|
from .utils.camera import get_rotation_matrix
|
||||||
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
|
||||||
|
|
||||||
def update_args(args, user_args):
|
def update_args(args, user_args):
|
||||||
"""update the args according to user inputs
|
"""update the args according to user inputs
|
||||||
@ -20,22 +21,13 @@ def update_args(args, user_args):
|
|||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
class GradioPipeline(LivePortraitPipeline):
|
class GradioPipeline(LivePortraitPipeline):
|
||||||
|
|
||||||
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
|
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
|
||||||
super().__init__(inference_cfg, crop_cfg)
|
super().__init__(inference_cfg, crop_cfg)
|
||||||
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
||||||
self.args = args
|
self.args = args
|
||||||
# for single image retargeting
|
|
||||||
self.start_prepare = False
|
|
||||||
self.f_s_user = None
|
|
||||||
self.x_c_s_info_user = None
|
|
||||||
self.x_s_user = None
|
|
||||||
self.source_lmk_user = None
|
|
||||||
self.mask_ori = None
|
|
||||||
self.img_rgb = None
|
|
||||||
self.crop_M_c2o = None
|
|
||||||
|
|
||||||
|
|
||||||
def execute_video(
|
def execute_video(
|
||||||
self,
|
self,
|
||||||
@ -44,7 +36,8 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
flag_relative_input,
|
flag_relative_input,
|
||||||
flag_do_crop_input,
|
flag_do_crop_input,
|
||||||
flag_remap_input,
|
flag_remap_input,
|
||||||
):
|
flag_crop_driving_video_input
|
||||||
|
):
|
||||||
""" for video driven potrait animation
|
""" for video driven potrait animation
|
||||||
"""
|
"""
|
||||||
if input_image_path is not None and input_video_path is not None:
|
if input_image_path is not None and input_video_path is not None:
|
||||||
@ -54,6 +47,7 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
'flag_relative': flag_relative_input,
|
'flag_relative': flag_relative_input,
|
||||||
'flag_do_crop': flag_do_crop_input,
|
'flag_do_crop': flag_do_crop_input,
|
||||||
'flag_pasteback': flag_remap_input,
|
'flag_pasteback': flag_remap_input,
|
||||||
|
'flag_crop_driving_video': flag_crop_driving_video_input
|
||||||
}
|
}
|
||||||
# update config from user input
|
# update config from user input
|
||||||
self.args = update_args(self.args, args_user)
|
self.args = update_args(self.args, args_user)
|
||||||
@ -66,51 +60,45 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
else:
|
else:
|
||||||
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
||||||
|
|
||||||
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
|
||||||
""" for single image retargeting
|
""" for single image retargeting
|
||||||
"""
|
"""
|
||||||
if input_eye_ratio is None or input_eye_ratio is None:
|
# disposable feature
|
||||||
|
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
|
||||||
|
self.prepare_retargeting(input_image, flag_do_crop)
|
||||||
|
|
||||||
|
if input_eye_ratio is None or input_lip_ratio is None:
|
||||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||||
elif self.f_s_user is None:
|
|
||||||
if self.start_prepare:
|
|
||||||
raise gr.Error(
|
|
||||||
"The source portrait is under processing 💥! Please wait for a second.",
|
|
||||||
duration=5
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise gr.Error(
|
|
||||||
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
|
|
||||||
duration=5
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||||
|
x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
|
||||||
|
f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
|
||||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
|
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
|
||||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
|
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
|
||||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
|
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
|
||||||
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
|
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
|
||||||
num_kp = self.x_s_user.shape[1]
|
num_kp = x_s_user.shape[1]
|
||||||
# default: use x_s
|
# default: use x_s
|
||||||
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
||||||
# D(W(f_s; x_s, x′_d))
|
# D(W(f_s; x_s, x′_d))
|
||||||
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
|
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
|
||||||
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||||
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
|
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
|
||||||
gr.Info("Run successfully!", duration=2)
|
gr.Info("Run successfully!", duration=2)
|
||||||
return out, out_to_ori_blend
|
return out, out_to_ori_blend
|
||||||
|
|
||||||
|
def prepare_retargeting(self, input_image, flag_do_crop=True):
|
||||||
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
|
|
||||||
""" for single image retargeting
|
""" for single image retargeting
|
||||||
"""
|
"""
|
||||||
if input_image_path is not None:
|
if input_image is not None:
|
||||||
gr.Info("Upload successfully!", duration=2)
|
# gr.Info("Upload successfully!", duration=2)
|
||||||
self.start_prepare = True
|
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||||
inference_cfg = self.live_portrait_wrapper.cfg
|
|
||||||
######## process source portrait ########
|
######## process source portrait ########
|
||||||
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
|
||||||
log(f"Load source image from {input_image_path}.")
|
log(f"Load source image from {input_image}.")
|
||||||
crop_info = self.cropper.crop_single_image(img_rgb)
|
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
|
||||||
if flag_do_crop:
|
if flag_do_crop:
|
||||||
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
||||||
else:
|
else:
|
||||||
@ -118,23 +106,12 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||||
############################################
|
############################################
|
||||||
|
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||||
# record global info for next time use
|
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||||
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
source_lmk_user = crop_info['lmk_crop']
|
||||||
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
crop_M_c2o = crop_info['M_c2o']
|
||||||
self.x_s_info_user = x_s_info
|
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||||
self.source_lmk_user = crop_info['lmk_crop']
|
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
|
||||||
self.img_rgb = img_rgb
|
|
||||||
self.crop_M_c2o = crop_info['M_c2o']
|
|
||||||
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
|
||||||
# update slider
|
|
||||||
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
|
|
||||||
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
|
|
||||||
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
|
|
||||||
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
|
|
||||||
# for vis
|
|
||||||
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
|
|
||||||
return eye_close_ratio, lip_close_ratio, self.I_s_vis
|
|
||||||
else:
|
else:
|
||||||
# when press the clear button, go here
|
# when press the clear button, go here
|
||||||
return 0.8, 0.8, self.I_s_vis
|
raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)
|
||||||
|
@ -4,13 +4,12 @@
|
|||||||
Pipeline of LivePortrait
|
Pipeline of LivePortrait
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO:
|
import torch
|
||||||
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
|
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
|
||||||
# 2. pick样例图 source + driving
|
|
||||||
|
|
||||||
import cv2
|
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pickle
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
|
|
||||||
@ -19,12 +18,12 @@ from .config.inference_config import InferenceConfig
|
|||||||
from .config.crop_config import CropConfig
|
from .config.crop_config import CropConfig
|
||||||
from .utils.cropper import Cropper
|
from .utils.cropper import Cropper
|
||||||
from .utils.camera import get_rotation_matrix
|
from .utils.camera import get_rotation_matrix
|
||||||
from .utils.video import images2video, concat_frames
|
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
|
||||||
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
||||||
from .utils.retargeting_utils import calc_lip_close_ratio
|
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load
|
||||||
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
|
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix
|
||||||
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
|
|
||||||
from .utils.rprint import rlog as log
|
from .utils.rprint import rlog as log
|
||||||
|
# from .utils.viz import viz_lmk
|
||||||
from .live_portrait_wrapper import LivePortraitWrapper
|
from .live_portrait_wrapper import LivePortraitWrapper
|
||||||
|
|
||||||
|
|
||||||
@ -35,84 +34,124 @@ def make_abs_path(fn):
|
|||||||
class LivePortraitPipeline(object):
|
class LivePortraitPipeline(object):
|
||||||
|
|
||||||
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
||||||
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
|
||||||
self.cropper = Cropper(crop_cfg=crop_cfg)
|
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
|
||||||
|
|
||||||
def execute(self, args: ArgumentConfig):
|
def execute(self, args: ArgumentConfig):
|
||||||
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
# for convenience
|
||||||
|
inf_cfg = self.live_portrait_wrapper.inference_cfg
|
||||||
|
device = self.live_portrait_wrapper.device
|
||||||
|
crop_cfg = self.cropper.crop_cfg
|
||||||
|
|
||||||
######## process source portrait ########
|
######## process source portrait ########
|
||||||
img_rgb = load_image_rgb(args.source_image)
|
img_rgb = load_image_rgb(args.source_image)
|
||||||
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
|
||||||
log(f"Load source image from {args.source_image}")
|
log(f"Load source image from {args.source_image}")
|
||||||
crop_info = self.cropper.crop_single_image(img_rgb)
|
|
||||||
|
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
|
||||||
|
if crop_info is None:
|
||||||
|
raise Exception("No face detected in the source image!")
|
||||||
source_lmk = crop_info['lmk_crop']
|
source_lmk = crop_info['lmk_crop']
|
||||||
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
||||||
if inference_cfg.flag_do_crop:
|
|
||||||
|
if inf_cfg.flag_do_crop:
|
||||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||||
else:
|
else:
|
||||||
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
|
||||||
|
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||||
x_c_s = x_s_info['kp']
|
x_c_s = x_s_info['kp']
|
||||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||||
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||||
|
|
||||||
if inference_cfg.flag_lip_zero:
|
flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite
|
||||||
|
if flag_lip_zero:
|
||||||
# let lip-open scalar to be 0 at first
|
# let lip-open scalar to be 0 at first
|
||||||
c_d_lip_before_animation = [0.]
|
c_d_lip_before_animation = [0.]
|
||||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
||||||
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
|
if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold:
|
||||||
inference_cfg.flag_lip_zero = False
|
flag_lip_zero = False
|
||||||
else:
|
else:
|
||||||
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
||||||
############################################
|
############################################
|
||||||
|
|
||||||
######## process driving info ########
|
######## process driving info ########
|
||||||
if is_video(args.driving_info):
|
flag_load_from_template = is_template(args.driving_info)
|
||||||
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
|
driving_rgb_crop_256x256_lst = None
|
||||||
# TODO: 这里track一下驱动视频 -> 构建模板
|
wfp_template = None
|
||||||
|
|
||||||
|
if flag_load_from_template:
|
||||||
|
# NOTE: load from template, it is fast, but the cropping video is None
|
||||||
|
log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
|
||||||
|
template_dct = load(args.driving_info)
|
||||||
|
n_frames = template_dct['n_frames']
|
||||||
|
|
||||||
|
# set output_fps
|
||||||
|
output_fps = template_dct.get('output_fps', inf_cfg.output_fps)
|
||||||
|
log(f'The FPS of template: {output_fps}')
|
||||||
|
|
||||||
|
if args.flag_crop_driving_video:
|
||||||
|
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
||||||
|
|
||||||
|
elif osp.exists(args.driving_info) and is_video(args.driving_info):
|
||||||
|
# load from video file, AND make motion template
|
||||||
|
log(f"Load video: {args.driving_info}")
|
||||||
|
if osp.isdir(args.driving_info):
|
||||||
|
output_fps = inf_cfg.output_fps
|
||||||
|
else:
|
||||||
|
output_fps = int(get_fps(args.driving_info))
|
||||||
|
log(f'The FPS of {args.driving_info} is: {output_fps}')
|
||||||
|
|
||||||
|
log(f"Load video file (mp4 mov avi etc...): {args.driving_info}")
|
||||||
driving_rgb_lst = load_driving_info(args.driving_info)
|
driving_rgb_lst = load_driving_info(args.driving_info)
|
||||||
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
|
|
||||||
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
|
######## make motion template ########
|
||||||
|
log("Start making motion template...")
|
||||||
|
if inf_cfg.flag_crop_driving_video:
|
||||||
|
ret = self.cropper.crop_driving_video(driving_rgb_lst)
|
||||||
|
log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.')
|
||||||
|
driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst']
|
||||||
|
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
|
||||||
|
else:
|
||||||
|
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
|
||||||
|
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
|
||||||
|
|
||||||
|
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst)
|
||||||
|
# save the motion template
|
||||||
|
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst)
|
||||||
|
template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
|
||||||
|
|
||||||
|
wfp_template = remove_suffix(args.driving_info) + '.pkl'
|
||||||
|
dump(wfp_template, template_dct)
|
||||||
|
log(f"Dump motion template to {wfp_template}")
|
||||||
|
|
||||||
n_frames = I_d_lst.shape[0]
|
n_frames = I_d_lst.shape[0]
|
||||||
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
|
|
||||||
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
|
||||||
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
|
||||||
elif is_template(args.driving_info):
|
|
||||||
log(f"Load from video templates {args.driving_info}")
|
|
||||||
with open(args.driving_info, 'rb') as f:
|
|
||||||
template_lst, driving_lmk_lst = pickle.load(f)
|
|
||||||
n_frames = template_lst[0]['n_frames']
|
|
||||||
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Unsupported driving types!")
|
raise Exception(f"{args.driving_info} not exists or unsupported driving info types!")
|
||||||
#########################################
|
#########################################
|
||||||
|
|
||||||
######## prepare for pasteback ########
|
######## prepare for pasteback ########
|
||||||
if inference_cfg.flag_pasteback:
|
I_p_pstbk_lst = None
|
||||||
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||||
I_p_paste_lst = []
|
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||||
|
I_p_pstbk_lst = []
|
||||||
|
log("Prepared pasteback mask done.")
|
||||||
#########################################
|
#########################################
|
||||||
|
|
||||||
I_p_lst = []
|
I_p_lst = []
|
||||||
R_d_0, x_d_0_info = None, None
|
R_d_0, x_d_0_info = None, None
|
||||||
for i in track(range(n_frames), description='Animating...', total=n_frames):
|
|
||||||
if is_video(args.driving_info):
|
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
|
||||||
# extract kp info by M
|
x_d_i_info = template_dct['motion'][i]
|
||||||
I_d_i = I_d_lst[i]
|
x_d_i_info = dct2device(x_d_i_info, device)
|
||||||
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
R_d_i = x_d_i_info['R_d']
|
||||||
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
|
||||||
else:
|
|
||||||
# from template
|
|
||||||
x_d_i_info = template_lst[i]
|
|
||||||
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
|
|
||||||
R_d_i = x_d_i_info['R_d']
|
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
R_d_0 = R_d_i
|
R_d_0 = R_d_i
|
||||||
x_d_0_info = x_d_i_info
|
x_d_0_info = x_d_i_info
|
||||||
|
|
||||||
if inference_cfg.flag_relative:
|
if inf_cfg.flag_relative_motion:
|
||||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||||
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||||
@ -123,36 +162,36 @@ class LivePortraitPipeline(object):
|
|||||||
scale_new = x_s_info['scale']
|
scale_new = x_s_info['scale']
|
||||||
t_new = x_d_i_info['t']
|
t_new = x_d_i_info['t']
|
||||||
|
|
||||||
t_new[..., 2].fill_(0) # zero tz
|
t_new[..., 2].fill_(0) # zero tz
|
||||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||||
|
|
||||||
# Algorithm 1:
|
# Algorithm 1:
|
||||||
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
|
||||||
# without stitching or retargeting
|
# without stitching or retargeting
|
||||||
if inference_cfg.flag_lip_zero:
|
if flag_lip_zero:
|
||||||
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
|
||||||
# with stitching and without retargeting
|
# with stitching and without retargeting
|
||||||
if inference_cfg.flag_lip_zero:
|
if flag_lip_zero:
|
||||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
||||||
else:
|
else:
|
||||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
||||||
else:
|
else:
|
||||||
eyes_delta, lip_delta = None, None
|
eyes_delta, lip_delta = None, None
|
||||||
if inference_cfg.flag_eye_retargeting:
|
if inf_cfg.flag_eye_retargeting:
|
||||||
c_d_eyes_i = input_eye_ratio_lst[i]
|
c_d_eyes_i = c_d_eyes_lst[i]
|
||||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
||||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
||||||
if inference_cfg.flag_lip_retargeting:
|
if inf_cfg.flag_lip_retargeting:
|
||||||
c_d_lip_i = input_lip_ratio_lst[i]
|
c_d_lip_i = c_d_lip_lst[i]
|
||||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
||||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
||||||
|
|
||||||
if inference_cfg.flag_relative: # use x_s
|
if inf_cfg.flag_relative_motion: # use x_s
|
||||||
x_d_i_new = x_s + \
|
x_d_i_new = x_s + \
|
||||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
||||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
||||||
@ -161,30 +200,86 @@ class LivePortraitPipeline(object):
|
|||||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
||||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
||||||
|
|
||||||
if inference_cfg.flag_stitching:
|
if inf_cfg.flag_stitching:
|
||||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
||||||
|
|
||||||
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
||||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||||
I_p_lst.append(I_p_i)
|
I_p_lst.append(I_p_i)
|
||||||
|
|
||||||
if inference_cfg.flag_pasteback:
|
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||||
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU
|
||||||
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
|
||||||
|
I_p_pstbk_lst.append(I_p_pstbk)
|
||||||
|
|
||||||
mkdir(args.output_dir)
|
mkdir(args.output_dir)
|
||||||
wfp_concat = None
|
wfp_concat = None
|
||||||
if is_video(args.driving_info):
|
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
|
||||||
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
|
|
||||||
# save (driving frames, source image, drived frames) result
|
######### build final concact result #########
|
||||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
# driving frame | source image | generation, or source image | generation
|
||||||
images2video(frames_concatenated, wfp=wfp_concat)
|
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
|
||||||
|
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
||||||
|
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||||
|
|
||||||
|
if flag_has_audio:
|
||||||
|
# final result with concact
|
||||||
|
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
|
||||||
|
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
|
||||||
|
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||||
|
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
|
||||||
|
|
||||||
# save drived result
|
# save drived result
|
||||||
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
||||||
if inference_cfg.flag_pasteback:
|
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||||
images2video(I_p_paste_lst, wfp=wfp)
|
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||||
else:
|
else:
|
||||||
images2video(I_p_lst, wfp=wfp)
|
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||||
|
|
||||||
|
######### build final result #########
|
||||||
|
if flag_has_audio:
|
||||||
|
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
|
||||||
|
add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
|
||||||
|
os.replace(wfp_with_audio, wfp)
|
||||||
|
log(f"Replace {wfp} with {wfp_with_audio}")
|
||||||
|
|
||||||
|
# final log
|
||||||
|
if wfp_template not in (None, ''):
|
||||||
|
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||||
|
log(f'Animated video: {wfp}')
|
||||||
|
log(f'Animated video with concact: {wfp_concat}')
|
||||||
|
|
||||||
return wfp, wfp_concat
|
return wfp, wfp_concat
|
||||||
|
|
||||||
|
def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs):
|
||||||
|
n_frames = I_d_lst.shape[0]
|
||||||
|
template_dct = {
|
||||||
|
'n_frames': n_frames,
|
||||||
|
'output_fps': kwargs.get('output_fps', 25),
|
||||||
|
'motion': [],
|
||||||
|
'c_d_eyes_lst': [],
|
||||||
|
'c_d_lip_lst': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
|
||||||
|
# collect s_d, R_d, δ_d and t_d for inference
|
||||||
|
I_d_i = I_d_lst[i]
|
||||||
|
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
||||||
|
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
||||||
|
|
||||||
|
item_dct = {
|
||||||
|
'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32),
|
||||||
|
'R_d': R_d_i.cpu().numpy().astype(np.float32),
|
||||||
|
'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32),
|
||||||
|
't': x_d_i_info['t'].cpu().numpy().astype(np.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
template_dct['motion'].append(item_dct)
|
||||||
|
|
||||||
|
c_d_eyes = c_d_eyes_lst[i].astype(np.float32)
|
||||||
|
template_dct['c_d_eyes_lst'].append(c_d_eyes)
|
||||||
|
|
||||||
|
c_d_lip = c_d_lip_lst[i].astype(np.float32)
|
||||||
|
template_dct['c_d_lip_lst'].append(c_d_lip)
|
||||||
|
|
||||||
|
return template_dct
|
||||||
|
@ -20,45 +20,51 @@ from .utils.rprint import rlog as log
|
|||||||
|
|
||||||
class LivePortraitWrapper(object):
|
class LivePortraitWrapper(object):
|
||||||
|
|
||||||
def __init__(self, cfg: InferenceConfig):
|
def __init__(self, inference_cfg: InferenceConfig):
|
||||||
|
|
||||||
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
self.inference_cfg = inference_cfg
|
||||||
|
self.device_id = inference_cfg.device_id
|
||||||
|
if inference_cfg.flag_force_cpu:
|
||||||
|
self.device = 'cpu'
|
||||||
|
else:
|
||||||
|
self.device = 'cuda:' + str(self.device_id)
|
||||||
|
|
||||||
|
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||||
# init F
|
# init F
|
||||||
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
|
||||||
log(f'Load appearance_feature_extractor done.')
|
log(f'Load appearance_feature_extractor done.')
|
||||||
# init M
|
# init M
|
||||||
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
|
||||||
log(f'Load motion_extractor done.')
|
log(f'Load motion_extractor done.')
|
||||||
# init W
|
# init W
|
||||||
self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
|
||||||
log(f'Load warping_module done.')
|
log(f'Load warping_module done.')
|
||||||
# init G
|
# init G
|
||||||
self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
|
||||||
log(f'Load spade_generator done.')
|
log(f'Load spade_generator done.')
|
||||||
# init S and R
|
# init S and R
|
||||||
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
|
if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
|
||||||
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
|
||||||
log(f'Load stitching_retargeting_module done.')
|
log(f'Load stitching_retargeting_module done.')
|
||||||
else:
|
else:
|
||||||
self.stitching_retargeting_module = None
|
self.stitching_retargeting_module = None
|
||||||
|
|
||||||
self.cfg = cfg
|
|
||||||
self.device_id = cfg.device_id
|
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
|
|
||||||
def update_config(self, user_args):
|
def update_config(self, user_args):
|
||||||
for k, v in user_args.items():
|
for k, v in user_args.items():
|
||||||
if hasattr(self.cfg, k):
|
if hasattr(self.inference_cfg, k):
|
||||||
setattr(self.cfg, k, v)
|
setattr(self.inference_cfg, k, v)
|
||||||
|
|
||||||
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
|
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
|
||||||
""" construct the input as standard
|
""" construct the input as standard
|
||||||
img: HxWx3, uint8, 256x256
|
img: HxWx3, uint8, 256x256
|
||||||
"""
|
"""
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
|
if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
|
||||||
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
|
x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
|
||||||
else:
|
else:
|
||||||
x = img.copy()
|
x = img.copy()
|
||||||
|
|
||||||
@ -70,7 +76,7 @@ class LivePortraitWrapper(object):
|
|||||||
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
||||||
x = np.clip(x, 0, 1) # clip to 0~1
|
x = np.clip(x, 0, 1) # clip to 0~1
|
||||||
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
||||||
x = x.cuda(self.device_id)
|
x = x.to(self.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def prepare_driving_videos(self, imgs) -> torch.Tensor:
|
def prepare_driving_videos(self, imgs) -> torch.Tensor:
|
||||||
@ -87,7 +93,7 @@ class LivePortraitWrapper(object):
|
|||||||
y = _imgs.astype(np.float32) / 255.
|
y = _imgs.astype(np.float32) / 255.
|
||||||
y = np.clip(y, 0, 1) # clip to 0~1
|
y = np.clip(y, 0, 1) # clip to 0~1
|
||||||
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
||||||
y = y.cuda(self.device_id)
|
y = y.to(self.device)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -96,7 +102,7 @@ class LivePortraitWrapper(object):
|
|||||||
x: Bx3xHxW, normalized to 0~1
|
x: Bx3xHxW, normalized to 0~1
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
||||||
feature_3d = self.appearance_feature_extractor(x)
|
feature_3d = self.appearance_feature_extractor(x)
|
||||||
|
|
||||||
return feature_3d.float()
|
return feature_3d.float()
|
||||||
@ -108,10 +114,10 @@ class LivePortraitWrapper(object):
|
|||||||
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
||||||
kp_info = self.motion_extractor(x)
|
kp_info = self.motion_extractor(x)
|
||||||
|
|
||||||
if self.cfg.flag_use_half_precision:
|
if self.inference_cfg.flag_use_half_precision:
|
||||||
# float the dict
|
# float the dict
|
||||||
for k, v in kp_info.items():
|
for k, v in kp_info.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
@ -254,14 +260,14 @@ class LivePortraitWrapper(object):
|
|||||||
"""
|
"""
|
||||||
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
||||||
# get decoder input
|
# get decoder input
|
||||||
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
||||||
# decode
|
# decode
|
||||||
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
||||||
|
|
||||||
# float the dict
|
# float the dict
|
||||||
if self.cfg.flag_use_half_precision:
|
if self.inference_cfg.flag_use_half_precision:
|
||||||
for k, v in ret_dct.items():
|
for k, v in ret_dct.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
ret_dct[k] = v.float()
|
ret_dct[k] = v.float()
|
||||||
@ -278,7 +284,7 @@ class LivePortraitWrapper(object):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
|
def calc_driving_ratio(self, driving_lmk_lst):
|
||||||
input_eye_ratio_lst = []
|
input_eye_ratio_lst = []
|
||||||
input_lip_ratio_lst = []
|
input_lip_ratio_lst = []
|
||||||
for lmk in driving_lmk_lst:
|
for lmk in driving_lmk_lst:
|
||||||
@ -288,20 +294,18 @@ class LivePortraitWrapper(object):
|
|||||||
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
||||||
return input_eye_ratio_lst, input_lip_ratio_lst
|
return input_eye_ratio_lst, input_lip_ratio_lst
|
||||||
|
|
||||||
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
|
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
|
||||||
eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
|
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
|
||||||
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
|
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
|
||||||
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
|
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
|
||||||
# [c_s,eyes, c_d,eyes,i]
|
# [c_s,eyes, c_d,eyes,i]
|
||||||
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
|
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
|
||||||
return combined_eye_ratio_tensor
|
return combined_eye_ratio_tensor
|
||||||
|
|
||||||
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
|
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
|
||||||
lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
|
c_s_lip = calc_lip_close_ratio(source_lmk[None])
|
||||||
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
|
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
|
||||||
|
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
|
||||||
# [c_s,lip, c_d,lip,i]
|
# [c_s,lip, c_d,lip,i]
|
||||||
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
|
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
|
||||||
if input_lip_ratio_tensor.shape != [1, 1]:
|
|
||||||
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
|
|
||||||
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
|
|
||||||
return combined_lip_ratio_tensor
|
return combined_lip_ratio_tensor
|
||||||
|
@ -1,65 +0,0 @@
|
|||||||
# coding: utf-8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Make video template
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import pickle
|
|
||||||
from rich.progress import track
|
|
||||||
from .utils.cropper import Cropper
|
|
||||||
|
|
||||||
from .utils.io import load_driving_info
|
|
||||||
from .utils.camera import get_rotation_matrix
|
|
||||||
from .utils.helper import mkdir, basename
|
|
||||||
from .utils.rprint import rlog as log
|
|
||||||
from .config.crop_config import CropConfig
|
|
||||||
from .config.inference_config import InferenceConfig
|
|
||||||
from .live_portrait_wrapper import LivePortraitWrapper
|
|
||||||
|
|
||||||
class TemplateMaker:
|
|
||||||
|
|
||||||
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
|
||||||
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
|
||||||
self.cropper = Cropper(crop_cfg=crop_cfg)
|
|
||||||
|
|
||||||
def make_motion_template(self, video_fp: str, output_path: str, **kwargs):
|
|
||||||
""" make video template (.pkl format)
|
|
||||||
video_fp: driving video file path
|
|
||||||
output_path: where to save the pickle file
|
|
||||||
"""
|
|
||||||
|
|
||||||
driving_rgb_lst = load_driving_info(video_fp)
|
|
||||||
driving_rgb_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
|
|
||||||
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
|
||||||
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst)
|
|
||||||
|
|
||||||
n_frames = I_d_lst.shape[0]
|
|
||||||
|
|
||||||
templates = []
|
|
||||||
|
|
||||||
|
|
||||||
for i in track(range(n_frames), description='Making templates...', total=n_frames):
|
|
||||||
I_d_i = I_d_lst[i]
|
|
||||||
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
|
||||||
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
|
||||||
# collect s_d, R_d, δ_d and t_d for inference
|
|
||||||
template_dct = {
|
|
||||||
'n_frames': n_frames,
|
|
||||||
'frames_index': i,
|
|
||||||
}
|
|
||||||
template_dct['scale'] = x_d_i_info['scale'].cpu().numpy().astype(np.float32)
|
|
||||||
template_dct['R_d'] = R_d_i.cpu().numpy().astype(np.float32)
|
|
||||||
template_dct['exp'] = x_d_i_info['exp'].cpu().numpy().astype(np.float32)
|
|
||||||
template_dct['t'] = x_d_i_info['t'].cpu().numpy().astype(np.float32)
|
|
||||||
|
|
||||||
templates.append(template_dct)
|
|
||||||
|
|
||||||
mkdir(output_path)
|
|
||||||
# Save the dictionary as a pickle file
|
|
||||||
pickle_fp = os.path.join(output_path, f'{basename(video_fp)}.pkl')
|
|
||||||
with open(pickle_fp, 'wb') as f:
|
|
||||||
pickle.dump([templates, driving_lmk_lst], f)
|
|
||||||
log(f"Template saved at {pickle_fp}")
|
|
@ -31,8 +31,6 @@ def headpose_pred_to_degree(pred):
|
|||||||
def get_rotation_matrix(pitch_, yaw_, roll_):
|
def get_rotation_matrix(pitch_, yaw_, roll_):
|
||||||
""" the input is in degree
|
""" the input is in degree
|
||||||
"""
|
"""
|
||||||
# calculate the rotation matrix: vps @ rot
|
|
||||||
|
|
||||||
# transform to radian
|
# transform to radian
|
||||||
pitch = pitch_ / 180 * PI
|
pitch = pitch_ / 180 * PI
|
||||||
yaw = yaw_ / 180 * PI
|
yaw = yaw_ / 180 * PI
|
||||||
|
@ -281,11 +281,10 @@ def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=Fals
|
|||||||
dtype=DTYPE
|
dtype=DTYPE
|
||||||
)
|
)
|
||||||
|
|
||||||
if flag_rot and angle is None:
|
# if flag_rot and angle is None:
|
||||||
print('angle is None, but flag_rotate is True', style="bold yellow")
|
# print('angle is None, but flag_rotate is True', style="bold yellow")
|
||||||
|
|
||||||
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
|
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
|
||||||
|
|
||||||
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
|
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
|
||||||
|
|
||||||
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
|
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
|
||||||
@ -362,17 +361,6 @@ def crop_image(img, pts: np.ndarray, **kwargs):
|
|||||||
flag_do_rot=kwargs.get('flag_do_rot', True),
|
flag_do_rot=kwargs.get('flag_do_rot', True),
|
||||||
)
|
)
|
||||||
|
|
||||||
if img is None:
|
|
||||||
M_INV_H = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
|
|
||||||
M = np.linalg.inv(M_INV_H)
|
|
||||||
ret_dct = {
|
|
||||||
'M': M[:2, ...], # from the original image to the cropped image
|
|
||||||
'M_o2c': M[:2, ...], # from the cropped image to the original image
|
|
||||||
'img_crop': None,
|
|
||||||
'pt_crop': None,
|
|
||||||
}
|
|
||||||
return ret_dct
|
|
||||||
|
|
||||||
img_crop = _transform_img(img, M_INV, dsize) # origin to crop
|
img_crop = _transform_img(img, M_INV, dsize) # origin to crop
|
||||||
pt_crop = _transform_pts(pts, M_INV)
|
pt_crop = _transform_pts(pts, M_INV)
|
||||||
|
|
||||||
@ -397,16 +385,14 @@ def average_bbox_lst(bbox_lst):
|
|||||||
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
|
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
|
||||||
"""prepare mask for later image paste back
|
"""prepare mask for later image paste back
|
||||||
"""
|
"""
|
||||||
if mask_crop is None:
|
|
||||||
mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR)
|
|
||||||
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
|
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
|
||||||
mask_ori = mask_ori.astype(np.float32) / 255.
|
mask_ori = mask_ori.astype(np.float32) / 255.
|
||||||
return mask_ori
|
return mask_ori
|
||||||
|
|
||||||
def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori):
|
def paste_back(img_crop, M_c2o, img_ori, mask_ori):
|
||||||
"""paste back the image
|
"""paste back the image
|
||||||
"""
|
"""
|
||||||
dsize = (rgb_ori.shape[1], rgb_ori.shape[0])
|
dsize = (img_ori.shape[1], img_ori.shape[0])
|
||||||
result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize)
|
result = _transform_img(img_crop, M_c2o, dsize=dsize)
|
||||||
result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8)
|
result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8)
|
||||||
return result
|
return result
|
@ -1,20 +1,23 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import numpy as np
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from typing import List, Union, Tuple
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from .landmark_runner import LandmarkRunner
|
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||||
from .face_analysis_diy import FaceAnalysisDIY
|
import numpy as np
|
||||||
from .helper import prefix
|
|
||||||
from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst
|
from ..config.crop_config import CropConfig
|
||||||
from .timer import Timer
|
from .crop import (
|
||||||
|
average_bbox_lst,
|
||||||
|
crop_image,
|
||||||
|
crop_image_by_bbox,
|
||||||
|
parse_bbox_from_landmark,
|
||||||
|
)
|
||||||
|
from .io import contiguous
|
||||||
from .rprint import rlog as log
|
from .rprint import rlog as log
|
||||||
from .io import load_image_rgb
|
from .face_analysis_diy import FaceAnalysisDIY
|
||||||
from .video import VideoWriter, get_fps, change_video_fps
|
from .landmark_runner import LandmarkRunner
|
||||||
|
|
||||||
|
|
||||||
def make_abs_path(fn):
|
def make_abs_path(fn):
|
||||||
@ -23,123 +26,171 @@ def make_abs_path(fn):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Trajectory:
|
class Trajectory:
|
||||||
start: int = -1 # 起始帧 闭区间
|
start: int = -1 # start frame
|
||||||
end: int = -1 # 结束帧 闭区间
|
end: int = -1 # end frame
|
||||||
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
||||||
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
|
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
|
||||||
|
|
||||||
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
|
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
|
||||||
|
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
||||||
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
|
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
|
||||||
|
|
||||||
|
|
||||||
class Cropper(object):
|
class Cropper(object):
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
device_id = kwargs.get('device_id', 0)
|
self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
|
||||||
|
device_id = kwargs.get("device_id", 0)
|
||||||
|
flag_force_cpu = kwargs.get("flag_force_cpu", False)
|
||||||
|
if flag_force_cpu:
|
||||||
|
device = "cpu"
|
||||||
|
face_analysis_wrapper_provicer = ["CPUExecutionProvider"]
|
||||||
|
else:
|
||||||
|
device = "cuda"
|
||||||
|
face_analysis_wrapper_provicer = ["CUDAExecutionProvider"]
|
||||||
self.landmark_runner = LandmarkRunner(
|
self.landmark_runner = LandmarkRunner(
|
||||||
ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'),
|
ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
|
||||||
onnx_provider='cuda',
|
onnx_provider=device,
|
||||||
device_id=device_id
|
device_id=device_id,
|
||||||
)
|
)
|
||||||
self.landmark_runner.warmup()
|
self.landmark_runner.warmup()
|
||||||
|
|
||||||
self.face_analysis_wrapper = FaceAnalysisDIY(
|
self.face_analysis_wrapper = FaceAnalysisDIY(
|
||||||
name='buffalo_l',
|
name="buffalo_l",
|
||||||
root=make_abs_path('../../pretrained_weights/insightface'),
|
root=make_abs_path(self.crop_cfg.insightface_root),
|
||||||
providers=["CUDAExecutionProvider"]
|
providers=face_analysis_wrapper_provicer,
|
||||||
)
|
)
|
||||||
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
|
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
|
||||||
self.face_analysis_wrapper.warmup()
|
self.face_analysis_wrapper.warmup()
|
||||||
|
|
||||||
self.crop_cfg = kwargs.get('crop_cfg', None)
|
|
||||||
|
|
||||||
def update_config(self, user_args):
|
def update_config(self, user_args):
|
||||||
for k, v in user_args.items():
|
for k, v in user_args.items():
|
||||||
if hasattr(self.crop_cfg, k):
|
if hasattr(self.crop_cfg, k):
|
||||||
setattr(self.crop_cfg, k, v)
|
setattr(self.crop_cfg, k, v)
|
||||||
|
|
||||||
def crop_single_image(self, obj, **kwargs):
|
def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig):
|
||||||
direction = kwargs.get('direction', 'large-small')
|
# crop a source image and get neccessary information
|
||||||
|
img_rgb = img_rgb_.copy() # copy it
|
||||||
# crop and align a single image
|
|
||||||
if isinstance(obj, str):
|
|
||||||
img_rgb = load_image_rgb(obj)
|
|
||||||
elif isinstance(obj, np.ndarray):
|
|
||||||
img_rgb = obj
|
|
||||||
|
|
||||||
|
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
||||||
src_face = self.face_analysis_wrapper.get(
|
src_face = self.face_analysis_wrapper.get(
|
||||||
img_rgb,
|
img_bgr,
|
||||||
flag_do_landmark_2d_106=True,
|
flag_do_landmark_2d_106=True,
|
||||||
direction=direction
|
direction=crop_cfg.direction,
|
||||||
|
max_face_num=crop_cfg.max_face_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(src_face) == 0:
|
if len(src_face) == 0:
|
||||||
log('No face detected in the source image.')
|
log("No face detected in the source image.")
|
||||||
raise gr.Error("No face detected in the source image 💥!", duration=5)
|
return None
|
||||||
raise Exception("No face detected in the source image!")
|
|
||||||
elif len(src_face) > 1:
|
elif len(src_face) > 1:
|
||||||
log(f'More than one face detected in the image, only pick one face by rule {direction}.')
|
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
|
||||||
|
|
||||||
|
# NOTE: temporarily only pick the first face, to support multiple face in the future
|
||||||
src_face = src_face[0]
|
src_face = src_face[0]
|
||||||
pts = src_face.landmark_2d_106
|
lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
|
||||||
|
|
||||||
# crop the face
|
# crop the face
|
||||||
ret_dct = crop_image(
|
ret_dct = crop_image(
|
||||||
img_rgb, # ndarray
|
img_rgb, # ndarray
|
||||||
pts, # 106x2 or Nx2
|
lmk, # 106x2 or Nx2
|
||||||
dsize=kwargs.get('dsize', 512),
|
dsize=crop_cfg.dsize,
|
||||||
scale=kwargs.get('scale', 2.3),
|
scale=crop_cfg.scale,
|
||||||
vy_ratio=kwargs.get('vy_ratio', -0.15),
|
vx_ratio=crop_cfg.vx_ratio,
|
||||||
|
vy_ratio=crop_cfg.vy_ratio,
|
||||||
)
|
)
|
||||||
# update a 256x256 version for network input or else
|
|
||||||
ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA)
|
|
||||||
ret_dct['pt_crop_256x256'] = ret_dct['pt_crop'] * 256 / kwargs.get('dsize', 512)
|
|
||||||
|
|
||||||
recon_ret = self.landmark_runner.run(img_rgb, pts)
|
lmk = self.landmark_runner.run(img_rgb, lmk)
|
||||||
lmk = recon_ret['pts']
|
ret_dct["lmk_crop"] = lmk
|
||||||
ret_dct['lmk_crop'] = lmk
|
|
||||||
|
# update a 256x256 version for network input
|
||||||
|
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
|
||||||
|
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
|
||||||
|
|
||||||
return ret_dct
|
return ret_dct
|
||||||
|
|
||||||
def get_retargeting_lmk_info(self, driving_rgb_lst):
|
def crop_driving_video(self, driving_rgb_lst, **kwargs):
|
||||||
# TODO: implement a tracking-based version
|
"""Tracking based landmarks/alignment and cropping"""
|
||||||
driving_lmk_lst = []
|
|
||||||
for driving_image in driving_rgb_lst:
|
|
||||||
ret_dct = self.crop_single_image(driving_image)
|
|
||||||
driving_lmk_lst.append(ret_dct['lmk_crop'])
|
|
||||||
return driving_lmk_lst
|
|
||||||
|
|
||||||
def make_video_clip(self, driving_rgb_lst, output_path, output_fps=30, **kwargs):
|
|
||||||
trajectory = Trajectory()
|
trajectory = Trajectory()
|
||||||
direction = kwargs.get('direction', 'large-small')
|
direction = kwargs.get("direction", "large-small")
|
||||||
for idx, driving_image in enumerate(driving_rgb_lst):
|
for idx, frame_rgb in enumerate(driving_rgb_lst):
|
||||||
if idx == 0 or trajectory.start == -1:
|
if idx == 0 or trajectory.start == -1:
|
||||||
src_face = self.face_analysis_wrapper.get(
|
src_face = self.face_analysis_wrapper.get(
|
||||||
driving_image,
|
contiguous(frame_rgb[..., ::-1]),
|
||||||
flag_do_landmark_2d_106=True,
|
flag_do_landmark_2d_106=True,
|
||||||
direction=direction
|
direction=direction,
|
||||||
)
|
)
|
||||||
if len(src_face) == 0:
|
if len(src_face) == 0:
|
||||||
# No face detected in the driving_image
|
log(f"No face detected in the frame #{idx}")
|
||||||
continue
|
continue
|
||||||
elif len(src_face) > 1:
|
elif len(src_face) > 1:
|
||||||
log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.')
|
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
|
||||||
src_face = src_face[0]
|
src_face = src_face[0]
|
||||||
pts = src_face.landmark_2d_106
|
lmk = src_face.landmark_2d_106
|
||||||
lmk_203 = self.landmark_runner(driving_image, pts)['pts']
|
lmk = self.landmark_runner.run(frame_rgb, lmk)
|
||||||
trajectory.start, trajectory.end = idx, idx
|
trajectory.start, trajectory.end = idx, idx
|
||||||
else:
|
else:
|
||||||
lmk_203 = self.face_recon_wrapper(driving_image, trajectory.lmk_lst[-1])['pts']
|
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
|
||||||
trajectory.end = idx
|
trajectory.end = idx
|
||||||
|
|
||||||
trajectory.lmk_lst.append(lmk_203)
|
trajectory.lmk_lst.append(lmk)
|
||||||
ret_bbox = parse_bbox_from_landmark(lmk_203, scale=self.crop_cfg.globalscale, vy_ratio=elf.crop_cfg.vy_ratio)['bbox']
|
ret_bbox = parse_bbox_from_landmark(
|
||||||
bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4,
|
lmk,
|
||||||
|
scale=self.crop_cfg.scale_crop_video,
|
||||||
|
vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video,
|
||||||
|
vy_ratio=self.crop_cfg.vy_ratio_crop_video,
|
||||||
|
)["bbox"]
|
||||||
|
bbox = [
|
||||||
|
ret_bbox[0, 0],
|
||||||
|
ret_bbox[0, 1],
|
||||||
|
ret_bbox[2, 0],
|
||||||
|
ret_bbox[2, 1],
|
||||||
|
] # 4,
|
||||||
trajectory.bbox_lst.append(bbox) # bbox
|
trajectory.bbox_lst.append(bbox) # bbox
|
||||||
trajectory.frame_rgb_lst.append(driving_image)
|
trajectory.frame_rgb_lst.append(frame_rgb)
|
||||||
|
|
||||||
global_bbox = average_bbox_lst(trajectory.bbox_lst)
|
global_bbox = average_bbox_lst(trajectory.bbox_lst)
|
||||||
|
|
||||||
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
|
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
|
||||||
ret_dct = crop_image_by_bbox(
|
ret_dct = crop_image_by_bbox(
|
||||||
frame_rgb, global_bbox, lmk=lmk,
|
frame_rgb,
|
||||||
dsize=self.video_crop_cfg.dsize, flag_rot=self.video_crop_cfg.flag_rot, borderValue=self.video_crop_cfg.borderValue
|
global_bbox,
|
||||||
|
lmk=lmk,
|
||||||
|
dsize=kwargs.get("dsize", 512),
|
||||||
|
flag_rot=False,
|
||||||
|
borderValue=(0, 0, 0),
|
||||||
)
|
)
|
||||||
frame_rgb_crop = ret_dct['img_crop']
|
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"])
|
||||||
|
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
|
||||||
|
"lmk_crop_lst": trajectory.lmk_crop_lst,
|
||||||
|
}
|
||||||
|
|
||||||
|
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
|
||||||
|
"""Tracking based landmarks/alignment"""
|
||||||
|
trajectory = Trajectory()
|
||||||
|
direction = kwargs.get("direction", "large-small")
|
||||||
|
|
||||||
|
for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst):
|
||||||
|
if idx == 0 or trajectory.start == -1:
|
||||||
|
src_face = self.face_analysis_wrapper.get(
|
||||||
|
contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR
|
||||||
|
flag_do_landmark_2d_106=True,
|
||||||
|
direction=direction,
|
||||||
|
)
|
||||||
|
if len(src_face) == 0:
|
||||||
|
log(f"No face detected in the frame #{idx}")
|
||||||
|
raise Exception(f"No face detected in the frame #{idx}")
|
||||||
|
elif len(src_face) > 1:
|
||||||
|
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
|
||||||
|
src_face = src_face[0]
|
||||||
|
lmk = src_face.landmark_2d_106
|
||||||
|
lmk = self.landmark_runner.run(frame_rgb_crop, lmk)
|
||||||
|
trajectory.start, trajectory.end = idx, idx
|
||||||
|
else:
|
||||||
|
lmk = self.landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
|
||||||
|
trajectory.end = idx
|
||||||
|
|
||||||
|
trajectory.lmk_lst.append(lmk)
|
||||||
|
return trajectory.lmk_lst
|
||||||
|
@ -39,7 +39,7 @@ class FaceAnalysisDIY(FaceAnalysis):
|
|||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
|
|
||||||
def get(self, img_bgr, **kwargs):
|
def get(self, img_bgr, **kwargs):
|
||||||
max_num = kwargs.get('max_num', 0) # the number of the detected faces, 0 means no limit
|
max_num = kwargs.get('max_face_num', 0) # the number of the detected faces, 0 means no limit
|
||||||
flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection
|
flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection
|
||||||
direction = kwargs.get('direction', 'large-small') # sorting direction
|
direction = kwargs.get('direction', 'large-small') # sorting direction
|
||||||
face_center = None
|
face_center = None
|
||||||
|
@ -37,6 +37,11 @@ def basename(filename):
|
|||||||
return prefix(osp.basename(filename))
|
return prefix(osp.basename(filename))
|
||||||
|
|
||||||
|
|
||||||
|
def remove_suffix(filepath):
|
||||||
|
"""a/b/c.jpg -> a/b/c"""
|
||||||
|
return osp.join(osp.dirname(filepath), basename(filepath))
|
||||||
|
|
||||||
|
|
||||||
def is_video(file_path):
|
def is_video(file_path):
|
||||||
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
|
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
|
||||||
return True
|
return True
|
||||||
@ -63,9 +68,9 @@ def squeeze_tensor_to_numpy(tensor):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def dct2cuda(dct: dict, device_id: int):
|
def dct2device(dct: dict, device):
|
||||||
for key in dct:
|
for key in dct:
|
||||||
dct[key] = torch.tensor(dct[key]).cuda(device_id)
|
dct[key] = torch.tensor(dct[key]).to(device)
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
|
|
||||||
@ -94,13 +99,13 @@ def load_model(ckpt_path, model_config, device, model_type):
|
|||||||
model_params = model_config['model_params'][f'{model_type}_params']
|
model_params = model_config['model_params'][f'{model_type}_params']
|
||||||
|
|
||||||
if model_type == 'appearance_feature_extractor':
|
if model_type == 'appearance_feature_extractor':
|
||||||
model = AppearanceFeatureExtractor(**model_params).cuda(device)
|
model = AppearanceFeatureExtractor(**model_params).to(device)
|
||||||
elif model_type == 'motion_extractor':
|
elif model_type == 'motion_extractor':
|
||||||
model = MotionExtractor(**model_params).cuda(device)
|
model = MotionExtractor(**model_params).to(device)
|
||||||
elif model_type == 'warping_module':
|
elif model_type == 'warping_module':
|
||||||
model = WarpingNetwork(**model_params).cuda(device)
|
model = WarpingNetwork(**model_params).to(device)
|
||||||
elif model_type == 'spade_generator':
|
elif model_type == 'spade_generator':
|
||||||
model = SPADEDecoder(**model_params).cuda(device)
|
model = SPADEDecoder(**model_params).to(device)
|
||||||
elif model_type == 'stitching_retargeting_module':
|
elif model_type == 'stitching_retargeting_module':
|
||||||
# Special handling for stitching and retargeting module
|
# Special handling for stitching and retargeting module
|
||||||
config = model_config['model_params']['stitching_retargeting_module_params']
|
config = model_config['model_params']['stitching_retargeting_module_params']
|
||||||
@ -108,17 +113,17 @@ def load_model(ckpt_path, model_config, device, model_type):
|
|||||||
|
|
||||||
stitcher = StitchingRetargetingNetwork(**config.get('stitching'))
|
stitcher = StitchingRetargetingNetwork(**config.get('stitching'))
|
||||||
stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder']))
|
stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder']))
|
||||||
stitcher = stitcher.cuda(device)
|
stitcher = stitcher.to(device)
|
||||||
stitcher.eval()
|
stitcher.eval()
|
||||||
|
|
||||||
retargetor_lip = StitchingRetargetingNetwork(**config.get('lip'))
|
retargetor_lip = StitchingRetargetingNetwork(**config.get('lip'))
|
||||||
retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth']))
|
retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth']))
|
||||||
retargetor_lip = retargetor_lip.cuda(device)
|
retargetor_lip = retargetor_lip.to(device)
|
||||||
retargetor_lip.eval()
|
retargetor_lip.eval()
|
||||||
|
|
||||||
retargetor_eye = StitchingRetargetingNetwork(**config.get('eye'))
|
retargetor_eye = StitchingRetargetingNetwork(**config.get('eye'))
|
||||||
retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye']))
|
retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye']))
|
||||||
retargetor_eye = retargetor_eye.cuda(device)
|
retargetor_eye = retargetor_eye.to(device)
|
||||||
retargetor_eye.eval()
|
retargetor_eye.eval()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -134,20 +139,6 @@ def load_model(ckpt_path, model_config, device, model_type):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
# get coefficients of Eqn. 7
|
|
||||||
def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i):
|
|
||||||
if config.relative:
|
|
||||||
new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s
|
|
||||||
new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp'])
|
|
||||||
else:
|
|
||||||
new_rotation = R_t_i
|
|
||||||
new_expression = t_i_kp_info['exp']
|
|
||||||
new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t'])
|
|
||||||
new_translation[..., 2].fill_(0) # Keep the z-axis unchanged
|
|
||||||
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
|
|
||||||
return new_rotation, new_expression, new_translation, new_scale
|
|
||||||
|
|
||||||
|
|
||||||
def load_description(fp):
|
def load_description(fp):
|
||||||
with open(fp, 'r', encoding='utf-8') as f:
|
with open(fp, 'r', encoding='utf-8') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
@ -5,8 +5,11 @@ from glob import glob
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pickle
|
||||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||||
|
|
||||||
|
from .helper import mkdir, suffix
|
||||||
|
|
||||||
|
|
||||||
def load_image_rgb(image_path: str):
|
def load_image_rgb(image_path: str):
|
||||||
if not osp.exists(image_path):
|
if not osp.exists(image_path):
|
||||||
@ -23,8 +26,8 @@ def load_driving_info(driving_info):
|
|||||||
return [load_image_rgb(im_path) for im_path in image_paths]
|
return [load_image_rgb(im_path) for im_path in image_paths]
|
||||||
|
|
||||||
def load_images_from_video(file_path):
|
def load_images_from_video(file_path):
|
||||||
reader = imageio.get_reader(file_path)
|
reader = imageio.get_reader(file_path, "ffmpeg")
|
||||||
return [image for idx, image in enumerate(reader)]
|
return [image for _, image in enumerate(reader)]
|
||||||
|
|
||||||
if osp.isdir(driving_info):
|
if osp.isdir(driving_info):
|
||||||
driving_video_ori = load_images_from_directory(driving_info)
|
driving_video_ori = load_images_from_directory(driving_info)
|
||||||
@ -40,7 +43,7 @@ def contiguous(obj):
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
|
def resize_to_limit(img: np.ndarray, max_dim=1920, division=2):
|
||||||
"""
|
"""
|
||||||
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
|
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
|
||||||
:param img: the image to be processed.
|
:param img: the image to be processed.
|
||||||
@ -61,9 +64,9 @@ def resize_to_limit(img: np.ndarray, max_dim=1920, n=2):
|
|||||||
img = cv2.resize(img, (new_w, new_h))
|
img = cv2.resize(img, (new_w, new_h))
|
||||||
|
|
||||||
# ensure that the image dimensions are multiples of n
|
# ensure that the image dimensions are multiples of n
|
||||||
n = max(n, 1)
|
division = max(division, 1)
|
||||||
new_h = img.shape[0] - (img.shape[0] % n)
|
new_h = img.shape[0] - (img.shape[0] % division)
|
||||||
new_w = img.shape[1] - (img.shape[1] % n)
|
new_w = img.shape[1] - (img.shape[1] % division)
|
||||||
|
|
||||||
if new_h == 0 or new_w == 0:
|
if new_h == 0 or new_w == 0:
|
||||||
# when the width or height is less than n, no need to process
|
# when the width or height is less than n, no need to process
|
||||||
@ -87,7 +90,7 @@ def load_img_online(obj, mode="bgr", **kwargs):
|
|||||||
img = obj
|
img = obj
|
||||||
|
|
||||||
# Resize image to satisfy constraints
|
# Resize image to satisfy constraints
|
||||||
img = resize_to_limit(img, max_dim=max_dim, n=n)
|
img = resize_to_limit(img, max_dim=max_dim, division=n)
|
||||||
|
|
||||||
if mode.lower() == "bgr":
|
if mode.lower() == "bgr":
|
||||||
return contiguous(img)
|
return contiguous(img)
|
||||||
@ -95,3 +98,28 @@ def load_img_online(obj, mode="bgr", **kwargs):
|
|||||||
return contiguous(img[..., ::-1])
|
return contiguous(img[..., ::-1])
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown mode {mode}")
|
raise Exception(f"Unknown mode {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
def load(fp):
|
||||||
|
suffix_ = suffix(fp)
|
||||||
|
|
||||||
|
if suffix_ == "npy":
|
||||||
|
return np.load(fp)
|
||||||
|
elif suffix_ == "pkl":
|
||||||
|
return pickle.load(open(fp, "rb"))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown type: {suffix}")
|
||||||
|
|
||||||
|
|
||||||
|
def dump(wfp, obj):
|
||||||
|
wd = osp.split(wfp)[0]
|
||||||
|
if wd != "" and not osp.exists(wd):
|
||||||
|
mkdir(wd)
|
||||||
|
|
||||||
|
_suffix = suffix(wfp)
|
||||||
|
if _suffix == "npy":
|
||||||
|
np.save(wfp, obj)
|
||||||
|
elif _suffix == "pkl":
|
||||||
|
pickle.dump(obj, open(wfp, "wb"))
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown type: {}".format(_suffix))
|
||||||
|
@ -25,6 +25,7 @@ def to_ndarray(obj):
|
|||||||
|
|
||||||
class LandmarkRunner(object):
|
class LandmarkRunner(object):
|
||||||
"""landmark runner"""
|
"""landmark runner"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
ckpt_path = kwargs.get('ckpt_path')
|
ckpt_path = kwargs.get('ckpt_path')
|
||||||
onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda
|
onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda
|
||||||
@ -55,6 +56,7 @@ class LandmarkRunner(object):
|
|||||||
crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1)
|
crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1)
|
||||||
img_crop_rgb = crop_dct['img_crop']
|
img_crop_rgb = crop_dct['img_crop']
|
||||||
else:
|
else:
|
||||||
|
# NOTE: force resize to 224x224, NOT RECOMMEND!
|
||||||
img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize))
|
img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize))
|
||||||
scale = max(img_rgb.shape[:2]) / self.dsize
|
scale = max(img_rgb.shape[:2]) / self.dsize
|
||||||
crop_dct = {
|
crop_dct = {
|
||||||
@ -70,15 +72,13 @@ class LandmarkRunner(object):
|
|||||||
out_lst = self._run(inp)
|
out_lst = self._run(inp)
|
||||||
out_pts = out_lst[2]
|
out_pts = out_lst[2]
|
||||||
|
|
||||||
pts = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224
|
# 2d landmarks 203 points
|
||||||
pts = _transform_pts(pts, M=crop_dct['M_c2o'])
|
lmk = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224
|
||||||
|
lmk = _transform_pts(lmk, M=crop_dct['M_c2o'])
|
||||||
|
|
||||||
return {
|
return lmk
|
||||||
'pts': pts, # 2d landmarks 203 points
|
|
||||||
}
|
|
||||||
|
|
||||||
def warmup(self):
|
def warmup(self):
|
||||||
# 构造dummy image进行warmup
|
|
||||||
self.timer.tic()
|
self.timer.tic()
|
||||||
|
|
||||||
dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32)
|
dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32)
|
||||||
|
@ -7,32 +7,11 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
|
def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
|
||||||
"""
|
|
||||||
Calculate the ratio of the distance between two pairs of landmarks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
lmk (np.ndarray): Landmarks array of shape (B, N, 2).
|
|
||||||
idx1, idx2, idx3, idx4 (int): Indices of the landmarks.
|
|
||||||
eps (float): Small value to avoid division by zero.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Calculated distance ratio.
|
|
||||||
"""
|
|
||||||
return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
|
return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
|
||||||
(np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))
|
(np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))
|
||||||
|
|
||||||
|
|
||||||
def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
|
def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
|
||||||
"""
|
|
||||||
Calculate the eye-close ratio for left and right eyes.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
lmk (np.ndarray): Landmarks array of shape (B, N, 2).
|
|
||||||
target_eye_ratio (np.ndarray, optional): Additional target eye ratio array to include.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Concatenated eye-close ratios.
|
|
||||||
"""
|
|
||||||
lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
|
lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
|
||||||
righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
|
righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
|
||||||
if target_eye_ratio is not None:
|
if target_eye_ratio is not None:
|
||||||
@ -42,13 +21,4 @@ def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -
|
|||||||
|
|
||||||
|
|
||||||
def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
|
def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
|
||||||
"""
|
|
||||||
Calculate the lip-close ratio.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
lmk (np.ndarray): Landmarks array of shape (B, N, 2).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Calculated lip-close ratio.
|
|
||||||
"""
|
|
||||||
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
|
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
"""
|
"""
|
||||||
functions for processing video
|
Functions for processing video
|
||||||
|
|
||||||
|
ATTENTION: you need to install ffmpeg and ffprobe in your env!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
@ -9,14 +11,15 @@ import numpy as np
|
|||||||
import subprocess
|
import subprocess
|
||||||
import imageio
|
import imageio
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
from .helper import prefix
|
|
||||||
|
from .rprint import rlog as log
|
||||||
from .rprint import rprint as print
|
from .rprint import rprint as print
|
||||||
|
from .helper import prefix
|
||||||
|
|
||||||
|
|
||||||
def exec_cmd(cmd):
|
def exec_cmd(cmd):
|
||||||
subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
|
||||||
|
|
||||||
def images2video(images, wfp, **kwargs):
|
def images2video(images, wfp, **kwargs):
|
||||||
@ -35,7 +38,7 @@ def images2video(images, wfp, **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
n = len(images)
|
n = len(images)
|
||||||
for i in track(range(n), description='writing', transient=True):
|
for i in track(range(n), description='Writing', transient=True):
|
||||||
if image_mode.lower() == 'bgr':
|
if image_mode.lower() == 'bgr':
|
||||||
writer.append_data(images[i][..., ::-1])
|
writer.append_data(images[i][..., ::-1])
|
||||||
else:
|
else:
|
||||||
@ -43,9 +46,6 @@ def images2video(images, wfp, **kwargs):
|
|||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
# print(f':smiley: Dump to {wfp}\n', style="bold green")
|
|
||||||
print(f'Dump to {wfp}\n')
|
|
||||||
|
|
||||||
|
|
||||||
def video2gif(video_fp, fps=30, size=256):
|
def video2gif(video_fp, fps=30, size=256):
|
||||||
if osp.exists(video_fp):
|
if osp.exists(video_fp):
|
||||||
@ -54,10 +54,10 @@ def video2gif(video_fp, fps=30, size=256):
|
|||||||
palette_wfp = osp.join(d, 'palette.png')
|
palette_wfp = osp.join(d, 'palette.png')
|
||||||
gif_wfp = osp.join(d, f'{fn}.gif')
|
gif_wfp = osp.join(d, f'{fn}.gif')
|
||||||
# generate the palette
|
# generate the palette
|
||||||
cmd = f'ffmpeg -i {video_fp} -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" {palette_wfp} -y'
|
cmd = f'ffmpeg -i "{video_fp}" -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" "{palette_wfp}" -y'
|
||||||
exec_cmd(cmd)
|
exec_cmd(cmd)
|
||||||
# use the palette to generate the gif
|
# use the palette to generate the gif
|
||||||
cmd = f'ffmpeg -i {video_fp} -i {palette_wfp} -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" {gif_wfp} -y'
|
cmd = f'ffmpeg -i "{video_fp}" -i "{palette_wfp}" -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" "{gif_wfp}" -y'
|
||||||
exec_cmd(cmd)
|
exec_cmd(cmd)
|
||||||
else:
|
else:
|
||||||
print(f'video_fp: {video_fp} not exists!')
|
print(f'video_fp: {video_fp} not exists!')
|
||||||
@ -65,7 +65,7 @@ def video2gif(video_fp, fps=30, size=256):
|
|||||||
|
|
||||||
def merge_audio_video(video_fp, audio_fp, wfp):
|
def merge_audio_video(video_fp, audio_fp, wfp):
|
||||||
if osp.exists(video_fp) and osp.exists(audio_fp):
|
if osp.exists(video_fp) and osp.exists(audio_fp):
|
||||||
cmd = f'ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y'
|
cmd = f'ffmpeg -i "{video_fp}" -i "{audio_fp}" -c:v copy -c:a aac "{wfp}" -y'
|
||||||
exec_cmd(cmd)
|
exec_cmd(cmd)
|
||||||
print(f'merge {video_fp} and {audio_fp} to {wfp}')
|
print(f'merge {video_fp} and {audio_fp} to {wfp}')
|
||||||
else:
|
else:
|
||||||
@ -80,21 +80,23 @@ def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def concat_frames(I_p_lst, driving_rgb_lst, img_rgb):
|
def concat_frames(driving_image_lst, source_image, I_p_lst):
|
||||||
# TODO: add more concat style, e.g., left-down corner driving
|
# TODO: add more concat style, e.g., left-down corner driving
|
||||||
out_lst = []
|
out_lst = []
|
||||||
|
h, w, _ = I_p_lst[0].shape
|
||||||
|
|
||||||
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
|
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
|
||||||
source_image_drived = I_p_lst[idx]
|
I_p = I_p_lst[idx]
|
||||||
image_drive = driving_rgb_lst[idx]
|
source_image_resized = cv2.resize(source_image, (w, h))
|
||||||
|
|
||||||
# resize images to match source_image_drived shape
|
if driving_image_lst is None:
|
||||||
h, w, _ = source_image_drived.shape
|
out = np.hstack((source_image_resized, I_p))
|
||||||
image_drive_resized = cv2.resize(image_drive, (w, h))
|
else:
|
||||||
img_rgb_resized = cv2.resize(img_rgb, (w, h))
|
driving_image = driving_image_lst[idx]
|
||||||
|
driving_image_resized = cv2.resize(driving_image, (w, h))
|
||||||
|
out = np.hstack((driving_image_resized, source_image_resized, I_p))
|
||||||
|
|
||||||
# concatenate images horizontally
|
out_lst.append(out)
|
||||||
frame = np.concatenate((image_drive_resized, img_rgb_resized, source_image_drived), axis=1)
|
|
||||||
out_lst.append(frame)
|
|
||||||
return out_lst
|
return out_lst
|
||||||
|
|
||||||
|
|
||||||
@ -126,14 +128,84 @@ class VideoWriter:
|
|||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=5):
|
def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=12):
|
||||||
cmd = f"ffmpeg -i {input_file} -c:v {codec} -crf {crf} -r {fps} {output_file} -y"
|
cmd = f'ffmpeg -i "{input_file}" -c:v {codec} -crf {crf} -r {fps} "{output_file}" -y'
|
||||||
exec_cmd(cmd)
|
exec_cmd(cmd)
|
||||||
|
|
||||||
|
|
||||||
def get_fps(filepath):
|
def get_fps(filepath, default_fps=25):
|
||||||
import ffmpeg
|
try:
|
||||||
probe = ffmpeg.probe(filepath)
|
fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS)
|
||||||
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
|
|
||||||
fps = eval(video_stream['avg_frame_rate'])
|
if fps in (0, None):
|
||||||
|
fps = default_fps
|
||||||
|
except Exception as e:
|
||||||
|
log(e)
|
||||||
|
fps = default_fps
|
||||||
|
|
||||||
return fps
|
return fps
|
||||||
|
|
||||||
|
|
||||||
|
def has_audio_stream(video_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the video file contains an audio stream.
|
||||||
|
|
||||||
|
:param video_path: Path to the video file
|
||||||
|
:return: True if the video contains an audio stream, False otherwise
|
||||||
|
"""
|
||||||
|
if osp.isdir(video_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
'ffprobe',
|
||||||
|
'-v', 'error',
|
||||||
|
'-select_streams', 'a',
|
||||||
|
'-show_entries', 'stream=codec_type',
|
||||||
|
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||||
|
f'"{video_path}"'
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
result = exec_cmd(' '.join(cmd))
|
||||||
|
if result.returncode != 0:
|
||||||
|
log(f"Error occurred while probing video: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if there is any output from ffprobe command
|
||||||
|
return bool(result.stdout.strip())
|
||||||
|
except Exception as e:
|
||||||
|
log(f"Error occurred while probing video: {video_path}, you may need to install ffprobe! Now set audio to false!", style="bold red")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg',
|
||||||
|
'-y',
|
||||||
|
'-i', f'"{silent_video_path}"',
|
||||||
|
'-i', f'"{audio_video_path}"',
|
||||||
|
'-map', '0:v',
|
||||||
|
'-map', '1:a',
|
||||||
|
'-c:v', 'copy',
|
||||||
|
'-shortest',
|
||||||
|
f'"{output_video_path}"'
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
exec_cmd(' '.join(cmd))
|
||||||
|
log(f"Video with audio generated successfully: {output_video_path}")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
log(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def bb_intersection_over_union(boxA, boxB):
|
||||||
|
xA = max(boxA[0], boxB[0])
|
||||||
|
yA = max(boxA[1], boxB[1])
|
||||||
|
xB = min(boxA[2], boxB[2])
|
||||||
|
yB = min(boxA[3], boxB[3])
|
||||||
|
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
||||||
|
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
||||||
|
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
||||||
|
iou = interArea / float(boxAArea + boxBArea - interArea)
|
||||||
|
return iou
|
||||||
|
19
src/utils/viz.py
Normal file
19
src/utils/viz.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||||
|
|
||||||
|
|
||||||
|
def viz_lmk(img_, vps, **kwargs):
|
||||||
|
"""可视化点"""
|
||||||
|
lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA
|
||||||
|
img_for_viz = img_.copy()
|
||||||
|
for pt in vps:
|
||||||
|
cv2.circle(
|
||||||
|
img_for_viz,
|
||||||
|
(int(pt[0]), int(pt[1])),
|
||||||
|
radius=kwargs.get("radius", 1),
|
||||||
|
color=(0, 255, 0),
|
||||||
|
thickness=kwargs.get("thickness", 1),
|
||||||
|
lineType=lineType,
|
||||||
|
)
|
||||||
|
return img_for_viz
|
@ -1,37 +0,0 @@
|
|||||||
# coding: utf-8
|
|
||||||
|
|
||||||
"""
|
|
||||||
[WIP] Pipeline for video template preparation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import tyro
|
|
||||||
from src.config.crop_config import CropConfig
|
|
||||||
from src.config.inference_config import InferenceConfig
|
|
||||||
from src.config.argument_config import ArgumentConfig
|
|
||||||
from src.template_maker import TemplateMaker
|
|
||||||
|
|
||||||
|
|
||||||
def partial_fields(target_class, kwargs):
|
|
||||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# set tyro theme
|
|
||||||
tyro.extras.set_accent_color("bright_cyan")
|
|
||||||
args = tyro.cli(ArgumentConfig)
|
|
||||||
|
|
||||||
# specify configs for inference
|
|
||||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
|
||||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
|
||||||
|
|
||||||
video_template_maker = TemplateMaker(
|
|
||||||
inference_cfg=inference_cfg,
|
|
||||||
crop_cfg=crop_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
# run
|
|
||||||
video_template_maker.make_motion_template(args.driving_video_path, args.template_output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
Loading…
Reference in New Issue
Block a user