feat: launch LivePortrait

This commit is contained in:
guojianzhu 2024-07-04 04:32:47 +08:00
parent 90223399b2
commit a272d74e70
130 changed files with 30323 additions and 0 deletions

17
.gitignore vendored Normal file
View File

@ -0,0 +1,17 @@
# Byte-compiled / optimized / DLL files
__pycache__/
**/__pycache__/
*.py[cod]
**/*.py[cod]
*$py.class
# Model weights
**/*.pth
**/*.onnx
# Ipython notebook
*.ipynb
# Temporary files or benchmark resources
animations/*
tmp/*

19
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,19 @@
{
"[python]": {
"editor.tabSize": 4
},
"files.eol": "\n",
"files.insertFinalNewline": true,
"files.trimFinalNewlines": true,
"files.trimTrailingWhitespace": true,
"files.exclude": {
"**/.git": true,
"**/.svn": true,
"**/.hg": true,
"**/CVS": true,
"**/.DS_Store": true,
"**/Thumbs.db": true,
"**/*.crswap": true,
"**/__pycache__": true
}
}

136
app.py Normal file
View File

@ -0,0 +1,136 @@
# coding: utf-8
"""
The entrance of the gradio
"""
import os
import os.path as osp
import gradio as gr
import tyro
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
# assets
title_md = "assets/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples = [
[osp.join(example_portrait_dir, "s1.jpg"), osp.join(example_video_dir, "d1.mp4"), True, True, True],
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d2.mp4"), True, True, True],
[osp.join(example_portrait_dir, "s3.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True],
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True],
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True],
]
#################### interface logic ####################
# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
output_image = gr.Image(label="The animated image with the given eye-close and lip-close ratio.", type="numpy")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"):
image_input = gr.Image(label="Please upload the reference portrait here.", type="filepath")
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video(label="Please upload the driving video here.")
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative pose")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
with gr.Row():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video = gr.Video(label="The animated video after pasted back.")
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat = gr.Video(label="The animated video and driving video.")
with gr.Row():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
with gr.Row():
# Examples
gr.Markdown("## You could choose the examples below ⬇️")
with gr.Row():
gr.Examples(
examples=data_examples,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input
],
examples_per_page=5
)
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
with gr.Row():
with gr.Column():
process_button_close_ratio = gr.Button("🤖 Calculate the eye-close and lip-close ratio")
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton([output_image, eye_retargeting_slider, lip_retargeting_slider], value="🧹 Clear")
# with gr.Column():
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Column():
with gr.Accordion(open=True, label="Eye and lip Retargeting Result"):
output_image.render()
# binding functions for buttons
process_button_close_ratio.click(
fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider],
show_progress=True
)
process_button_retargeting.click(
fn=gradio_pipeline.execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider],
outputs=output_image,
show_progress=True
)
process_button_animation.click(
fn=gradio_pipeline.execute_video,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input
],
outputs=[output_video, output_video_concat],
show_progress=True
)
process_button_reset.click()
process_button_reset_retargeting
##########################################################
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
share=args.share,
)

23
assets/.gitattributes vendored Normal file
View File

@ -0,0 +1,23 @@
docs/inference.gif filter=lfs diff=lfs merge=lfs -text
docs/showcase2.gif filter=lfs diff=lfs merge=lfs -text
docs/showcase.gif filter=lfs diff=lfs merge=lfs -text
examples/driving/d5.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d7.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d9.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d6.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d8.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d0.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d1.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d2.mp4 filter=lfs diff=lfs merge=lfs -text
examples/driving/d3.mp4 filter=lfs diff=lfs merge=lfs -text
examples/source/s5.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s7.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s9.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s4.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s10.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s1.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s2.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s3.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s6.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s8.jpg filter=lfs diff=lfs merge=lfs -text
examples/source/s0.jpg filter=lfs diff=lfs merge=lfs -text

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1316eca5556ba5a8da7c53bcadbc1df26aa822bbde68fbad94813139803d0c6
size 819961

3
assets/docs/showcase.gif Normal file
View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7bca5f38bfd555bf7c013312d87883afdf39d97fba719ac171c60f897af49e21
size 6623248

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eb1fffb139681775780b2956e7d0289f55d199c1a3e14ab263887864d4b0d586
size 2881351

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:63f6f9962e1fdf6e6722172e7a18155204858d5d5ce3b1e0646c150360c33bed
size 2958395

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d8bdb2e3d18db28a7c81bc06070aefe2beebae2b489b976fa009eff5881bc7fe
size 48753

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5ce4a92e80e4d0b44c43af7afba471b3ebdf7dc83894ae98af6d4973f17af484
size 47762

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ef5c86e49b1b43dcb1449b499eb5a7f0cbae2f78aec08b5598193be1e4257099
size 1430968

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:021af59d9e1d89c967699efd374e7acc77fd14c30b052abd98665de401d9e511
size 135015

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:00e3ea79bbf28cbdc4fbb67ec655d9a0fe876e880ec45af55ae481348d0c0fff
size 1967790

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e3ecc59a7285d52ca3666be6562e59ed45ad8afcef28d0f69a9be0e2850ed37
size 185258

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:912a02dcffef9b1a6d6641c9a48e170490080c3715dc9f55d91e168483d27e6e
size 312295

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9a414aa1d547be35306d692065a2157434bf40a6025ba8e30ce12e5bb322cc33
size 2257929

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ccd094951787cec391f3d444e056ee9a58d715f84152f49100c03bb3ce2962fc
size 116149

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:98931abeba483c27c699197e7c7bab58b7deb4324745b0f84e3c92e9ccef5901
size 98411

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:49bdee735bed389d289c20d0bdff9bf0a485068e128cec8eb9e5bb11ae1f0422
size 537372

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4d950dc1c6560e0c7d0bcdca258079ba605d26f0979b7a602af75511a15e4c03
size 760146

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d93728ec60e1c0fceaf0f62b05d0b70283e9383a6eef0f31dc2078cb560f707a
size 63626

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d62af6c554e211942f954237f7b898abb41d9d6166b5079564e76c25ff804c55
size 143918

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9c368d9aa2eea0c5319b7d76e3b6c339819e49565ef064933a4511c11b581549
size 133517

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa226be24fc24c11c4bb276c2d7d789218eba88b8e49155855f5685d1b4d7809
size 107520

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a94b38bb66b58c3d44b98b86be3138cd3fb3ca40f89017294476b879353ebc4
size 139855

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:01b6af295ebefcac1291cbe3b84fc6af37e40b168c528113662519aace4bbce5
size 227378

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:84a6bfcb3346951cef85c0b0bb07856794e0b04d5b65322c18d77198fbf43f2a
size 442850

View File

@ -0,0 +1,7 @@
<span style="font-size: 1.2em;">🔥 To animate the reference portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>

View File

@ -0,0 +1,7 @@
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please:</span>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">1. Please <strong>first</strong> press the <strong>🤖 Calculate the eye-close and lip-close ratio</strong> button, and wait for the result shown in the sliders.</span>
</div>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">2. Please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. Then the result would be shown in the middle block. You can try running it multiple times!</span>
</div>

View File

@ -0,0 +1,4 @@
## 🤗 This is the official gradio demo for **Live Portrait**.
### Guidance for the gradio page:
<div style="font-size: 1.2em;">Please upload or use the webcam to get a reference portrait to the <strong>Reference Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>

10
assets/gradio_title.md Normal file
View File

@ -0,0 +1,10 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href=""><img src="https://img.shields.io/badge/arXiv-XXXX.XXXX-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div>
</div>
</div>

33
inference.py Normal file
View File

@ -0,0 +1,33 @@
# coding: utf-8
import tyro
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)
# run
live_portrait_pipeline.execute(args)
if __name__ == '__main__':
main()

View File

145
readme.md Normal file
View File

@ -0,0 +1,145 @@
<h1 align="center">LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div align='center'>
<a href='https://github.com/cleardusk' target='_blank'>Jianzhu Guo</a><sup>1*</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'>Dingyun Zhang</a><sup>1,2</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'>Xiaoqiang Liu</a><sup>1</sup>&emsp;
<a href='https://github.com/KwaiVGI' target='_blank'>Zhizhou Zhong</a><sup>1,3</sup>&emsp;
<a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'>Yuan Zhang</a><sup>1</sup>&emsp;
<a href='https://scholar.google.com/citations?user=P6MraaYAAAAJ' target='_blank'>Pengfei Wan</a><sup>1</sup>&emsp;
<a href='https://openreview.net/profile?id=~Di_ZHANG3' target='_blank'>Di Zhang</a><sup>1</sup>&emsp;
</div>
<div align='center'>
<sup>1</sup>Kuaishou Technology&emsp; <sup>2</sup>University of Science and Technology of China&emsp; <sup>3</sup>Fudan University&emsp;
</div>
<br>
<div align="center">
<!-- <a href='LICENSE'><img src='https://img.shields.io/badge/license-MIT-yellow'></a> -->
<a href='https://liveportrait.github.io'><img src='https://img.shields.io/badge/Project-Homepage-green'></a>
<a href='https://github.com/KwaiVGI/LivePortrait'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
</div>
<br>
<p align="center">
<img src="./assets/docs/showcase2.gif" alt="showcase">
</p>
## 🔥 Updates
- **`2024/07/04`**: 🔥 We released the initial version of the inference code and models.
- **`2024/07/04`**: 😊 We released the technique report on [arXiv]().
## Introduction
This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control]().
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
## 🔥 Getting Started
### 1. Clone the code and prepare the environment
```bash
git clone https://github.com/KwaiVGI/LivePortrait
cd LivePortrait
# using lfs to pull the data
git lfs install
git lfs pull
# create env using conda
conda create -n LivePortrait python==3.9.18
conda activate LivePortrait
# install dependencies with pip
pip install -r requirements.txt
```
### 2. Download pretrained weights
Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
```text
pretrained_weights
├── insightface
│ └── models
│ └── buffalo_l
│ ├── 2d106det.onnx
│ └── det_10g.onnx
└── liveportrait
├── base_models
│ ├── appearance_feature_extractor.pth
│ ├── motion_extractor.pth
│ ├── spade_generator.pth
│ └── warping_module.pth
├── landmark.onnx
└── retargeting_models
└── stitching_retargeting_module.pth
```
### 3. Inference 🚀
```bash
python inference.py
```
If the script runs successfully, you will see the following results: driving video, input image, and generated result.
<p align="center">
<img src="./assets/docs/inference.gif" alt="image">
</p>
Or, you can change the input by specifying the `-s` and `-d` arguments:
```bash
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
# or disable pasting back
python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
# more options to see
python inference.py -h
```
**More interesting results can be found in our [Homepage](https://liveportrait.github.io/)** 😊
### 4. Gradio interface (WIP)
We also provide a Gradio interface for a better experience. Please install `gradio` and then run `app.py`:
```bash
pip install gradio==4.36.1
python app.py
```
***NOTE:*** *we are working on the Gradio interface and will be upgrading it soon.*
### 5. Inference speed evaluation 🚀🚀🚀
We have also provided a script to evaluate the inference speed of each module:
```bash
python speed.py
```
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
| Model | Parameters(M) | Model Size(MB) | Inference(ms) |
|-----------------------------------|:-------------:|:--------------:|:-------------:|
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
| Motion Extractor | 28.12 | 108 | 0.84 |
| Spade Generator | 55.37 | 212 | 7.59 |
| Warping Module | 45.53 | 174 | 5.21 |
| Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 |
*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.*
## Acknowledgements
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
## Citation 💖
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
```bibtex
@article{guo2024live,
title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang},
year = {2024},
journal = {arXiv preprint:24xx.xxxx},
}
```

192
speed.py Normal file
View File

@ -0,0 +1,192 @@
# coding: utf-8
"""
Benchmark the inference speed of each module in LivePortrait.
TODO: heavy GPT style, need to refactor
"""
import yaml
import torch
import time
import numpy as np
from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig
def initialize_inputs(batch_size=1):
"""
Generate random input tensors and move them to GPU
"""
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
kp_source = torch.randn(batch_size, 21, 3).cuda().half()
kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
feat_stitching = concat_feat(kp_source, kp_driving).half()
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
inputs = {
'feature_3d': feature_3d,
'kp_source': kp_source,
'kp_driving': kp_driving,
'source_image': source_image,
'generator_input': generator_input,
'feat_stitching': feat_stitching,
'feat_eye': feat_eye,
'feat_lip': feat_lip
}
return inputs
def load_and_compile_models(cfg, model_config):
"""
Load and compile models for inference
"""
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor),
('Motion Extractor', motion_extractor),
('Warping Network', warping_module),
('SPADE Decoder', spade_generator)
]
compiled_models = {}
for name, model in models_with_params:
model = model.half()
model = torch.compile(model, mode='max-autotune') # Optimize for inference
model.eval() # Switch to evaluation mode
compiled_models[name] = model
retargeting_models = ['stitching', 'eye', 'lip']
for retarget in retargeting_models:
module = stitching_retargeting_module[retarget].half()
module = torch.compile(module, mode='max-autotune') # Optimize for inference
module.eval() # Switch to evaluation mode
stitching_retargeting_module[retarget] = module
return compiled_models, stitching_retargeting_module
def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
"""
Warm up models to prepare them for benchmarking
"""
print("Warm up start!")
with torch.no_grad():
for _ in range(10):
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
compiled_models['Motion Extractor'](inputs['source_image'])
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
print("Warm up end!")
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
"""
Measure inference times for each model
"""
times = {name: [] for name in compiled_models.keys()}
times['Retargeting Models'] = []
overall_times = []
with torch.no_grad():
for _ in range(100):
torch.cuda.synchronize()
overall_start = time.time()
start = time.time()
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Appearance Feature Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Motion Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Motion Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
torch.cuda.synchronize()
times['Warping Network'].append(time.time() - start)
start = time.time()
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
torch.cuda.synchronize()
times['SPADE Decoder'].append(time.time() - start)
start = time.time()
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
torch.cuda.synchronize()
times['Retargeting Models'].append(time.time() - start)
overall_times.append(time.time() - overall_start)
return times, overall_times
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
"""
Print benchmark results with average and standard deviation of inference times
"""
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
for name, model in compiled_models.items():
num_params = sum(p.numel() for p in model.parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
for index, retarget in enumerate(retargeting_models):
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
for name, avg_time in average_times.items():
std_time = std_times[name]
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
def main():
"""
Main function to benchmark speed and model parameters
"""
# Sample input tensors
inputs = initialize_inputs()
# Load configuration
cfg = InferenceConfig(device_id=0)
model_config_path = cfg.models_config
with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file)
# Load and compile models
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
# Warm up models
warm_up_models(compiled_models, stitching_retargeting_module, inputs)
# Measure inference times
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
# Print benchmark results
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
if __name__ == "__main__":
main()

0
src/config/__init__.py Normal file
View File

View File

@ -0,0 +1,44 @@
# coding: utf-8
"""
config for user
"""
import os.path as osp
from dataclasses import dataclass
import tyro
from typing_extensions import Annotated
from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig):
########## input arguments ##########
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the reference portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
#####################################
########## inference arguments ##########
device_id: int = 0
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative pose
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the reference portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
#########################################
########## crop arguments ##########
dsize: int = 512
scale: float = 2.3
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down
####################################
########## gradio arguments ##########
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
share: bool = False
server_name: str = "0.0.0.0"

29
src/config/base_config.py Normal file
View File

@ -0,0 +1,29 @@
# coding: utf-8
"""
pretty printing class
"""
from __future__ import annotations
import os.path as osp
from typing import Tuple
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
class PrintableConfig: # pylint: disable=too-few-public-methods
"""Printable Config defining str function"""
def __repr__(self):
lines = [self.__class__.__name__ + ":"]
for key, val in vars(self).items():
if isinstance(val, Tuple):
flattened_val = "["
for item in val:
flattened_val += str(item) + "\n"
flattened_val = flattened_val.rstrip("\n")
val = flattened_val + "]"
lines += f"{key}: {str(val)}".split("\n")
return "\n ".join(lines)

18
src/config/crop_config.py Normal file
View File

@ -0,0 +1,18 @@
# coding: utf-8
"""
parameters used for crop faces
"""
import os.path as osp
from dataclasses import dataclass
from typing import Union, List
from .base_config import PrintableConfig
@dataclass(repr=False) # use repr from PrintableConfig
class CropConfig(PrintableConfig):
dsize: int = 512 # crop size
scale: float = 2.3 # scale factor
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down

View File

@ -0,0 +1,49 @@
# coding: utf-8
"""
config dataclass used for inference
"""
import os.path as osp
from dataclasses import dataclass
from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig):
models_config: str = make_abs_path('./models.yaml') # portrait animation config
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint
flag_use_half_precision: bool = True # whether to use half precision
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
lip_zero_threshold: float = 0.03
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative pose
anchor_frame: int = 0 # set this value if find_best_frame is True
input_shape: Tuple[int, int] = (256, 256) # input shape
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
output_fps: int = 30 # fps for output video
crf: int = 15 # crf for output video
flag_write_result: bool = True # whether to write output video
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
mask_crop = None
flag_write_gif: bool = False
size_gif: int = 256
ref_max_shape: int = 1280
ref_shape_n: int = 2
device_id: int = 0
flag_do_crop: bool = False # whether to crop the reference portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True

43
src/config/models.yaml Normal file
View File

@ -0,0 +1,43 @@
model_params:
appearance_feature_extractor_params: # the F in the paper
image_channel: 3
block_expansion: 64
num_down_blocks: 2
max_features: 512
reshape_channel: 32
reshape_depth: 16
num_resblocks: 6
motion_extractor_params: # the M in the paper
num_kp: 21
backbone: convnextv2_tiny
warping_module_params: # the W in the paper
num_kp: 21
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
spade_generator_params: # the G in the paper
upscale: 2 # represents upsample factor 256x256 -> 512x512
block_expansion: 64
max_features: 512
num_down_blocks: 2
stitching_retargeting_module_params: # the S in the paper
stitching:
input_size: 126 # (21*3)*2
hidden_sizes: [128, 128, 64]
output_size: 65 # (21*3)+2(tx,ty)
lip:
input_size: 65 # (21*3)+2
hidden_sizes: [128, 128, 64]
output_size: 63 # (21*3)
eye:
input_size: 66 # (21*3)+3
hidden_sizes: [256, 256, 128, 128, 64]
output_size: 63 # (21*3)

104
src/gradio_pipeline.py Normal file
View File

@ -0,0 +1,104 @@
# coding: utf-8
"""
Pipeline for gradio
"""
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.camera import get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from .utils.rprint import rlog as log
def update_args(args, user_args):
"""update the args according to user inputs
"""
for k, v in user_args.items():
if hasattr(args, k):
setattr(args, k, v)
return args
class GradioPipeline(LivePortraitPipeline):
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
# for single image retargeting
self.f_s_user = None
self.x_c_s_info_user = None
self.x_s_user = None
self.source_lmk_user = None
def execute_video(
self,
input_image_path,
input_video_path,
flag_relative_input,
flag_do_crop_input,
flag_remap_input
):
""" for video driven potrait animation
"""
args_user = {
'source_image': input_image_path,
'driving_info': input_video_path,
'flag_relative': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
return video_path, video_path_concat
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting
"""
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
num_kp = self.x_s_user.shape[1]
# default: use x_s
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
return out
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
""" for single image retargeting
"""
inference_cfg = self.live_portrait_wrapper.cfg
######## process reference portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
# record global info for next time use
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
return eye_close_ratio, lip_close_ratio

View File

@ -0,0 +1,195 @@
# coding: utf-8
"""
Pipeline of LivePortrait
"""
# TODO:
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
# 2. pick样例图 source + driving
import cv2
import numpy as np
import pickle
import os.path as osp
from rich.progress import track
from .config.argument_config import ArgumentConfig
from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames
from .utils.crop import _transform_img
from .utils.retargeting_utils import calc_lip_close_ratio
from .utils.io import load_image_rgb, load_driving_info
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template, resize_to_limit
from .utils.rprint import rlog as log
from .live_portrait_wrapper import LivePortraitWrapper
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
class LivePortraitPipeline(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg)
def execute(self, args: ArgumentConfig):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
######## process reference portrait ########
img_rgb = load_image_rgb(args.source_image)
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
log(f"Load source image from {args.source_image}")
crop_info = self.cropper.crop_single_image(img_rgb)
source_lmk = crop_info['lmk_crop']
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
if inference_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
if inference_cfg.flag_lip_zero:
# let lip-open scalar to be 0 at first
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
inference_cfg.flag_lip_zero = False
else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ########
if is_video(args.driving_info):
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
# TODO: 这里track一下驱动视频 -> 构建模板
driving_rgb_lst = load_driving_info(args.driving_info)
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
n_frames = I_d_lst.shape[0]
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
elif is_template(args.driving_info):
log(f"Load from video templates {args.driving_info}")
with open(args.driving_info, 'rb') as f:
template_lst, driving_lmk_lst = pickle.load(f)
n_frames = template_lst[0]['n_frames']
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
else:
raise Exception("Unsupported driving types!")
#########################################
######## prepare for pasteback ########
if inference_cfg.flag_pasteback:
if inference_cfg.mask_crop is None:
inference_cfg.mask_crop = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
mask_ori = mask_ori.astype(np.float32) / 255.
I_p_paste_lst = []
#########################################
I_p_lst = []
R_d_0, x_d_0_info = None, None
for i in track(range(n_frames), description='Animating...', total=n_frames):
if is_video(args.driving_info):
# extract kp info by M
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
else:
# from template
x_d_i_info = template_lst[i]
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if inference_cfg.flag_relative:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
R_new = R_d_i
delta_new = x_d_i_info['exp']
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
# Algorithm 1:
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
# without stitching or retargeting
if inference_cfg.flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
pass
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
# with stitching and without retargeting
if inference_cfg.flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
if inference_cfg.flag_eye_retargeting:
c_d_eyes_i = input_eye_ratio_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
if inference_cfg.flag_lip_retargeting:
c_d_lip_i = input_lip_ratio_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
if inference_cfg.flag_relative: # use x_s
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
if inference_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if inference_cfg.flag_pasteback:
I_p_i_to_ori = _transform_img(I_p_i, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_i_to_ori_blend = np.clip(mask_ori * I_p_i_to_ori + (1 - mask_ori) * img_rgb, 0, 255).astype(np.uint8)
out = np.hstack([I_p_i_to_ori, I_p_i_to_ori_blend])
I_p_paste_lst.append(I_p_i_to_ori_blend)
mkdir(args.output_dir)
wfp_concat = None
if is_video(args.driving_info):
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
# save (driving frames, source image, drived frames) result
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat)
# save drived result
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
if inference_cfg.flag_pasteback:
images2video(I_p_paste_lst, wfp=wfp)
else:
images2video(I_p_lst, wfp=wfp)
return wfp, wfp_concat

View File

@ -0,0 +1,335 @@
# coding: utf-8
"""
Wrapper for LivePortrait core functions
"""
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
from src.utils.timer import Timer
from src.utils.helper import load_model, concat_feat
from src.utils.retargeting_utils import compute_eye_delta, compute_lip_delta
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.rprint import rlog as log
class LivePortraitWrapper(object):
def __init__(self, cfg: InferenceConfig):
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor done.')
# init M
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
log(f'Load motion_extractor done.')
# init W
self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
log(f'Load warping_module done.')
# init G
self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
log(f'Load spade_generator done.')
# init S and R
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module done.')
else:
self.stitching_retargeting_module = None
self.cfg = cfg
self.device_id = cfg.device_id
self.timer = Timer()
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.cfg, k):
setattr(self.cfg, k, v)
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
""" construct the input as standard
img: HxWx3, uint8, 256x256
"""
h, w = img.shape[:2]
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
else:
x = img.copy()
if x.ndim == 3:
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
elif x.ndim == 4:
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
else:
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
x = np.clip(x, 0, 1) # clip to 0~1
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
x = x.cuda(self.device_id)
return x
def prepare_driving_videos(self, imgs) -> torch.Tensor:
""" construct the input as standard
imgs: NxBxHxWx3, uint8
"""
if isinstance(imgs, list):
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
elif isinstance(imgs, np.ndarray):
_imgs = imgs
else:
raise ValueError(f'imgs type error: {type(imgs)}')
y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
y = y.cuda(self.device_id)
return y
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
""" get the appearance feature of the image by F
x: Bx3xHxW, normalized to 0~1
"""
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
feature_3d = self.appearance_feature_extractor(x)
return feature_3d.float()
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
""" get the implicit keypoint information
x: Bx3xHxW, normalized to 0~1
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
"""
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
kp_info = self.motion_extractor(x)
if self.cfg.flag_use_half_precision:
# float the dict
for k, v in kp_info.items():
if isinstance(v, torch.Tensor):
kp_info[k] = v.float()
flag_refine_info: bool = kwargs.get('flag_refine_info', True)
if flag_refine_info:
bs = kp_info['kp'].shape[0]
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
return kp_info
def get_pose_dct(self, kp_info: dict) -> dict:
pose_dct = dict(
pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
roll=headpose_pred_to_degree(kp_info['roll']).item(),
)
return pose_dct
def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
# get the canonical keypoints of source image by M
source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
# get the canonical keypoints of first driving frame by M
driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
driving_first_frame_rotation = get_rotation_matrix(
driving_first_frame_kp_info['pitch'],
driving_first_frame_kp_info['yaw'],
driving_first_frame_kp_info['roll']
)
# get feature volume by F
source_feature_3d = self.extract_feature_3d(source_prepared)
return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
def transform_keypoint(self, kp_info: dict):
"""
transform the implicit keypoints with the pose, shift, and expression deformation
kp: BxNx3
"""
kp = kp_info['kp'] # (bs, k, 3)
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
t, exp = kp_info['t'], kp_info['exp']
scale = kp_info['scale']
pitch = headpose_pred_to_degree(pitch)
yaw = headpose_pred_to_degree(yaw)
roll = headpose_pred_to_degree(roll)
bs = kp.shape[0]
if kp.ndim == 2:
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
else:
num_kp = kp.shape[1] # Bxnum_kpx3
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
# Eqn.2: s * (R * x_c,s + exp) + t
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
return kp_transformed
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
eye_close_ratio: Bx3
Return: Bx(3*num_kp+2)
"""
feat_eye = concat_feat(kp_source, eye_close_ratio)
with torch.no_grad():
delta = self.stitching_retargeting_module['eye'](feat_eye)
return delta
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
lip_close_ratio: Bx2
"""
feat_lip = concat_feat(kp_source, lip_close_ratio)
with torch.no_grad():
delta = self.stitching_retargeting_module['lip'](feat_lip)
return delta
def retarget_keypoints(self, frame_idx, num_keypoints, input_eye_ratios, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source, driving_transformed_kp):
# TODO: GPT style, refactor it...
if self.cfg.flag_eye_retargeting:
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eye_delta = compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source)
else:
# α_eyes = 0
eye_delta = None
if self.cfg.flag_lip_retargeting:
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source)
else:
# α_lip = 0
lip_delta = None
if self.cfg.flag_relative: # use x_s
new_driving_kp = kp_source + \
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
else: # use x_d,i
new_driving_kp = driving_transformed_kp + \
(eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \
(lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0)
return new_driving_kp
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
kp_driving: BxNx3
Return: Bx(3*num_kp+2)
"""
feat_stiching = concat_feat(kp_source, kp_driving)
with torch.no_grad():
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
return delta
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" conduct the stitching
kp_source: Bxnum_kpx3
kp_driving: Bxnum_kpx3
"""
if self.stitching_retargeting_module is not None:
bs, num_kp = kp_source.shape[:2]
kp_driving_new = kp_driving.clone()
delta = self.stitch(kp_source, kp_driving_new)
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
kp_driving_new += delta_exp
kp_driving_new[..., :2] += delta_tx_ty
return kp_driving_new
return kp_driving
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" get the image after the warping of the implicit keypoints
feature_3d: Bx32x16x64x64, feature volume
kp_source: BxNx3
kp_driving: BxNx3
"""
# The line 18 in Algorithm 1: D(W(f_s; x_s, x_d,i)
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
# float the dict
if self.cfg.flag_use_half_precision:
for k, v in ret_dct.items():
if isinstance(v, torch.Tensor):
ret_dct[k] = v.float()
return ret_dct
def parse_output(self, out: torch.Tensor) -> np.ndarray:
""" construct the output as standard
return: 1xHxWx3, uint8
"""
out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
out = np.clip(out, 0, 1) # clip to 0~1
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
return out
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
input_eye_ratio_lst = []
input_lip_ratio_lst = []
for lmk in driving_lmk_lst:
# for eyes retargeting
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
# for lip retargeting
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
return input_eye_ratio_lst, input_lip_ratio_lst
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
# [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
return combined_eye_ratio_tensor
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
# [c_s,lip, c_d,lip,i]
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
if input_lip_ratio_tensor.shape != [1, 1]:
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
return combined_lip_ratio_tensor

0
src/modules/__init__.py Normal file
View File

View File

@ -0,0 +1,48 @@
# coding: utf-8
"""
Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
"""
import torch
from torch import nn
from .util import SameBlock2d, DownBlock2d, ResBlock3d
class AppearanceFeatureExtractor(nn.Module):
def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks):
super(AppearanceFeatureExtractor, self).__init__()
self.image_channel = image_channel
self.block_expansion = block_expansion
self.num_down_blocks = num_down_blocks
self.max_features = max_features
self.reshape_channel = reshape_channel
self.reshape_depth = reshape_depth
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
self.resblocks_3d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
def forward(self, source_image):
out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
out = self.second(out)
bs, c, h, w = out.shape # ->Bx512x64x64
f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64
f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
return f_s

149
src/modules/convnextv2.py Normal file
View File

@ -0,0 +1,149 @@
# coding: utf-8
"""
This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
"""
import torch
import torch.nn as nn
# from timm.models.layers import trunc_normal_, DropPath
from .util import LayerNorm, DropPath, trunc_normal_, GRN
__all__ = ['convnextv2_tiny']
class Block(nn.Module):
""" ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class ConvNeXtV2(nn.Module):
""" ConvNeXt V2
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_chans=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.,
**kwargs
):
super().__init__()
self.depths = depths
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
# NOTE: the output semantic items
num_bins = kwargs.get('num_bins', 66)
num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
# print('dims[-1]: ', dims[-1])
self.fc_scale = nn.Linear(dims[-1], 1) # scale
self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
self.fc_t = nn.Linear(dims[-1], 3) # translation
self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
def forward(self, x):
x = self.forward_features(x)
# implicit keypoints
kp = self.fc_kp(x)
# pose and expression deformation
pitch = self.fc_pitch(x)
yaw = self.fc_yaw(x)
roll = self.fc_roll(x)
t = self.fc_t(x)
exp = self.fc_exp(x)
scale = self.fc_scale(x)
ret_dct = {
'pitch': pitch,
'yaw': yaw,
'roll': roll,
't': t,
'exp': exp,
'scale': scale,
'kp': kp, # canonical keypoint
}
return ret_dct
def convnextv2_tiny(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
return model

104
src/modules/dense_motion.py Normal file
View File

@ -0,0 +1,104 @@
# coding: utf-8
"""
The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
"""
from torch import nn
import torch.nn.functional as F
import torch
from .util import Hourglass, make_coordinate_grid, kp2gaussian
class DenseMotionNetwork(nn.Module):
def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
super(DenseMotionNetwork, self).__init__()
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
self.norm = nn.BatchNorm3d(compress, affine=True)
self.num_kp = num_kp
self.flag_estimate_occlusion_map = estimate_occlusion_map
if self.flag_estimate_occlusion_map:
self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
else:
self.occlusion = None
def create_sparse_motions(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
k = coordinate_grid.shape[1]
# NOTE: there lacks an one-order flow
driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
# adding background feature
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
return sparse_motions
def create_deformed_feature(self, feature, sparse_motions):
bs, _, d, h, w = feature.shape
feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
return sparse_deformed
def create_heatmap_representations(self, feature, kp_driving, kp_source):
spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
# adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
return heatmap
def forward(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
feature = self.compress(feature) # (bs, 4, 16, 64, 64)
feature = self.norm(feature) # (bs, 4, 16, 64, 64)
feature = F.relu(feature) # (bs, 4, 16, 64, 64)
out_dict = dict()
# 1. deform 3d feature
sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
# 2. (bs, 1+num_kp, d, h, w)
heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
prediction = self.hourglass(input)
mask = self.mask(prediction)
mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
out_dict['mask'] = mask
mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
out_dict['deformation'] = deformation
if self.flag_estimate_occlusion_map:
bs, _, d, h, w = prediction.shape
prediction_reshape = prediction.view(bs, -1, h, w)
occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
out_dict['occlusion_map'] = occlusion_map
return out_dict

View File

@ -0,0 +1,35 @@
# coding: utf-8
"""
Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
"""
from torch import nn
import torch
from .convnextv2 import convnextv2_tiny
from .util import filter_state_dict
model_dict = {
'convnextv2_tiny': convnextv2_tiny,
}
class MotionExtractor(nn.Module):
def __init__(self, **kwargs):
super(MotionExtractor, self).__init__()
# default is convnextv2_base
backbone = kwargs.get('backbone', 'convnextv2_tiny')
self.detector = model_dict.get(backbone)(**kwargs)
def load_pretrained(self, init_path: str):
if init_path not in (None, ''):
state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
state_dict = filter_state_dict(state_dict, remove_name='head')
ret = self.detector.load_state_dict(state_dict, strict=False)
print(f'Load pretrained model from {init_path}, ret: {ret}')
def forward(self, x):
out = self.detector(x)
return out

View File

@ -0,0 +1,59 @@
# coding: utf-8
"""
Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
"""
import torch
from torch import nn
import torch.nn.functional as F
from .util import SPADEResnetBlock
class SPADEDecoder(nn.Module):
def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
for i in range(num_down_blocks):
input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
self.upscale = upscale
super().__init__()
norm_G = 'spadespectralinstance'
label_num_channels = input_channels # 256
self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
self.up = nn.Upsample(scale_factor=2)
if self.upscale is None or self.upscale <= 1:
self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
else:
self.conv_img = nn.Sequential(
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
nn.PixelShuffle(upscale_factor=2)
)
def forward(self, feature):
seg = feature # Bx256x64x64
x = self.fc(feature) # Bx512x64x64
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.G_middle_2(x, seg)
x = self.G_middle_3(x, seg)
x = self.G_middle_4(x, seg)
x = self.G_middle_5(x, seg)
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
x = torch.sigmoid(x) # Bx3xHxW
return x

View File

@ -0,0 +1,38 @@
# coding: utf-8
"""
Stitching module(S) and two retargeting modules(R) defined in the paper.
- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
the stitching region.
- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
when a person with small eyes drives a person with larger eyes.
- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
the lips are in a closed state, which facilitates better animation driving.
"""
from torch import nn
class StitchingRetargetingNetwork(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super(StitchingRetargetingNetwork, self).__init__()
layers = []
for i in range(len(hidden_sizes)):
if i == 0:
layers.append(nn.Linear(input_size, hidden_sizes[i]))
else:
layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(hidden_sizes[-1], output_size))
self.mlp = nn.Sequential(*layers)
def initialize_weights_to_zero(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
return self.mlp(x)

441
src/modules/util.py Normal file
View File

@ -0,0 +1,441 @@
# coding: utf-8
"""
This file defines various neural network modules and utility functions, including convolutional and residual blocks,
normalizations, and functions for spatial transformation and tensor manipulation.
"""
from torch import nn
import torch.nn.functional as F
import torch
import torch.nn.utils.spectral_norm as spectral_norm
import math
import warnings
def kp2gaussian(kp, spatial_size, kp_variance):
"""
Transform a keypoint into gaussian like representation
"""
mean = kp
coordinate_grid = make_coordinate_grid(spatial_size, mean)
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
coordinate_grid = coordinate_grid.view(*shape)
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
coordinate_grid = coordinate_grid.repeat(*repeats)
# Preprocess kp shape
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
mean = mean.view(*shape)
mean_sub = (coordinate_grid - mean)
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
return out
def make_coordinate_grid(spatial_size, ref, **kwargs):
d, h, w = spatial_size
x = torch.arange(w).type(ref.dtype).to(ref.device)
y = torch.arange(h).type(ref.dtype).to(ref.device)
z = torch.arange(d).type(ref.dtype).to(ref.device)
# NOTE: must be right-down-in
x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
yy = y.view(1, -1, 1).repeat(d, 1, w)
xx = x.view(1, 1, -1).repeat(d, h, 1)
zz = z.view(-1, 1, 1).repeat(1, h, w)
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
return meshed
class ConvT2d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
super(ConvT2d, self).__init__()
self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
padding=padding, output_padding=output_padding)
self.norm = nn.InstanceNorm2d(out_features)
def forward(self, x):
out = self.convT(x)
out = self.norm(out)
out = F.leaky_relu(out)
return out
class ResBlock3d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock3d, self).__init__()
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
self.norm1 = nn.BatchNorm3d(in_features, affine=True)
self.norm2 = nn.BatchNorm3d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class UpBlock3d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(UpBlock3d, self).__init__()
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = nn.BatchNorm3d(out_features, affine=True)
def forward(self, x):
out = F.interpolate(x, scale_factor=(1, 2, 2))
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class DownBlock2d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
self.norm = nn.BatchNorm2d(out_features, affine=True)
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class DownBlock3d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock3d, self).__init__()
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups, stride=(1, 2, 2))
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = nn.BatchNorm3d(out_features, affine=True)
self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class SameBlock2d(nn.Module):
"""
Simple block, preserve spatial resolution.
"""
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
super(SameBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
self.norm = nn.BatchNorm2d(out_features, affine=True)
if lrelu:
self.ac = nn.LeakyReLU()
else:
self.ac = nn.ReLU()
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = self.ac(out)
return out
class Encoder(nn.Module):
"""
Hourglass Encoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Encoder, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
self.down_blocks = nn.ModuleList(down_blocks)
def forward(self, x):
outs = [x]
for down_block in self.down_blocks:
outs.append(down_block(outs[-1]))
return outs
class Decoder(nn.Module):
"""
Hourglass Decoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Decoder, self).__init__()
up_blocks = []
for i in range(num_blocks)[::-1]:
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
out_filters = min(max_features, block_expansion * (2 ** i))
up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
self.up_blocks = nn.ModuleList(up_blocks)
self.out_filters = block_expansion + in_features
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
def forward(self, x):
out = x.pop()
for up_block in self.up_blocks:
out = up_block(out)
skip = x.pop()
out = torch.cat([out, skip], dim=1)
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class Hourglass(nn.Module):
"""
Hourglass architecture.
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Hourglass, self).__init__()
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
self.out_filters = self.decoder.out_filters
def forward(self, x):
return self.decoder(self.encoder(x))
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
nhidden = 128
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
def forward(self, x, segmap):
normalized = self.param_free_norm(x)
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.use_se = use_se
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
self.norm_0 = SPADE(fin, label_nc)
self.norm_1 = SPADE(fmiddle, label_nc)
if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
def forward(self, x, seg1):
x_s = self.shortcut(x, seg1)
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
out = x_s + dx
return out
def shortcut(self, x, seg1):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg1))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
def filter_state_dict(state_dict, remove_name='fc'):
new_state_dict = {}
for key in state_dict:
if remove_name in key:
continue
new_state_dict[key] = state_dict[key]
return new_state_dict
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)

View File

@ -0,0 +1,77 @@
# coding: utf-8
"""
Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
"""
from torch import nn
import torch.nn.functional as F
from .util import SameBlock2d
from .dense_motion import DenseMotionNetwork
class WarpingNetwork(nn.Module):
def __init__(
self,
num_kp,
block_expansion,
max_features,
num_down_blocks,
reshape_channel,
estimate_occlusion_map=False,
dense_motion_params=None,
**kwargs
):
super(WarpingNetwork, self).__init__()
self.upscale = kwargs.get('upscale', 1)
self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(
num_kp=num_kp,
feature_channel=reshape_channel,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params
)
else:
self.dense_motion_network = None
self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
self.estimate_occlusion_map = estimate_occlusion_map
def deform_input(self, inp, deformation):
return F.grid_sample(inp, deformation, align_corners=False)
def forward(self, feature_3d, kp_driving, kp_source):
if self.dense_motion_network is not None:
# Feature warper, Transforming feature representation according to deformation and occlusion
dense_motion = self.dense_motion_network(
feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
)
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64
else:
occlusion_map = None
deformation = dense_motion['deformation'] # Bx16x64x64x3
out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64
bs, c, d, h, w = out.shape # Bx32x16x64x64
out = out.view(bs, c * d, h, w) # -> Bx512x64x64
out = self.third(out) # -> Bx256x64x64
out = self.fourth(out) # -> Bx256x64x64
if self.flag_use_occlusion_map and (occlusion_map is not None):
out = out * occlusion_map
ret_dct = {
'occlusion_map': occlusion_map,
'deformation': deformation,
'out': out,
}
return ret_dct

65
src/template_maker.py Normal file
View File

@ -0,0 +1,65 @@
# coding: utf-8
"""
Make video template
"""
import os
import cv2
import numpy as np
import pickle
from rich.progress import track
from .utils.cropper import Cropper
from .utils.io import load_driving_info
from .utils.camera import get_rotation_matrix
from .utils.helper import mkdir, basename
from .utils.rprint import rlog as log
from .config.crop_config import CropConfig
from .config.inference_config import InferenceConfig
from .live_portrait_wrapper import LivePortraitWrapper
class TemplateMaker:
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg)
def make_motion_template(self, video_fp: str, output_path: str, **kwargs):
""" make video template (.pkl format)
video_fp: driving video file path
output_path: where to save the pickle file
"""
driving_rgb_lst = load_driving_info(video_fp)
driving_rgb_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst)
n_frames = I_d_lst.shape[0]
templates = []
for i in track(range(n_frames), description='Making templates...', total=n_frames):
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
# collect s_d, R_d, δ_d and t_d for inference
template_dct = {
'n_frames': n_frames,
'frames_index': i,
}
template_dct['scale'] = x_d_i_info['scale'].cpu().numpy().astype(np.float32)
template_dct['R_d'] = R_d_i.cpu().numpy().astype(np.float32)
template_dct['exp'] = x_d_i_info['exp'].cpu().numpy().astype(np.float32)
template_dct['t'] = x_d_i_info['t'].cpu().numpy().astype(np.float32)
templates.append(template_dct)
mkdir(output_path)
# Save the dictionary as a pickle file
pickle_fp = os.path.join(output_path, f'{basename(video_fp)}.pkl')
with open(pickle_fp, 'wb') as f:
pickle.dump([templates, driving_lmk_lst], f)
log(f"Template saved at {pickle_fp}")

0
src/utils/__init__.py Normal file
View File

75
src/utils/camera.py Normal file
View File

@ -0,0 +1,75 @@
# coding: utf-8
"""
functions for processing and transforming 3D facial keypoints
"""
import numpy as np
import torch
import torch.nn.functional as F
PI = np.pi
def headpose_pred_to_degree(pred):
"""
pred: (bs, 66) or (bs, 1) or others
"""
if pred.ndim > 1 and pred.shape[1] == 66:
# NOTE: note that the average is modified to 97.5
device = pred.device
idx_tensor = [idx for idx in range(0, 66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred, dim=1)
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5
return degree
return pred
def get_rotation_matrix(pitch_, yaw_, roll_):
""" the input is in degree
"""
# calculate the rotation matrix: vps @ rot
# transform to radian
pitch = pitch_ / 180 * PI
yaw = yaw_ / 180 * PI
roll = roll_ / 180 * PI
device = pitch.device
if pitch.ndim == 1:
pitch = pitch.unsqueeze(1)
if yaw.ndim == 1:
yaw = yaw.unsqueeze(1)
if roll.ndim == 1:
roll = roll.unsqueeze(1)
# calculate the euler matrix
bs = pitch.shape[0]
ones = torch.ones([bs, 1]).to(device)
zeros = torch.zeros([bs, 1]).to(device)
x, y, z = pitch, yaw, roll
rot_x = torch.cat([
ones, zeros, zeros,
zeros, torch.cos(x), -torch.sin(x),
zeros, torch.sin(x), torch.cos(x)
], dim=1).reshape([bs, 3, 3])
rot_y = torch.cat([
torch.cos(y), zeros, torch.sin(y),
zeros, ones, zeros,
-torch.sin(y), zeros, torch.cos(y)
], dim=1).reshape([bs, 3, 3])
rot_z = torch.cat([
torch.cos(z), -torch.sin(z), zeros,
torch.sin(z), torch.cos(z), zeros,
zeros, zeros, ones
], dim=1).reshape([bs, 3, 3])
rot = rot_z @ rot_y @ rot_x
return rot.permute(0, 2, 1) # transpose

393
src/utils/crop.py Normal file
View File

@ -0,0 +1,393 @@
# coding: utf-8
"""
cropping function and the related preprocess functions for cropping
"""
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
import numpy as np
from .rprint import rprint as print
from math import sin, cos, acos, degrees
DTYPE = np.float32
CV2_INTERP = cv2.INTER_LINEAR
def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
""" conduct similarity or affine transformation to the image, do not do border operation!
img:
M: 2x3 matrix or 3x3 matrix
dsize: target shape (width, height)
"""
if isinstance(dsize, tuple) or isinstance(dsize, list):
_dsize = tuple(dsize)
else:
_dsize = (dsize, dsize)
if borderMode is not None:
return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0))
else:
return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags)
def _transform_pts(pts, M):
""" conduct similarity or affine transformation to the pts
pts: Nx2 ndarray
M: 2x3 matrix or 3x3 matrix
return: Nx2
"""
return pts @ M[:2, :2].T + M[:2, 2]
def parse_pt2_from_pt101(pt101, use_lip=True):
"""
parsing the 2 points according to the 101 points, which cancels the roll
"""
# the former version use the eye center, but it is not robust, now use interpolation
pt_left_eye = np.mean(pt101[[39, 42, 45, 48]], axis=0) # left eye center
pt_right_eye = np.mean(pt101[[51, 54, 57, 60]], axis=0) # right eye center
if use_lip:
# use lip
pt_center_eye = (pt_left_eye + pt_right_eye) / 2
pt_center_lip = (pt101[75] + pt101[81]) / 2
pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
else:
pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
return pt2
def parse_pt2_from_pt106(pt106, use_lip=True):
"""
parsing the 2 points according to the 106 points, which cancels the roll
"""
pt_left_eye = np.mean(pt106[[33, 35, 40, 39]], axis=0) # left eye center
pt_right_eye = np.mean(pt106[[87, 89, 94, 93]], axis=0) # right eye center
if use_lip:
# use lip
pt_center_eye = (pt_left_eye + pt_right_eye) / 2
pt_center_lip = (pt106[52] + pt106[61]) / 2
pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
else:
pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
return pt2
def parse_pt2_from_pt203(pt203, use_lip=True):
"""
parsing the 2 points according to the 203 points, which cancels the roll
"""
pt_left_eye = np.mean(pt203[[0, 6, 12, 18]], axis=0) # left eye center
pt_right_eye = np.mean(pt203[[24, 30, 36, 42]], axis=0) # right eye center
if use_lip:
# use lip
pt_center_eye = (pt_left_eye + pt_right_eye) / 2
pt_center_lip = (pt203[48] + pt203[66]) / 2
pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
else:
pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
return pt2
def parse_pt2_from_pt68(pt68, use_lip=True):
"""
parsing the 2 points according to the 68 points, which cancels the roll
"""
lm_idx = np.array([31, 37, 40, 43, 46, 49, 55], dtype=np.int32) - 1
if use_lip:
pt5 = np.stack([
np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye
np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye
pt68[lm_idx[0], :], # nose
pt68[lm_idx[5], :], # lip
pt68[lm_idx[6], :] # lip
], axis=0)
pt2 = np.stack([
(pt5[0] + pt5[1]) / 2,
(pt5[3] + pt5[4]) / 2
], axis=0)
else:
pt2 = np.stack([
np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye
np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye
], axis=0)
return pt2
def parse_pt2_from_pt5(pt5, use_lip=True):
"""
parsing the 2 points according to the 5 points, which cancels the roll
"""
if use_lip:
pt2 = np.stack([
(pt5[0] + pt5[1]) / 2,
(pt5[3] + pt5[4]) / 2
], axis=0)
else:
pt2 = np.stack([
pt5[0],
pt5[1]
], axis=0)
return pt2
def parse_pt2_from_pt_x(pts, use_lip=True):
if pts.shape[0] == 101:
pt2 = parse_pt2_from_pt101(pts, use_lip=use_lip)
elif pts.shape[0] == 106:
pt2 = parse_pt2_from_pt106(pts, use_lip=use_lip)
elif pts.shape[0] == 68:
pt2 = parse_pt2_from_pt68(pts, use_lip=use_lip)
elif pts.shape[0] == 5:
pt2 = parse_pt2_from_pt5(pts, use_lip=use_lip)
elif pts.shape[0] == 203:
pt2 = parse_pt2_from_pt203(pts, use_lip=use_lip)
elif pts.shape[0] > 101:
# take the first 101 points
pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip)
else:
raise Exception(f'Unknow shape: {pts.shape}')
if not use_lip:
# NOTE: to compile with the latter code, need to rotate the pt2 90 degrees clockwise manually
v = pt2[1] - pt2[0]
pt2[1, 0] = pt2[0, 0] - v[1]
pt2[1, 1] = pt2[0, 1] + v[0]
return pt2
def parse_rect_from_landmark(
pts,
scale=1.5,
need_square=True,
vx_ratio=0,
vy_ratio=0,
use_deg_flag=False,
**kwargs
):
"""parsing center, size, angle from 101/68/5/x landmarks
vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size
vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area
judge with pts.shape
"""
pt2 = parse_pt2_from_pt_x(pts, use_lip=kwargs.get('use_lip', True))
uy = pt2[1] - pt2[0]
l = np.linalg.norm(uy)
if l <= 1e-3:
uy = np.array([0, 1], dtype=DTYPE)
else:
uy /= l
ux = np.array((uy[1], -uy[0]), dtype=DTYPE)
# the rotation degree of the x-axis, the clockwise is positive, the counterclockwise is negative (image coordinate system)
# print(uy)
# print(ux)
angle = acos(ux[0])
if ux[1] < 0:
angle = -angle
# rotation matrix
M = np.array([ux, uy])
# calculate the size which contains the angle degree of the bbox, and the center
center0 = np.mean(pts, axis=0)
rpts = (pts - center0) @ M.T # (M @ P.T).T = P @ M.T
lt_pt = np.min(rpts, axis=0)
rb_pt = np.max(rpts, axis=0)
center1 = (lt_pt + rb_pt) / 2
size = rb_pt - lt_pt
if need_square:
m = max(size[0], size[1])
size[0] = m
size[1] = m
size *= scale # scale size
center = center0 + ux * center1[0] + uy * center1[1] # counterclockwise rotation, equivalent to M.T @ center1.T
center = center + ux * (vx_ratio * size) + uy * \
(vy_ratio * size) # considering the offset in vx and vy direction
if use_deg_flag:
angle = degrees(angle)
return center, size, angle
def parse_bbox_from_landmark(pts, **kwargs):
center, size, angle = parse_rect_from_landmark(pts, **kwargs)
cx, cy = center
w, h = size
# calculate the vertex positions before rotation
bbox = np.array([
[cx-w/2, cy-h/2], # left, top
[cx+w/2, cy-h/2],
[cx+w/2, cy+h/2], # right, bottom
[cx-w/2, cy+h/2]
], dtype=DTYPE)
# construct rotation matrix
bbox_rot = bbox.copy()
R = np.array([
[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]
], dtype=DTYPE)
# calculate the relative position of each vertex from the rotation center, then rotate these positions, and finally add the coordinates of the rotation center
bbox_rot = (bbox_rot - center) @ R.T + center
return {
'center': center, # 2x1
'size': size, # scalar
'angle': angle, # rad, counterclockwise
'bbox': bbox, # 4x2
'bbox_rot': bbox_rot, # 4x2
}
def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=False, **kwargs):
left, top, right, bot = bbox
if int(right - left) != int(bot - top):
print(f'right-left {right-left} != bot-top {bot-top}')
size = right - left
src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=DTYPE)
tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE)
s = dsize / size # scale
if flag_rot and angle is not None:
costheta, sintheta = cos(angle), sin(angle)
cx, cy = src_center[0], src_center[1] # ori center
tcx, tcy = tgt_center[0], tgt_center[1] # target center
# need to infer
M_o2c = np.array(
[[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)],
[-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]],
dtype=DTYPE
)
else:
M_o2c = np.array(
[[s, 0, tgt_center[0] - s * src_center[0]],
[0, s, tgt_center[1] - s * src_center[1]]],
dtype=DTYPE
)
if flag_rot and angle is None:
print('angle is None, but flag_rotate is True', style="bold yellow")
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
M_c2o = np.linalg.inv(M_o2c)
# cv2.imwrite('crop.jpg', img_crop)
return {
'img_crop': img_crop,
'lmk_crop': lmk_crop,
'M_o2c': M_o2c,
'M_c2o': M_c2o,
}
def _estimate_similar_transform_from_pts(
pts,
dsize,
scale=1.5,
vx_ratio=0,
vy_ratio=-0.1,
flag_do_rot=True,
**kwargs
):
""" calculate the affine matrix of the cropped image from sparse points, the original image to the cropped image, the inverse is the cropped image to the original image
pts: landmark, 101 or 68 points or other points, Nx2
scale: the larger scale factor, the smaller face ratio
vx_ratio: x shift
vy_ratio: y shift, the smaller the y shift, the lower the face region
rot_flag: if it is true, conduct correction
"""
center, size, angle = parse_rect_from_landmark(
pts, scale=scale, vx_ratio=vx_ratio, vy_ratio=vy_ratio,
use_lip=kwargs.get('use_lip', True)
)
s = dsize / size[0] # scale
tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) # center of dsize
if flag_do_rot:
costheta, sintheta = cos(angle), sin(angle)
cx, cy = center[0], center[1] # ori center
tcx, tcy = tgt_center[0], tgt_center[1] # target center
# need to infer
M_INV = np.array(
[[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)],
[-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]],
dtype=DTYPE
)
else:
M_INV = np.array(
[[s, 0, tgt_center[0] - s * center[0]],
[0, s, tgt_center[1] - s * center[1]]],
dtype=DTYPE
)
M_INV_H = np.vstack([M_INV, np.array([0, 0, 1])])
M = np.linalg.inv(M_INV_H)
# M_INV is from the original image to the cropped image, M is from the cropped image to the original image
return M_INV, M[:2, ...]
def crop_image(img, pts: np.ndarray, **kwargs):
dsize = kwargs.get('dsize', 224)
scale = kwargs.get('scale', 1.5) # 1.5 | 1.6
vy_ratio = kwargs.get('vy_ratio', -0.1) # -0.0625 | -0.1
M_INV, _ = _estimate_similar_transform_from_pts(
pts,
dsize=dsize,
scale=scale,
vy_ratio=vy_ratio,
flag_do_rot=kwargs.get('flag_do_rot', True),
)
if img is None:
M_INV_H = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
M = np.linalg.inv(M_INV_H)
ret_dct = {
'M': M[:2, ...], # from the original image to the cropped image
'M_o2c': M[:2, ...], # from the cropped image to the original image
'img_crop': None,
'pt_crop': None,
}
return ret_dct
img_crop = _transform_img(img, M_INV, dsize) # origin to crop
pt_crop = _transform_pts(pts, M_INV)
M_o2c = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
M_c2o = np.linalg.inv(M_o2c)
ret_dct = {
'M_o2c': M_o2c, # from the original image to the cropped image 3x3
'M_c2o': M_c2o, # from the cropped image to the original image 3x3
'img_crop': img_crop, # the cropped image
'pt_crop': pt_crop, # the landmarks of the cropped image
}
return ret_dct
def average_bbox_lst(bbox_lst):
if len(bbox_lst) == 0:
return None
bbox_arr = np.array(bbox_lst)
return np.mean(bbox_arr, axis=0).tolist()

143
src/utils/cropper.py Normal file
View File

@ -0,0 +1,143 @@
# coding: utf-8
import numpy as np
import os.path as osp
from typing import List, Union, Tuple
from dataclasses import dataclass, field
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
from .landmark_runner import LandmarkRunner
from .face_analysis_diy import FaceAnalysisDIY
from .helper import prefix
from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst
from .timer import Timer
from .rprint import rlog as log
from .io import load_image_rgb
from .video import VideoWriter, get_fps, change_video_fps
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
@dataclass
class Trajectory:
start: int = -1 # 起始帧 闭区间
end: int = -1 # 结束帧 闭区间
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
class Cropper(object):
def __init__(self, **kwargs) -> None:
device_id = kwargs.get('device_id', 0)
self.landmark_runner = LandmarkRunner(
ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'),
onnx_provider='cuda',
device_id=device_id
)
self.landmark_runner.warmup()
self.face_analysis_wrapper = FaceAnalysisDIY(
name='buffalo_l',
root=make_abs_path('../../pretrained_weights/insightface'),
providers=["CUDAExecutionProvider"]
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
self.face_analysis_wrapper.warmup()
self.crop_cfg = kwargs.get('crop_cfg', None)
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.crop_cfg, k):
setattr(self.crop_cfg, k, v)
def crop_single_image(self, obj, **kwargs):
direction = kwargs.get('direction', 'large-small')
# crop and align a single image
if isinstance(obj, str):
img_rgb = load_image_rgb(obj)
elif isinstance(obj, np.ndarray):
img_rgb = obj
src_face = self.face_analysis_wrapper.get(
img_rgb,
flag_do_landmark_2d_106=True,
direction=direction
)
if len(src_face) == 0:
log('No face detected in the source image.')
raise Exception("No face detected in the source image!")
elif len(src_face) > 1:
log(f'More than one face detected in the image, only pick one face by rule {direction}.')
src_face = src_face[0]
pts = src_face.landmark_2d_106
# crop the face
ret_dct = crop_image(
img_rgb, # ndarray
pts, # 106x2 or Nx2
dsize=kwargs.get('dsize', 512),
scale=kwargs.get('scale', 2.3),
vy_ratio=kwargs.get('vy_ratio', -0.15),
)
# update a 256x256 version for network input or else
ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct['pt_crop_256x256'] = ret_dct['pt_crop'] * 256 / kwargs.get('dsize', 512)
recon_ret = self.landmark_runner.run(img_rgb, pts)
lmk = recon_ret['pts']
ret_dct['lmk_crop'] = lmk
return ret_dct
def get_retargeting_lmk_info(self, driving_rgb_lst):
# TODO: implement a tracking-based version
driving_lmk_lst = []
for driving_image in driving_rgb_lst:
ret_dct = self.crop_single_image(driving_image)
driving_lmk_lst.append(ret_dct['lmk_crop'])
return driving_lmk_lst
def make_video_clip(self, driving_rgb_lst, output_path, output_fps=30, **kwargs):
trajectory = Trajectory()
direction = kwargs.get('direction', 'large-small')
for idx, driving_image in enumerate(driving_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
driving_image,
flag_do_landmark_2d_106=True,
direction=direction
)
if len(src_face) == 0:
# No face detected in the driving_image
continue
elif len(src_face) > 1:
log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.')
src_face = src_face[0]
pts = src_face.landmark_2d_106
lmk_203 = self.landmark_runner(driving_image, pts)['pts']
trajectory.start, trajectory.end = idx, idx
else:
lmk_203 = self.face_recon_wrapper(driving_image, trajectory.lmk_lst[-1])['pts']
trajectory.end = idx
trajectory.lmk_lst.append(lmk_203)
ret_bbox = parse_bbox_from_landmark(lmk_203, scale=self.crop_cfg.globalscale, vy_ratio=elf.crop_cfg.vy_ratio)['bbox']
bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4,
trajectory.bbox_lst.append(bbox) # bbox
trajectory.frame_rgb_lst.append(driving_image)
global_bbox = average_bbox_lst(trajectory.bbox_lst)
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
ret_dct = crop_image_by_bbox(
frame_rgb, global_bbox, lmk=lmk,
dsize=self.video_crop_cfg.dsize, flag_rot=self.video_crop_cfg.flag_rot, borderValue=self.video_crop_cfg.borderValue
)
frame_rgb_crop = ret_dct['img_crop']

View File

@ -0,0 +1,21 @@
# coding: utf-8
# pylint: disable=wrong-import-position
"""InsightFace: A Face Analysis Toolkit."""
from __future__ import absolute_import
try:
#import mxnet as mx
import onnxruntime
except ImportError:
raise ImportError(
"Unable to import dependency onnxruntime. "
)
__version__ = '0.7.3'
from . import model_zoo
from . import utils
from . import app
from . import data
from . import thirdparty

View File

@ -0,0 +1,2 @@
from .face_analysis import *
from .mask_renderer import *

View File

@ -0,0 +1,49 @@
import numpy as np
from numpy.linalg import norm as l2norm
#from easydict import EasyDict
class Face(dict):
def __init__(self, d=None, **kwargs):
if d is None:
d = {}
if kwargs:
d.update(**kwargs)
for k, v in d.items():
setattr(self, k, v)
# Class attributes
#for k in self.__class__.__dict__.keys():
# if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
# setattr(self, k, getattr(self, k))
def __setattr__(self, name, value):
if isinstance(value, (list, tuple)):
value = [self.__class__(x)
if isinstance(x, dict) else x for x in value]
elif isinstance(value, dict) and not isinstance(value, self.__class__):
value = self.__class__(value)
super(Face, self).__setattr__(name, value)
super(Face, self).__setitem__(name, value)
__setitem__ = __setattr__
def __getattr__(self, name):
return None
@property
def embedding_norm(self):
if self.embedding is None:
return None
return l2norm(self.embedding)
@property
def normed_embedding(self):
if self.embedding is None:
return None
return self.embedding / self.embedding_norm
@property
def sex(self):
if self.gender is None:
return None
return 'M' if self.gender==1 else 'F'

View File

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import glob
import os.path as osp
import numpy as np
import onnxruntime
from numpy.linalg import norm
from ..model_zoo import model_zoo
from ..utils import DEFAULT_MP_NAME, ensure_available
from .common import Face
__all__ = ['FaceAnalysis']
class FaceAnalysis:
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
onnxruntime.set_default_logger_severity(3)
self.models = {}
self.model_dir = ensure_available('models', name, root=root)
onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
onnx_files = sorted(onnx_files)
for onnx_file in onnx_files:
model = model_zoo.get_model(onnx_file, **kwargs)
if model is None:
print('model not recognized:', onnx_file)
elif allowed_modules is not None and model.taskname not in allowed_modules:
print('model ignore:', onnx_file, model.taskname)
del model
elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
# print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std)
self.models[model.taskname] = model
else:
print('duplicated model task type, ignore:', onnx_file, model.taskname)
del model
assert 'detection' in self.models
self.det_model = self.models['detection']
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
self.det_thresh = det_thresh
assert det_size is not None
# print('set det-size:', det_size)
self.det_size = det_size
for taskname, model in self.models.items():
if taskname=='detection':
model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
else:
model.prepare(ctx_id)
def get(self, img, max_num=0):
bboxes, kpss = self.det_model.detect(img,
max_num=max_num,
metric='default')
if bboxes.shape[0] == 0:
return []
ret = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i, 0:4]
det_score = bboxes[i, 4]
kps = None
if kpss is not None:
kps = kpss[i]
face = Face(bbox=bbox, kps=kps, det_score=det_score)
for taskname, model in self.models.items():
if taskname=='detection':
continue
model.get(img, face)
ret.append(face)
return ret
def draw_on(self, img, faces):
import cv2
dimg = img.copy()
for i in range(len(faces)):
face = faces[i]
box = face.bbox.astype(np.int)
color = (0, 0, 255)
cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2)
if face.kps is not None:
kps = face.kps.astype(np.int)
#print(landmark.shape)
for l in range(kps.shape[0]):
color = (0, 0, 255)
if l == 0 or l == 3:
color = (0, 255, 0)
cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color,
2)
if face.gender is not None and face.age is not None:
cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1)
#for key, value in face.items():
# if key.startswith('landmark_3d'):
# print(key, value.shape)
# print(value[0:10,:])
# lmk = np.round(value).astype(np.int)
# for l in range(lmk.shape[0]):
# color = (255, 0, 0)
# cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color,
# 2)
return dimg

View File

@ -0,0 +1,232 @@
import os, sys, datetime
import numpy as np
import os.path as osp
import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
from .face_analysis import FaceAnalysis
from ..utils import get_model_dir
from ..thirdparty import face3d
from ..data import get_image as ins_get_image
from ..utils import DEFAULT_MP_NAME
import cv2
class MaskRenderer:
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', insfa=None):
#if insfa is None, enter render_only mode
self.mp_name = name
self.root = root
self.insfa = insfa
model_dir = get_model_dir(name, root)
bfm_file = osp.join(model_dir, 'BFM.mat')
assert osp.exists(bfm_file), 'should contains BFM.mat in your model directory'
self.bfm = face3d.morphable_model.MorphabelModel(bfm_file)
self.index_ind = self.bfm.kpt_ind
bfm_uv_file = osp.join(model_dir, 'BFM_UV.mat')
assert osp.exists(bfm_uv_file), 'should contains BFM_UV.mat in your model directory'
uv_coords = face3d.morphable_model.load.load_uv_coords(bfm_uv_file)
self.uv_size = (224,224)
self.mask_stxr = 0.1
self.mask_styr = 0.33
self.mask_etxr = 0.9
self.mask_etyr = 0.7
self.tex_h , self.tex_w, self.tex_c = self.uv_size[1] , self.uv_size[0],3
texcoord = np.zeros_like(uv_coords)
texcoord[:, 0] = uv_coords[:, 0] * (self.tex_h - 1)
texcoord[:, 1] = uv_coords[:, 1] * (self.tex_w - 1)
texcoord[:, 1] = self.tex_w - texcoord[:, 1] - 1
self.texcoord = np.hstack((texcoord, np.zeros((texcoord.shape[0], 1))))
self.X_ind = self.bfm.kpt_ind
self.mask_image_names = ['mask_white', 'mask_blue', 'mask_black', 'mask_green']
self.mask_aug_probs = [0.4, 0.4, 0.1, 0.1]
#self.mask_images = []
#self.mask_images_rgb = []
#for image_name in mask_image_names:
# mask_image = ins_get_image(image_name)
# self.mask_images.append(mask_image)
# mask_image_rgb = mask_image[:,:,::-1]
# self.mask_images_rgb.append(mask_image_rgb)
def prepare(self, ctx_id=0, det_thresh=0.5, det_size=(128, 128)):
self.pre_ctx_id = ctx_id
self.pre_det_thresh = det_thresh
self.pre_det_size = det_size
def transform(self, shape3D, R):
s = 1.0
shape3D[:2, :] = shape3D[:2, :]
shape3D = s * np.dot(R, shape3D)
return shape3D
def preprocess(self, vertices, w, h):
R1 = face3d.mesh.transform.angle2matrix([0, 180, 180])
t = np.array([-w // 2, -h // 2, 0])
vertices = vertices.T
vertices += t
vertices = self.transform(vertices.T, R1).T
return vertices
def project_to_2d(self,vertices,s,angles,t):
transformed_vertices = self.bfm.transform(vertices, s, angles, t)
projected_vertices = transformed_vertices.copy() # using stantard camera & orth projection
return projected_vertices[self.bfm.kpt_ind, :2]
def params_to_vertices(self,params , H , W):
fitted_sp, fitted_ep, fitted_s, fitted_angles, fitted_t = params
fitted_vertices = self.bfm.generate_vertices(fitted_sp, fitted_ep)
transformed_vertices = self.bfm.transform(fitted_vertices, fitted_s, fitted_angles,
fitted_t)
transformed_vertices = self.preprocess(transformed_vertices.T, W, H)
image_vertices = face3d.mesh.transform.to_image(transformed_vertices, H, W)
return image_vertices
def draw_lmk(self, face_image):
faces = self.insfa.get(face_image, max_num=1)
if len(faces)==0:
return face_image
return self.insfa.draw_on(face_image, faces)
def build_params(self, face_image):
#landmark = self.if3d68_handler.get(face_image)
#if landmark is None:
# return None #face not found
if self.insfa is None:
self.insfa = FaceAnalysis(name=self.mp_name, root=self.root, allowed_modules=['detection', 'landmark_3d_68'])
self.insfa.prepare(ctx_id=self.pre_ctx_id, det_thresh=self.pre_det_thresh, det_size=self.pre_det_size)
faces = self.insfa.get(face_image, max_num=1)
if len(faces)==0:
return None
landmark = faces[0].landmark_3d_68[:,:2]
fitted_sp, fitted_ep, fitted_s, fitted_angles, fitted_t = self.bfm.fit(landmark, self.X_ind, max_iter = 3)
return [fitted_sp, fitted_ep, fitted_s, fitted_angles, fitted_t]
def generate_mask_uv(self,mask, positions):
uv_size = (self.uv_size[1], self.uv_size[0], 3)
h, w, c = uv_size
uv = np.zeros(shape=(self.uv_size[1],self.uv_size[0], 3), dtype=np.uint8)
stxr, styr = positions[0], positions[1]
etxr, etyr = positions[2], positions[3]
stx, sty = int(w * stxr), int(h * styr)
etx, ety = int(w * etxr), int(h * etyr)
height = ety - sty
width = etx - stx
mask = cv2.resize(mask, (width, height))
uv[sty:ety, stx:etx] = mask
return uv
def render_mask(self,face_image, mask_image, params, input_is_rgb=False, auto_blend = True, positions=[0.1, 0.33, 0.9, 0.7]):
if isinstance(mask_image, str):
to_rgb = True if input_is_rgb else False
mask_image = ins_get_image(mask_image, to_rgb=to_rgb)
uv_mask_image = self.generate_mask_uv(mask_image, positions)
h,w,c = face_image.shape
image_vertices = self.params_to_vertices(params ,h,w)
output = (1-face3d.mesh.render.render_texture(image_vertices, self.bfm.full_triangles , uv_mask_image, self.texcoord, self.bfm.full_triangles, h , w ))*255
output = output.astype(np.uint8)
if auto_blend:
mask_bd = (output==255).astype(np.uint8)
final = face_image*mask_bd + (1-mask_bd)*output
return final
return output
#def mask_augmentation(self, face_image, label, input_is_rgb=False, p=0.1):
# if np.random.random()<p:
# assert isinstance(label, (list, np.ndarray)), 'make sure the rec dataset includes mask params'
# assert len(label)==237 or len(lable)==235, 'make sure the rec dataset includes mask params'
# if len(label)==237:
# if label[1]<0.0: #invalid label for mask aug
# return face_image
# label = label[2:]
# params = self.decode_params(label)
# mask_image_name = np.random.choice(self.mask_image_names, p=self.mask_aug_probs)
# pos = np.random.uniform(0.33, 0.5)
# face_image = self.render_mask(face_image, mask_image_name, params, input_is_rgb=input_is_rgb, positions=[0.1, pos, 0.9, 0.7])
# return face_image
@staticmethod
def encode_params(params):
p0 = list(params[0])
p1 = list(params[1])
p2 = [float(params[2])]
p3 = list(params[3])
p4 = list(params[4])
return p0+p1+p2+p3+p4
@staticmethod
def decode_params(params):
p0 = params[0:199]
p0 = np.array(p0, dtype=np.float32).reshape( (-1, 1))
p1 = params[199:228]
p1 = np.array(p1, dtype=np.float32).reshape( (-1, 1))
p2 = params[228]
p3 = tuple(params[229:232])
p4 = params[232:235]
p4 = np.array(p4, dtype=np.float32).reshape( (-1, 1))
return p0, p1, p2, p3, p4
class MaskAugmentation(ImageOnlyTransform):
def __init__(
self,
mask_names=['mask_white', 'mask_blue', 'mask_black', 'mask_green'],
mask_probs=[0.4,0.4,0.1,0.1],
h_low = 0.33,
h_high = 0.35,
always_apply=False,
p=1.0,
):
super(MaskAugmentation, self).__init__(always_apply, p)
self.renderer = MaskRenderer()
assert len(mask_names)>0
assert len(mask_names)==len(mask_probs)
self.mask_names = mask_names
self.mask_probs = mask_probs
self.h_low = h_low
self.h_high = h_high
#self.hlabel = None
def apply(self, image, hlabel, mask_name, h_pos, **params):
#print(params.keys())
#hlabel = params.get('hlabel')
assert len(hlabel)==237 or len(hlabel)==235, 'make sure the rec dataset includes mask params'
if len(hlabel)==237:
if hlabel[1]<0.0:
return image
hlabel = hlabel[2:]
#print(len(hlabel))
mask_params = self.renderer.decode_params(hlabel)
image = self.renderer.render_mask(image, mask_name, mask_params, input_is_rgb=True, positions=[0.1, h_pos, 0.9, 0.7])
return image
@property
def targets_as_params(self):
return ["image", "hlabel"]
def get_params_dependent_on_targets(self, params):
hlabel = params['hlabel']
mask_name = np.random.choice(self.mask_names, p=self.mask_probs)
h_pos = np.random.uniform(self.h_low, self.h_high)
return {'hlabel': hlabel, 'mask_name': mask_name, 'h_pos': h_pos}
def get_transform_init_args_names(self):
#return ("hlabel", 'mask_names', 'mask_probs', 'h_low', 'h_high')
return ('mask_names', 'mask_probs', 'h_low', 'h_high')
if __name__ == "__main__":
tool = MaskRenderer('antelope')
tool.prepare(det_size=(128,128))
image = cv2.imread("Tom_Hanks_54745.png")
params = tool.build_params(image)
#out = tool.draw_lmk(image)
#cv2.imwrite('output_lmk.jpg', out)
#mask_image = cv2.imread("masks/mask1.jpg")
#mask_image = cv2.imread("masks/black-mask.png")
#mask_image = cv2.imread("masks/mask2.jpg")
mask_out = tool.render_mask(image, 'mask_blue', params)# use single thread to test the time cost
cv2.imwrite('output_mask.jpg', mask_out)

View File

@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
from argparse import ArgumentParser
class BaseInsightFaceCLICommand(ABC):
@staticmethod
@abstractmethod
def register_subcommand(parser: ArgumentParser):
raise NotImplementedError()
@abstractmethod
def run(self):
raise NotImplementedError()

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python
from argparse import ArgumentParser
from .model_download import ModelDownloadCommand
from .rec_add_mask_param import RecAddMaskParamCommand
def main():
parser = ArgumentParser("InsightFace CLI tool", usage="insightface-cli <command> [<args>]")
commands_parser = parser.add_subparsers(help="insightface-cli command-line helpers")
# Register commands
ModelDownloadCommand.register_subcommand(commands_parser)
RecAddMaskParamCommand.register_subcommand(commands_parser)
args = parser.parse_args()
if not hasattr(args, "func"):
parser.print_help()
exit(1)
# Run
service = args.func(args)
service.run()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,36 @@
from argparse import ArgumentParser
from . import BaseInsightFaceCLICommand
import os
import os.path as osp
import zipfile
import glob
from ..utils import download
def model_download_command_factory(args):
return ModelDownloadCommand(args.model, args.root, args.force)
class ModelDownloadCommand(BaseInsightFaceCLICommand):
#_url_format = '{repo_url}models/{file_name}.zip'
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser("model.download")
download_parser.add_argument(
"--root", type=str, default='~/.insightface', help="Path to location to store the models"
)
download_parser.add_argument(
"--force", action="store_true", help="Force the model to be download even if already in root-dir"
)
download_parser.add_argument("model", type=str, help="Name of the model to download")
download_parser.set_defaults(func=model_download_command_factory)
def __init__(self, model: str, root: str, force: bool):
self._model = model
self._root = root
self._force = force
def run(self):
download('models', self._model, force=self._force, root=self._root)

View File

@ -0,0 +1,94 @@
import numbers
import os
from argparse import ArgumentParser, Namespace
import mxnet as mx
import numpy as np
from ..app import MaskRenderer
from ..data.rec_builder import RecBuilder
from . import BaseInsightFaceCLICommand
def rec_add_mask_param_command_factory(args: Namespace):
return RecAddMaskParamCommand(
args.input, args.output
)
class RecAddMaskParamCommand(BaseInsightFaceCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
_parser = parser.add_parser("rec.addmaskparam")
_parser.add_argument("input", type=str, help="input rec")
_parser.add_argument("output", type=str, help="output rec, with mask param")
_parser.set_defaults(func=rec_add_mask_param_command_factory)
def __init__(
self,
input: str,
output: str,
):
self._input = input
self._output = output
def run(self):
tool = MaskRenderer()
tool.prepare(ctx_id=0, det_size=(128,128))
root_dir = self._input
path_imgrec = os.path.join(root_dir, 'train.rec')
path_imgidx = os.path.join(root_dir, 'train.idx')
imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
save_path = self._output
wrec=RecBuilder(path=save_path)
s = imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
if header.flag > 0:
if len(header.label)==2:
imgidx = np.array(range(1, int(header.label[0])))
else:
imgidx = np.array(list(self.imgrec.keys))
else:
imgidx = np.array(list(self.imgrec.keys))
stat = [0, 0]
print('total:', len(imgidx))
for iid, idx in enumerate(imgidx):
#if iid==500000:
# break
if iid%1000==0:
print('processing:', iid)
s = imgrec.read_idx(idx)
header, img = mx.recordio.unpack(s)
label = header.label
if not isinstance(label, numbers.Number):
label = label[0]
sample = mx.image.imdecode(img).asnumpy()
bgr = sample[:,:,::-1]
params = tool.build_params(bgr)
#if iid<10:
# mask_out = tool.render_mask(bgr, 'mask_blue', params)
# cv2.imwrite('maskout_%d.jpg'%iid, mask_out)
stat[1] += 1
if params is None:
wlabel = [label] + [-1.0]*236
stat[0] += 1
else:
#print(0, params[0].shape, params[0].dtype)
#print(1, params[1].shape, params[1].dtype)
#print(2, params[2])
#print(3, len(params[3]), params[3][0].__class__)
#print(4, params[4].shape, params[4].dtype)
mask_label = tool.encode_params(params)
wlabel = [label, 0.0]+mask_label # 237 including idlabel, total mask params size is 235
if iid==0:
print('param size:', len(mask_label), len(wlabel), label)
assert len(wlabel)==237
wrec.add_image(img, wlabel)
#print(len(params))
wrec.close()
print('finished on', self._output, ', failed:', stat[0])

View File

@ -0,0 +1,2 @@
from .image import get_image
from .pickle_object import get_object

View File

@ -0,0 +1,27 @@
import cv2
import os
import os.path as osp
from pathlib import Path
class ImageCache:
data = {}
def get_image(name, to_rgb=False):
key = (name, to_rgb)
if key in ImageCache.data:
return ImageCache.data[key]
images_dir = osp.join(Path(__file__).parent.absolute(), 'images')
ext_names = ['.jpg', '.png', '.jpeg']
image_file = None
for ext_name in ext_names:
_image_file = osp.join(images_dir, "%s%s"%(name, ext_name))
if osp.exists(_image_file):
image_file = _image_file
break
assert image_file is not None, '%s not found'%name
img = cv2.imread(image_file)
if to_rgb:
img = img[:,:,::-1]
ImageCache.data[key] = img
return img

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 126 KiB

View File

@ -0,0 +1,17 @@
import cv2
import os
import os.path as osp
from pathlib import Path
import pickle
def get_object(name):
objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects')
if not name.endswith('.pkl'):
name = name+".pkl"
filepath = osp.join(objects_dir, name)
if not osp.exists(filepath):
return None
with open(filepath, 'rb') as f:
obj = pickle.load(f)
return obj

View File

@ -0,0 +1,71 @@
import pickle
import numpy as np
import os
import os.path as osp
import sys
import mxnet as mx
class RecBuilder():
def __init__(self, path, image_size=(112, 112)):
self.path = path
self.image_size = image_size
self.widx = 0
self.wlabel = 0
self.max_label = -1
assert not osp.exists(path), '%s exists' % path
os.makedirs(path)
self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'),
os.path.join(path, 'train.rec'),
'w')
self.meta = []
def add(self, imgs):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
assert len(imgs) > 0
label = self.wlabel
for img in imgs:
idx = self.widx
image_meta = {'image_index': idx, 'image_classes': [label]}
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = label
self.wlabel += 1
def add_image(self, img, label):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
idx = self.widx
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(label, list):
idlabel = label[0]
else:
idlabel = label
image_meta = {'image_index': idx, 'image_classes': [idlabel]}
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = max(self.max_label, idlabel)
def close(self):
with open(osp.join(self.path, 'train.meta'), 'wb') as pfile:
pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL)
print('stat:', self.widx, self.wlabel)
with open(os.path.join(self.path, 'property'), 'w') as f:
f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1]))
f.write("%d\n" % (self.widx))

View File

@ -0,0 +1,6 @@
from .model_zoo import get_model
from .arcface_onnx import ArcFaceONNX
from .retinaface import RetinaFace
from .scrfd import SCRFD
from .landmark import Landmark
from .attribute import Attribute

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
__all__ = [
'ArcFaceONNX',
]
class ArcFaceONNX:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
self.taskname = 'recognition'
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 127.5
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
self.output_shape = outputs[0].shape
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])
face.embedding = self.get_feat(aimg).flatten()
return face.embedding
def compute_sim(self, feat1, feat2):
from numpy.linalg import norm
feat1 = feat1.ravel()
feat2 = feat2.ravel()
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
return sim
def get_feat(self, imgs):
if not isinstance(imgs, list):
imgs = [imgs]
input_size = self.input_size
blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out
def forward(self, batch_data):
blob = (batch_data - self.input_mean) / self.input_std
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out

View File

@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-06-19
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
__all__ = [
'Attribute',
]
class Attribute:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if nid<3 and node.name=='bn_data':
find_sub = True
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 128.0
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
#print('init output_shape:', output_shape)
if output_shape[1]==3:
self.taskname = 'genderage'
else:
self.taskname = 'attribute_%d'%output_shape[1]
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
bbox = face.bbox
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
rotate = 0
_scale = self.input_size[0] / (max(w, h)*1.5)
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
if self.taskname=='genderage':
assert len(pred)==3
gender = np.argmax(pred[:2])
age = int(np.round(pred[2]*100))
face['gender'] = gender
face['age'] = age
return gender, age
else:
return pred

View File

@ -0,0 +1,114 @@
import time
import numpy as np
import onnxruntime
import cv2
import onnx
from onnx import numpy_helper
from ..utils import face_align
class INSwapper():
def __init__(self, model_file=None, session=None):
self.model_file = model_file
self.session = session
model = onnx.load(self.model_file)
graph = model.graph
self.emap = numpy_helper.to_array(graph.initializer[-1])
self.input_mean = 0.0
self.input_std = 255.0
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
inputs = self.session.get_inputs()
self.input_names = []
for inp in inputs:
self.input_names.append(inp.name)
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
input_cfg = inputs[0]
input_shape = input_cfg.shape
self.input_shape = input_shape
# print('inswapper-shape:', self.input_shape)
self.input_size = tuple(input_shape[2:4][::-1])
def forward(self, img, latent):
img = (img - self.input_mean) / self.input_std
pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0]
return pred
def get(self, img, target_face, source_face, paste_back=True):
face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1)
aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0])
blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
latent = source_face.normed_embedding.reshape((1,-1))
latent = np.dot(latent, self.emap)
latent /= np.linalg.norm(latent)
pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0]
#print(latent.shape, latent.dtype, pred.shape)
img_fake = pred.transpose((0,2,3,1))[0]
bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
if not paste_back:
return bgr_fake, M
else:
target_img = img
fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32)
fake_diff = np.abs(fake_diff).mean(axis=2)
fake_diff[:2,:] = 0
fake_diff[-2:,:] = 0
fake_diff[:,:2] = 0
fake_diff[:,-2:] = 0
IM = cv2.invertAffineTransform(M)
img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32)
bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
img_white[img_white>20] = 255
fthresh = 10
fake_diff[fake_diff<fthresh] = 0
fake_diff[fake_diff>=fthresh] = 255
img_mask = img_white
mask_h_inds, mask_w_inds = np.where(img_mask==255)
mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
mask_size = int(np.sqrt(mask_h*mask_w))
k = max(mask_size//10, 10)
#k = max(mask_size//20, 6)
#k = 6
kernel = np.ones((k,k),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel = np.ones((2,2),np.uint8)
fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1)
face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1)
fake_diff[face_mask==1] = 255
k = max(mask_size//20, 5)
#k = 3
#k = 3
kernel_size = (k, k)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
k = 5
kernel_size = (k, k)
blur_size = tuple(2*i+1 for i in kernel_size)
fake_diff = cv2.blur(fake_diff, (11,11), 0)
##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0)
# print('blur_size: ', blur_size)
# fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size
img_mask /= 255
fake_diff /= 255
# img_mask = fake_diff
img_mask = img_mask*fake_diff
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32)
fake_merged = fake_merged.astype(np.uint8)
return fake_merged

View File

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
from ..utils import transform
from ..data import get_object
__all__ = [
'Landmark',
]
class Landmark:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if nid<3 and node.name=='bn_data':
find_sub = True
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 128.0
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
self.require_pose = False
#print('init output_shape:', output_shape)
if output_shape[1]==3309:
self.lmk_dim = 3
self.lmk_num = 68
self.mean_lmk = get_object('meanshape_68.pkl')
self.require_pose = True
else:
self.lmk_dim = 2
self.lmk_num = output_shape[1]//self.lmk_dim
self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num)
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
bbox = face.bbox
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
rotate = 0
_scale = self.input_size[0] / (max(w, h)*1.5)
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
if pred.shape[0] >= 3000:
pred = pred.reshape((-1, 3))
else:
pred = pred.reshape((-1, 2))
if self.lmk_num < pred.shape[0]:
pred = pred[self.lmk_num*-1:,:]
pred[:, 0:2] += 1
pred[:, 0:2] *= (self.input_size[0] // 2)
if pred.shape[1] == 3:
pred[:, 2] *= (self.input_size[0] // 2)
IM = cv2.invertAffineTransform(M)
pred = face_align.trans_points(pred, IM)
face[self.taskname] = pred
if self.require_pose:
P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred)
s, R, t = transform.P2sRt(P)
rx, ry, rz = transform.matrix2angle(R)
pose = np.array( [rx, ry, rz], dtype=np.float32 )
face['pose'] = pose #pitch, yaw, roll
return pred

View File

@ -0,0 +1,103 @@
"""
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py
"""
from __future__ import print_function
__all__ = ['get_model_file']
import os
import zipfile
import glob
from ..utils import download, check_sha1
_model_sha1 = {
name: checksum
for checksum, name in [
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
('', 'arcface_mfn_v1'),
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'),
('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'),
]
}
base_repo_url = 'https://insightface.ai/files/'
_url_format = '{repo_url}models/{file_name}.zip'
def short_hash(name):
if name not in _model_sha1:
raise ValueError(
'Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]
def find_params_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.params" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
The root directory will be created if it doesn't exist.
Parameters
----------
name : str
Name of the model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Returns
-------
file_path
Path to the requested pretrained model file.
"""
file_name = name
root = os.path.expanduser(root)
dir_path = os.path.join(root, name)
file_path = find_params_file(dir_path)
#file_path = os.path.join(root, file_name + '.params')
sha1_hash = _model_sha1[name]
if file_path is not None:
if check_sha1(file_path, sha1_hash):
return file_path
else:
print(
'Mismatch in the content of model file detected. Downloading again.'
)
else:
print('Model file is not found. Downloading.')
if not os.path.exists(root):
os.makedirs(root)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
zip_file_path = os.path.join(root, file_name + '.zip')
repo_url = base_repo_url
if repo_url[-1] != '/':
repo_url = repo_url + '/'
download(_url_format.format(repo_url=repo_url, file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(dir_path)
os.remove(zip_file_path)
file_path = find_params_file(dir_path)
if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError(
'Downloaded file has different hash. Please try again.')

View File

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
import os
import os.path as osp
import glob
import onnxruntime
from .arcface_onnx import *
from .retinaface import *
#from .scrfd import *
from .landmark import *
from .attribute import Attribute
from .inswapper import INSwapper
from ..utils import download_onnx
__all__ = ['get_model']
class PickableInferenceSession(onnxruntime.InferenceSession):
# This is a wrapper to make the current InferenceSession class pickable.
def __init__(self, model_path, **kwargs):
super().__init__(model_path, **kwargs)
self.model_path = model_path
def __getstate__(self):
return {'model_path': self.model_path}
def __setstate__(self, values):
model_path = values['model_path']
self.__init__(model_path)
class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file
def get_model(self, **kwargs):
session = PickableInferenceSession(self.onnx_file, **kwargs)
# print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
inputs = session.get_inputs()
input_cfg = inputs[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
if len(outputs)>=5:
return RetinaFace(model_file=self.onnx_file, session=session)
elif input_shape[2]==192 and input_shape[3]==192:
return Landmark(model_file=self.onnx_file, session=session)
elif input_shape[2]==96 and input_shape[3]==96:
return Attribute(model_file=self.onnx_file, session=session)
elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128:
return INSwapper(model_file=self.onnx_file, session=session)
elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
#raise RuntimeError('error on model routing')
return None
def find_onnx_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.onnx" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_default_providers():
return ['CUDAExecutionProvider', 'CPUExecutionProvider']
def get_default_provider_options():
return None
def get_model(name, **kwargs):
root = kwargs.get('root', '~/.insightface')
root = os.path.expanduser(root)
model_root = osp.join(root, 'models')
allow_download = kwargs.get('download', False)
download_zip = kwargs.get('download_zip', False)
if not name.endswith('.onnx'):
model_dir = os.path.join(model_root, name)
model_file = find_onnx_file(model_dir)
if model_file is None:
return None
else:
model_file = name
if not osp.exists(model_file) and allow_download:
model_file = download_onnx('models', model_file, root=root, download_zip=download_zip)
assert osp.exists(model_file), 'model_file %s should exist'%model_file
assert osp.isfile(model_file), 'model_file %s should be a file'%model_file
router = ModelRouter(model_file)
providers = kwargs.get('providers', get_default_providers())
provider_options = kwargs.get('provider_options', get_default_provider_options())
model = router.get_model(providers=providers, provider_options=provider_options)
return model

View File

@ -0,0 +1,301 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-09-18
# @Function :
from __future__ import division
import datetime
import numpy as np
import onnx
import onnxruntime
import os
import os.path as osp
import cv2
import sys
def softmax(z):
assert len(z.shape) == 2
s = np.max(z, axis=1)
s = s[:, np.newaxis] # necessary step to do broadcasting
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return np.stack([x1, y1, x2, y2], axis=-1)
def distance2kps(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
preds = []
for i in range(0, distance.shape[1], 2):
px = points[:, i%2] + distance[:, i]
py = points[:, i%2+1] + distance[:, i+1]
if max_shape is not None:
px = px.clamp(min=0, max=max_shape[1])
py = py.clamp(min=0, max=max_shape[0])
preds.append(px)
preds.append(py)
return np.stack(preds, axis=-1)
class RetinaFace:
def __init__(self, model_file=None, session=None):
import onnxruntime
self.model_file = model_file
self.session = session
self.taskname = 'detection'
if self.session is None:
assert self.model_file is not None
assert osp.exists(self.model_file)
self.session = onnxruntime.InferenceSession(self.model_file, None)
self.center_cache = {}
self.nms_thresh = 0.4
self.det_thresh = 0.5
self._init_vars()
def _init_vars(self):
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
#print(input_shape)
if isinstance(input_shape[2], str):
self.input_size = None
else:
self.input_size = tuple(input_shape[2:4][::-1])
#print('image_size:', self.image_size)
input_name = input_cfg.name
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for o in outputs:
output_names.append(o.name)
self.input_name = input_name
self.output_names = output_names
self.input_mean = 127.5
self.input_std = 128.0
#print(self.output_names)
#assert len(outputs)==10 or len(outputs)==15
self.use_kps = False
self._anchor_ratio = 1.0
self._num_anchors = 1
if len(outputs)==6:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
elif len(outputs)==9:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
self.use_kps = True
elif len(outputs)==10:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
elif len(outputs)==15:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
self.use_kps = True
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
nms_thresh = kwargs.get('nms_thresh', None)
if nms_thresh is not None:
self.nms_thresh = nms_thresh
det_thresh = kwargs.get('det_thresh', None)
if det_thresh is not None:
self.det_thresh = det_thresh
input_size = kwargs.get('input_size', None)
if input_size is not None:
if self.input_size is not None:
print('warning: det_size is already set in detection model, ignore')
else:
self.input_size = input_size
def forward(self, img, threshold):
scores_list = []
bboxes_list = []
kpss_list = []
input_size = tuple(img.shape[0:2][::-1])
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_outs = self.session.run(self.output_names, {self.input_name : blob})
input_height = blob.shape[2]
input_width = blob.shape[3]
fmc = self.fmc
for idx, stride in enumerate(self._feat_stride_fpn):
scores = net_outs[idx]
bbox_preds = net_outs[idx+fmc]
bbox_preds = bbox_preds * stride
if self.use_kps:
kps_preds = net_outs[idx+fmc*2] * stride
height = input_height // stride
width = input_width // stride
K = height * width
key = (height, width, stride)
if key in self.center_cache:
anchor_centers = self.center_cache[key]
else:
#solution-1, c style:
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
#for i in range(height):
# anchor_centers[i, :, 1] = i
#for i in range(width):
# anchor_centers[:, i, 0] = i
#solution-2:
#ax = np.arange(width, dtype=np.float32)
#ay = np.arange(height, dtype=np.float32)
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
#solution-3:
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
#print(anchor_centers.shape)
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
if self._num_anchors>1:
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
if len(self.center_cache)<100:
self.center_cache[key] = anchor_centers
pos_inds = np.where(scores>=threshold)[0]
bboxes = distance2bbox(anchor_centers, bbox_preds)
pos_scores = scores[pos_inds]
pos_bboxes = bboxes[pos_inds]
scores_list.append(pos_scores)
bboxes_list.append(pos_bboxes)
if self.use_kps:
kpss = distance2kps(anchor_centers, kps_preds)
#kpss = kps_preds
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
pos_kpss = kpss[pos_inds]
kpss_list.append(pos_kpss)
return scores_list, bboxes_list, kpss_list
def detect(self, img, input_size = None, max_num=0, metric='default'):
assert input_size is not None or self.input_size is not None
input_size = self.input_size if input_size is None else input_size
im_ratio = float(img.shape[0]) / img.shape[1]
model_ratio = float(input_size[1]) / input_size[0]
if im_ratio>model_ratio:
new_height = input_size[1]
new_width = int(new_height / im_ratio)
else:
new_width = input_size[0]
new_height = int(new_width * im_ratio)
det_scale = float(new_height) / img.shape[0]
resized_img = cv2.resize(img, (new_width, new_height))
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
det_img[:new_height, :new_width, :] = resized_img
scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
scores = np.vstack(scores_list)
scores_ravel = scores.ravel()
order = scores_ravel.argsort()[::-1]
bboxes = np.vstack(bboxes_list) / det_scale
if self.use_kps:
kpss = np.vstack(kpss_list) / det_scale
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
pre_det = pre_det[order, :]
keep = self.nms(pre_det)
det = pre_det[keep, :]
if self.use_kps:
kpss = kpss[order,:,:]
kpss = kpss[keep,:,:]
else:
kpss = None
if max_num > 0 and det.shape[0] > max_num:
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
det[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
if metric=='max':
values = area
else:
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
bindex = np.argsort(
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num]
det = det[bindex, :]
if kpss is not None:
kpss = kpss[bindex, :]
return det, kpss
def nms(self, dets):
thresh = self.nms_thresh
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def get_retinaface(name, download=False, root='~/.insightface/models', **kwargs):
if not download:
assert os.path.exists(name)
return RetinaFace(name)
else:
from .model_store import get_model_file
_file = get_model_file("retinaface_%s" % name, root=root)
return retinaface(_file)

View File

@ -0,0 +1,348 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import datetime
import numpy as np
import onnx
import onnxruntime
import os
import os.path as osp
import cv2
import sys
def softmax(z):
assert len(z.shape) == 2
s = np.max(z, axis=1)
s = s[:, np.newaxis] # necessary step to do broadcasting
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return np.stack([x1, y1, x2, y2], axis=-1)
def distance2kps(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
preds = []
for i in range(0, distance.shape[1], 2):
px = points[:, i%2] + distance[:, i]
py = points[:, i%2+1] + distance[:, i+1]
if max_shape is not None:
px = px.clamp(min=0, max=max_shape[1])
py = py.clamp(min=0, max=max_shape[0])
preds.append(px)
preds.append(py)
return np.stack(preds, axis=-1)
class SCRFD:
def __init__(self, model_file=None, session=None):
import onnxruntime
self.model_file = model_file
self.session = session
self.taskname = 'detection'
self.batched = False
if self.session is None:
assert self.model_file is not None
assert osp.exists(self.model_file)
self.session = onnxruntime.InferenceSession(self.model_file, None)
self.center_cache = {}
self.nms_thresh = 0.4
self.det_thresh = 0.5
self._init_vars()
def _init_vars(self):
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
#print(input_shape)
if isinstance(input_shape[2], str):
self.input_size = None
else:
self.input_size = tuple(input_shape[2:4][::-1])
#print('image_size:', self.image_size)
input_name = input_cfg.name
self.input_shape = input_shape
outputs = self.session.get_outputs()
if len(outputs[0].shape) == 3:
self.batched = True
output_names = []
for o in outputs:
output_names.append(o.name)
self.input_name = input_name
self.output_names = output_names
self.input_mean = 127.5
self.input_std = 128.0
#print(self.output_names)
#assert len(outputs)==10 or len(outputs)==15
self.use_kps = False
self._anchor_ratio = 1.0
self._num_anchors = 1
if len(outputs)==6:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
elif len(outputs)==9:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
self.use_kps = True
elif len(outputs)==10:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
elif len(outputs)==15:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
self.use_kps = True
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
nms_thresh = kwargs.get('nms_thresh', None)
if nms_thresh is not None:
self.nms_thresh = nms_thresh
det_thresh = kwargs.get('det_thresh', None)
if det_thresh is not None:
self.det_thresh = det_thresh
input_size = kwargs.get('input_size', None)
if input_size is not None:
if self.input_size is not None:
print('warning: det_size is already set in scrfd model, ignore')
else:
self.input_size = input_size
def forward(self, img, threshold):
scores_list = []
bboxes_list = []
kpss_list = []
input_size = tuple(img.shape[0:2][::-1])
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_outs = self.session.run(self.output_names, {self.input_name : blob})
input_height = blob.shape[2]
input_width = blob.shape[3]
fmc = self.fmc
for idx, stride in enumerate(self._feat_stride_fpn):
# If model support batch dim, take first output
if self.batched:
scores = net_outs[idx][0]
bbox_preds = net_outs[idx + fmc][0]
bbox_preds = bbox_preds * stride
if self.use_kps:
kps_preds = net_outs[idx + fmc * 2][0] * stride
# If model doesn't support batching take output as is
else:
scores = net_outs[idx]
bbox_preds = net_outs[idx + fmc]
bbox_preds = bbox_preds * stride
if self.use_kps:
kps_preds = net_outs[idx + fmc * 2] * stride
height = input_height // stride
width = input_width // stride
K = height * width
key = (height, width, stride)
if key in self.center_cache:
anchor_centers = self.center_cache[key]
else:
#solution-1, c style:
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
#for i in range(height):
# anchor_centers[i, :, 1] = i
#for i in range(width):
# anchor_centers[:, i, 0] = i
#solution-2:
#ax = np.arange(width, dtype=np.float32)
#ay = np.arange(height, dtype=np.float32)
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
#solution-3:
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
#print(anchor_centers.shape)
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
if self._num_anchors>1:
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
if len(self.center_cache)<100:
self.center_cache[key] = anchor_centers
pos_inds = np.where(scores>=threshold)[0]
bboxes = distance2bbox(anchor_centers, bbox_preds)
pos_scores = scores[pos_inds]
pos_bboxes = bboxes[pos_inds]
scores_list.append(pos_scores)
bboxes_list.append(pos_bboxes)
if self.use_kps:
kpss = distance2kps(anchor_centers, kps_preds)
#kpss = kps_preds
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
pos_kpss = kpss[pos_inds]
kpss_list.append(pos_kpss)
return scores_list, bboxes_list, kpss_list
def detect(self, img, input_size = None, max_num=0, metric='default'):
assert input_size is not None or self.input_size is not None
input_size = self.input_size if input_size is None else input_size
im_ratio = float(img.shape[0]) / img.shape[1]
model_ratio = float(input_size[1]) / input_size[0]
if im_ratio>model_ratio:
new_height = input_size[1]
new_width = int(new_height / im_ratio)
else:
new_width = input_size[0]
new_height = int(new_width * im_ratio)
det_scale = float(new_height) / img.shape[0]
resized_img = cv2.resize(img, (new_width, new_height))
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
det_img[:new_height, :new_width, :] = resized_img
scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
scores = np.vstack(scores_list)
scores_ravel = scores.ravel()
order = scores_ravel.argsort()[::-1]
bboxes = np.vstack(bboxes_list) / det_scale
if self.use_kps:
kpss = np.vstack(kpss_list) / det_scale
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
pre_det = pre_det[order, :]
keep = self.nms(pre_det)
det = pre_det[keep, :]
if self.use_kps:
kpss = kpss[order,:,:]
kpss = kpss[keep,:,:]
else:
kpss = None
if max_num > 0 and det.shape[0] > max_num:
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
det[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
if metric=='max':
values = area
else:
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
bindex = np.argsort(
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num]
det = det[bindex, :]
if kpss is not None:
kpss = kpss[bindex, :]
return det, kpss
def nms(self, dets):
thresh = self.nms_thresh
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs):
if not download:
assert os.path.exists(name)
return SCRFD(name)
else:
from .model_store import get_model_file
_file = get_model_file("scrfd_%s" % name, root=root)
return SCRFD(_file)
def scrfd_2p5gkps(**kwargs):
return get_scrfd("2p5gkps", download=True, **kwargs)
if __name__ == '__main__':
import glob
detector = SCRFD(model_file='./det.onnx')
detector.prepare(-1)
img_paths = ['tests/data/t1.jpg']
for img_path in img_paths:
img = cv2.imread(img_path)
for _ in range(1):
ta = datetime.datetime.now()
#bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640))
bboxes, kpss = detector.detect(img, 0.5)
tb = datetime.datetime.now()
print('all cost:', (tb-ta).total_seconds()*1000)
print(img_path, bboxes.shape)
if kpss is not None:
print(kpss.shape)
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1,y1,x2,y2,score = bbox.astype(np.int)
cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2)
if kpss is not None:
kps = kpss[i]
for kp in kps:
kp = kp.astype(np.int)
cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2)
filename = img_path.split('/')[-1]
print('output:', filename)
cv2.imwrite('./outputs/%s'%filename, img)

View File

@ -0,0 +1,4 @@
#import mesh
#import morphable_model
from . import mesh
from . import morphable_model

View File

@ -0,0 +1,15 @@
#from __future__ import absolute_import
#from cython import mesh_core_cython
#import io
#import vis
#import transform
#import light
#import render
from .cython import mesh_core_cython
from . import io
from . import vis
from . import transform
from . import light
from . import render

View File

@ -0,0 +1,375 @@
/*
functions that can not be optimazed by vertorization in python.
1. rasterization.(need process each triangle)
2. normal of each vertex.(use one-ring, need process each vertex)
3. write obj(seems that it can be verctorized? anyway, writing it in c++ is simple, so also add function here. --> however, why writting in c++ is still slow?)
Author: Yao Feng
Mail: yaofeng1995@gmail.com
*/
#include "mesh_core.h"
/* Judge whether the point is in the triangle
Method:
http://blackpawn.com/texts/pointinpoly/
Args:
point: [x, y]
tri_points: three vertices(2d points) of a triangle. 2 coords x 3 vertices
Returns:
bool: true for in triangle
*/
bool isPointInTri(point p, point p0, point p1, point p2)
{
// vectors
point v0, v1, v2;
v0 = p2 - p0;
v1 = p1 - p0;
v2 = p - p0;
// dot products
float dot00 = v0.dot(v0); //v0.x * v0.x + v0.y * v0.y //np.dot(v0.T, v0)
float dot01 = v0.dot(v1); //v0.x * v1.x + v0.y * v1.y //np.dot(v0.T, v1)
float dot02 = v0.dot(v2); //v0.x * v2.x + v0.y * v2.y //np.dot(v0.T, v2)
float dot11 = v1.dot(v1); //v1.x * v1.x + v1.y * v1.y //np.dot(v1.T, v1)
float dot12 = v1.dot(v2); //v1.x * v2.x + v1.y * v2.y//np.dot(v1.T, v2)
// barycentric coordinates
float inverDeno;
if(dot00*dot11 - dot01*dot01 == 0)
inverDeno = 0;
else
inverDeno = 1/(dot00*dot11 - dot01*dot01);
float u = (dot11*dot02 - dot01*dot12)*inverDeno;
float v = (dot00*dot12 - dot01*dot02)*inverDeno;
// check if point in triangle
return (u >= 0) && (v >= 0) && (u + v < 1);
}
void get_point_weight(float* weight, point p, point p0, point p1, point p2)
{
// vectors
point v0, v1, v2;
v0 = p2 - p0;
v1 = p1 - p0;
v2 = p - p0;
// dot products
float dot00 = v0.dot(v0); //v0.x * v0.x + v0.y * v0.y //np.dot(v0.T, v0)
float dot01 = v0.dot(v1); //v0.x * v1.x + v0.y * v1.y //np.dot(v0.T, v1)
float dot02 = v0.dot(v2); //v0.x * v2.x + v0.y * v2.y //np.dot(v0.T, v2)
float dot11 = v1.dot(v1); //v1.x * v1.x + v1.y * v1.y //np.dot(v1.T, v1)
float dot12 = v1.dot(v2); //v1.x * v2.x + v1.y * v2.y//np.dot(v1.T, v2)
// barycentric coordinates
float inverDeno;
if(dot00*dot11 - dot01*dot01 == 0)
inverDeno = 0;
else
inverDeno = 1/(dot00*dot11 - dot01*dot01);
float u = (dot11*dot02 - dot01*dot12)*inverDeno;
float v = (dot00*dot12 - dot01*dot02)*inverDeno;
// weight
weight[0] = 1 - u - v;
weight[1] = v;
weight[2] = u;
}
void _get_normal_core(
float* normal, float* tri_normal, int* triangles,
int ntri)
{
int i, j;
int tri_p0_ind, tri_p1_ind, tri_p2_ind;
for(i = 0; i < ntri; i++)
{
tri_p0_ind = triangles[3*i];
tri_p1_ind = triangles[3*i + 1];
tri_p2_ind = triangles[3*i + 2];
for(j = 0; j < 3; j++)
{
normal[3*tri_p0_ind + j] = normal[3*tri_p0_ind + j] + tri_normal[3*i + j];
normal[3*tri_p1_ind + j] = normal[3*tri_p1_ind + j] + tri_normal[3*i + j];
normal[3*tri_p2_ind + j] = normal[3*tri_p2_ind + j] + tri_normal[3*i + j];
}
}
}
void _rasterize_triangles_core(
float* vertices, int* triangles,
float* depth_buffer, int* triangle_buffer, float* barycentric_weight,
int nver, int ntri,
int h, int w)
{
int i;
int x, y, k;
int tri_p0_ind, tri_p1_ind, tri_p2_ind;
point p0, p1, p2, p;
int x_min, x_max, y_min, y_max;
float p_depth, p0_depth, p1_depth, p2_depth;
float weight[3];
for(i = 0; i < ntri; i++)
{
tri_p0_ind = triangles[3*i];
tri_p1_ind = triangles[3*i + 1];
tri_p2_ind = triangles[3*i + 2];
p0.x = vertices[3*tri_p0_ind]; p0.y = vertices[3*tri_p0_ind + 1]; p0_depth = vertices[3*tri_p0_ind + 2];
p1.x = vertices[3*tri_p1_ind]; p1.y = vertices[3*tri_p1_ind + 1]; p1_depth = vertices[3*tri_p1_ind + 2];
p2.x = vertices[3*tri_p2_ind]; p2.y = vertices[3*tri_p2_ind + 1]; p2_depth = vertices[3*tri_p2_ind + 2];
x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0);
x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1);
y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0);
y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1);
if(x_max < x_min || y_max < y_min)
{
continue;
}
for(y = y_min; y <= y_max; y++) //h
{
for(x = x_min; x <= x_max; x++) //w
{
p.x = x; p.y = y;
if(p.x < 2 || p.x > w - 3 || p.y < 2 || p.y > h - 3 || isPointInTri(p, p0, p1, p2))
{
get_point_weight(weight, p, p0, p1, p2);
p_depth = weight[0]*p0_depth + weight[1]*p1_depth + weight[2]*p2_depth;
if((p_depth > depth_buffer[y*w + x]))
{
depth_buffer[y*w + x] = p_depth;
triangle_buffer[y*w + x] = i;
for(k = 0; k < 3; k++)
{
barycentric_weight[y*w*3 + x*3 + k] = weight[k];
}
}
}
}
}
}
}
void _render_colors_core(
float* image, float* vertices, int* triangles,
float* colors,
float* depth_buffer,
int nver, int ntri,
int h, int w, int c)
{
int i;
int x, y, k;
int tri_p0_ind, tri_p1_ind, tri_p2_ind;
point p0, p1, p2, p;
int x_min, x_max, y_min, y_max;
float p_depth, p0_depth, p1_depth, p2_depth;
float p_color, p0_color, p1_color, p2_color;
float weight[3];
for(i = 0; i < ntri; i++)
{
tri_p0_ind = triangles[3*i];
tri_p1_ind = triangles[3*i + 1];
tri_p2_ind = triangles[3*i + 2];
p0.x = vertices[3*tri_p0_ind]; p0.y = vertices[3*tri_p0_ind + 1]; p0_depth = vertices[3*tri_p0_ind + 2];
p1.x = vertices[3*tri_p1_ind]; p1.y = vertices[3*tri_p1_ind + 1]; p1_depth = vertices[3*tri_p1_ind + 2];
p2.x = vertices[3*tri_p2_ind]; p2.y = vertices[3*tri_p2_ind + 1]; p2_depth = vertices[3*tri_p2_ind + 2];
x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0);
x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1);
y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0);
y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1);
if(x_max < x_min || y_max < y_min)
{
continue;
}
for(y = y_min; y <= y_max; y++) //h
{
for(x = x_min; x <= x_max; x++) //w
{
p.x = x; p.y = y;
if(p.x < 2 || p.x > w - 3 || p.y < 2 || p.y > h - 3 || isPointInTri(p, p0, p1, p2))
{
get_point_weight(weight, p, p0, p1, p2);
p_depth = weight[0]*p0_depth + weight[1]*p1_depth + weight[2]*p2_depth;
if((p_depth > depth_buffer[y*w + x]))
{
for(k = 0; k < c; k++) // c
{
p0_color = colors[c*tri_p0_ind + k];
p1_color = colors[c*tri_p1_ind + k];
p2_color = colors[c*tri_p2_ind + k];
p_color = weight[0]*p0_color + weight[1]*p1_color + weight[2]*p2_color;
image[y*w*c + x*c + k] = p_color;
}
depth_buffer[y*w + x] = p_depth;
}
}
}
}
}
}
void _render_texture_core(
float* image, float* vertices, int* triangles,
float* texture, float* tex_coords, int* tex_triangles,
float* depth_buffer,
int nver, int tex_nver, int ntri,
int h, int w, int c,
int tex_h, int tex_w, int tex_c,
int mapping_type)
{
int i;
int x, y, k;
int tri_p0_ind, tri_p1_ind, tri_p2_ind;
int tex_tri_p0_ind, tex_tri_p1_ind, tex_tri_p2_ind;
point p0, p1, p2, p;
point tex_p0, tex_p1, tex_p2, tex_p;
int x_min, x_max, y_min, y_max;
float weight[3];
float p_depth, p0_depth, p1_depth, p2_depth;
float xd, yd;
float ul, ur, dl, dr;
for(i = 0; i < ntri; i++)
{
// mesh
tri_p0_ind = triangles[3*i];
tri_p1_ind = triangles[3*i + 1];
tri_p2_ind = triangles[3*i + 2];
p0.x = vertices[3*tri_p0_ind]; p0.y = vertices[3*tri_p0_ind + 1]; p0_depth = vertices[3*tri_p0_ind + 2];
p1.x = vertices[3*tri_p1_ind]; p1.y = vertices[3*tri_p1_ind + 1]; p1_depth = vertices[3*tri_p1_ind + 2];
p2.x = vertices[3*tri_p2_ind]; p2.y = vertices[3*tri_p2_ind + 1]; p2_depth = vertices[3*tri_p2_ind + 2];
// texture
tex_tri_p0_ind = tex_triangles[3*i];
tex_tri_p1_ind = tex_triangles[3*i + 1];
tex_tri_p2_ind = tex_triangles[3*i + 2];
tex_p0.x = tex_coords[3*tex_tri_p0_ind]; tex_p0.y = tex_coords[3*tri_p0_ind + 1];
tex_p1.x = tex_coords[3*tex_tri_p1_ind]; tex_p1.y = tex_coords[3*tri_p1_ind + 1];
tex_p2.x = tex_coords[3*tex_tri_p2_ind]; tex_p2.y = tex_coords[3*tri_p2_ind + 1];
x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0);
x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1);
y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0);
y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1);
if(x_max < x_min || y_max < y_min)
{
continue;
}
for(y = y_min; y <= y_max; y++) //h
{
for(x = x_min; x <= x_max; x++) //w
{
p.x = x; p.y = y;
if(p.x < 2 || p.x > w - 3 || p.y < 2 || p.y > h - 3 || isPointInTri(p, p0, p1, p2))
{
get_point_weight(weight, p, p0, p1, p2);
p_depth = weight[0]*p0_depth + weight[1]*p1_depth + weight[2]*p2_depth;
if((p_depth > depth_buffer[y*w + x]))
{
// -- color from texture
// cal weight in mesh tri
get_point_weight(weight, p, p0, p1, p2);
// cal coord in texture
tex_p = tex_p0*weight[0] + tex_p1*weight[1] + tex_p2*weight[2];
tex_p.x = max(min(tex_p.x, float(tex_w - 1)), float(0));
tex_p.y = max(min(tex_p.y, float(tex_h - 1)), float(0));
yd = tex_p.y - floor(tex_p.y);
xd = tex_p.x - floor(tex_p.x);
for(k = 0; k < c; k++)
{
if(mapping_type==0)// nearest
{
image[y*w*c + x*c + k] = texture[int(round(tex_p.y))*tex_w*tex_c + int(round(tex_p.x))*tex_c + k];
}
else//bilinear interp
{
ul = texture[(int)floor(tex_p.y)*tex_w*tex_c + (int)floor(tex_p.x)*tex_c + k];
ur = texture[(int)floor(tex_p.y)*tex_w*tex_c + (int)ceil(tex_p.x)*tex_c + k];
dl = texture[(int)ceil(tex_p.y)*tex_w*tex_c + (int)floor(tex_p.x)*tex_c + k];
dr = texture[(int)ceil(tex_p.y)*tex_w*tex_c + (int)ceil(tex_p.x)*tex_c + k];
image[y*w*c + x*c + k] = ul*(1-xd)*(1-yd) + ur*xd*(1-yd) + dl*(1-xd)*yd + dr*xd*yd;
}
}
depth_buffer[y*w + x] = p_depth;
}
}
}
}
}
}
// ------------------------------------------------- write
// obj write
// Ref: https://github.com/patrikhuber/eos/blob/master/include/eos/core/Mesh.hpp
void _write_obj_with_colors_texture(string filename, string mtl_name,
float* vertices, int* triangles, float* colors, float* uv_coords,
int nver, int ntri, int ntexver)
{
int i;
ofstream obj_file(filename.c_str());
// first line of the obj file: the mtl name
obj_file << "mtllib " << mtl_name << endl;
// write vertices
for (i = 0; i < nver; ++i)
{
obj_file << "v " << vertices[3*i] << " " << vertices[3*i + 1] << " " << vertices[3*i + 2] << colors[3*i] << " " << colors[3*i + 1] << " " << colors[3*i + 2] << endl;
}
// write uv coordinates
for (i = 0; i < ntexver; ++i)
{
//obj_file << "vt " << uv_coords[2*i] << " " << (1 - uv_coords[2*i + 1]) << endl;
obj_file << "vt " << uv_coords[2*i] << " " << uv_coords[2*i + 1] << endl;
}
obj_file << "usemtl FaceTexture" << endl;
// write triangles
for (i = 0; i < ntri; ++i)
{
// obj_file << "f " << triangles[3*i] << "/" << triangles[3*i] << " " << triangles[3*i + 1] << "/" << triangles[3*i + 1] << " " << triangles[3*i + 2] << "/" << triangles[3*i + 2] << endl;
obj_file << "f " << triangles[3*i + 2] << "/" << triangles[3*i + 2] << " " << triangles[3*i + 1] << "/" << triangles[3*i + 1] << " " << triangles[3*i] << "/" << triangles[3*i] << endl;
}
}

View File

@ -0,0 +1,83 @@
#ifndef MESH_CORE_HPP_
#define MESH_CORE_HPP_
#include <stdio.h>
#include <cmath>
#include <algorithm>
#include <string>
#include <iostream>
#include <fstream>
using namespace std;
class point
{
public:
float x;
float y;
float dot(point p)
{
return this->x * p.x + this->y * p.y;
}
point operator-(const point& p)
{
point np;
np.x = this->x - p.x;
np.y = this->y - p.y;
return np;
}
point operator+(const point& p)
{
point np;
np.x = this->x + p.x;
np.y = this->y + p.y;
return np;
}
point operator*(float s)
{
point np;
np.x = s * this->x;
np.y = s * this->y;
return np;
}
};
bool isPointInTri(point p, point p0, point p1, point p2, int h, int w);
void get_point_weight(float* weight, point p, point p0, point p1, point p2);
void _get_normal_core(
float* normal, float* tri_normal, int* triangles,
int ntri);
void _rasterize_triangles_core(
float* vertices, int* triangles,
float* depth_buffer, int* triangle_buffer, float* barycentric_weight,
int nver, int ntri,
int h, int w);
void _render_colors_core(
float* image, float* vertices, int* triangles,
float* colors,
float* depth_buffer,
int nver, int ntri,
int h, int w, int c);
void _render_texture_core(
float* image, float* vertices, int* triangles,
float* texture, float* tex_coords, int* tex_triangles,
float* depth_buffer,
int nver, int tex_nver, int ntri,
int h, int w, int c,
int tex_h, int tex_w, int tex_c,
int mapping_type);
void _write_obj_with_colors_texture(string filename, string mtl_name,
float* vertices, int* triangles, float* colors, float* uv_coords,
int nver, int ntri, int ntexver);
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,109 @@
import numpy as np
cimport numpy as np
from libcpp.string cimport string
# use the Numpy-C-API from Cython
np.import_array()
# cdefine the signature of our c function
cdef extern from "mesh_core.h":
void _rasterize_triangles_core(
float* vertices, int* triangles,
float* depth_buffer, int* triangle_buffer, float* barycentric_weight,
int nver, int ntri,
int h, int w)
void _render_colors_core(
float* image, float* vertices, int* triangles,
float* colors,
float* depth_buffer,
int nver, int ntri,
int h, int w, int c)
void _render_texture_core(
float* image, float* vertices, int* triangles,
float* texture, float* tex_coords, int* tex_triangles,
float* depth_buffer,
int nver, int tex_nver, int ntri,
int h, int w, int c,
int tex_h, int tex_w, int tex_c,
int mapping_type)
void _get_normal_core(
float* normal, float* tri_normal, int* triangles,
int ntri)
void _write_obj_with_colors_texture(string filename, string mtl_name,
float* vertices, int* triangles, float* colors, float* uv_coords,
int nver, int ntri, int ntexver)
def get_normal_core(np.ndarray[float, ndim=2, mode = "c"] normal not None,
np.ndarray[float, ndim=2, mode = "c"] tri_normal not None,
np.ndarray[int, ndim=2, mode="c"] triangles not None,
int ntri
):
_get_normal_core(
<float*> np.PyArray_DATA(normal), <float*> np.PyArray_DATA(tri_normal), <int*> np.PyArray_DATA(triangles),
ntri)
def rasterize_triangles_core(
np.ndarray[float, ndim=2, mode = "c"] vertices not None,
np.ndarray[int, ndim=2, mode="c"] triangles not None,
np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None,
np.ndarray[int, ndim=2, mode = "c"] triangle_buffer not None,
np.ndarray[float, ndim=2, mode = "c"] barycentric_weight not None,
int nver, int ntri,
int h, int w
):
_rasterize_triangles_core(
<float*> np.PyArray_DATA(vertices), <int*> np.PyArray_DATA(triangles),
<float*> np.PyArray_DATA(depth_buffer), <int*> np.PyArray_DATA(triangle_buffer), <float*> np.PyArray_DATA(barycentric_weight),
nver, ntri,
h, w)
def render_colors_core(np.ndarray[float, ndim=3, mode = "c"] image not None,
np.ndarray[float, ndim=2, mode = "c"] vertices not None,
np.ndarray[int, ndim=2, mode="c"] triangles not None,
np.ndarray[float, ndim=2, mode = "c"] colors not None,
np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None,
int nver, int ntri,
int h, int w, int c
):
_render_colors_core(
<float*> np.PyArray_DATA(image), <float*> np.PyArray_DATA(vertices), <int*> np.PyArray_DATA(triangles),
<float*> np.PyArray_DATA(colors),
<float*> np.PyArray_DATA(depth_buffer),
nver, ntri,
h, w, c)
def render_texture_core(np.ndarray[float, ndim=3, mode = "c"] image not None,
np.ndarray[float, ndim=2, mode = "c"] vertices not None,
np.ndarray[int, ndim=2, mode="c"] triangles not None,
np.ndarray[float, ndim=3, mode = "c"] texture not None,
np.ndarray[float, ndim=2, mode = "c"] tex_coords not None,
np.ndarray[int, ndim=2, mode="c"] tex_triangles not None,
np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None,
int nver, int tex_nver, int ntri,
int h, int w, int c,
int tex_h, int tex_w, int tex_c,
int mapping_type
):
_render_texture_core(
<float*> np.PyArray_DATA(image), <float*> np.PyArray_DATA(vertices), <int*> np.PyArray_DATA(triangles),
<float*> np.PyArray_DATA(texture), <float*> np.PyArray_DATA(tex_coords), <int*> np.PyArray_DATA(tex_triangles),
<float*> np.PyArray_DATA(depth_buffer),
nver, tex_nver, ntri,
h, w, c,
tex_h, tex_w, tex_c,
mapping_type)
def write_obj_with_colors_texture_core(string filename, string mtl_name,
np.ndarray[float, ndim=2, mode = "c"] vertices not None,
np.ndarray[int, ndim=2, mode="c"] triangles not None,
np.ndarray[float, ndim=2, mode = "c"] colors not None,
np.ndarray[float, ndim=2, mode = "c"] uv_coords not None,
int nver, int ntri, int ntexver
):
_write_obj_with_colors_texture(filename, mtl_name,
<float*> np.PyArray_DATA(vertices), <int*> np.PyArray_DATA(triangles), <float*> np.PyArray_DATA(colors), <float*> np.PyArray_DATA(uv_coords),
nver, ntri, ntexver)

View File

@ -0,0 +1,20 @@
'''
python setup.py build_ext -i
to compile
'''
# setup.py
from distutils.core import setup, Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
import numpy
setup(
name = 'mesh_core_cython',
cmdclass={'build_ext': build_ext},
ext_modules=[Extension("mesh_core_cython",
sources=["mesh_core_cython.pyx", "mesh_core.cpp"],
language='c++',
include_dirs=[numpy.get_include()])],
)

View File

@ -0,0 +1,142 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
from skimage import io
from time import time
from .cython import mesh_core_cython
## TODO
## TODO: c++ version
def read_obj(obj_name):
''' read mesh
'''
return 0
# ------------------------- write
def write_asc(path, vertices):
'''
Args:
vertices: shape = (nver, 3)
'''
if path.split('.')[-1] == 'asc':
np.savetxt(path, vertices)
else:
np.savetxt(path + '.asc', vertices)
def write_obj_with_colors(obj_name, vertices, triangles, colors):
''' Save 3D face model with texture represented by colors.
Args:
obj_name: str
vertices: shape = (nver, 3)
triangles: shape = (ntri, 3)
colors: shape = (nver, 3)
'''
triangles = triangles.copy()
triangles += 1 # meshlab start with 1
if obj_name.split('.')[-1] != 'obj':
obj_name = obj_name + '.obj'
# write obj
with open(obj_name, 'w') as f:
# write vertices & colors
for i in range(vertices.shape[0]):
# s = 'v {} {} {} \n'.format(vertices[0,i], vertices[1,i], vertices[2,i])
s = 'v {} {} {} {} {} {}\n'.format(vertices[i, 0], vertices[i, 1], vertices[i, 2], colors[i, 0], colors[i, 1], colors[i, 2])
f.write(s)
# write f: ver ind/ uv ind
[k, ntri] = triangles.shape
for i in range(triangles.shape[0]):
# s = 'f {} {} {}\n'.format(triangles[i, 0], triangles[i, 1], triangles[i, 2])
s = 'f {} {} {}\n'.format(triangles[i, 2], triangles[i, 1], triangles[i, 0])
f.write(s)
## TODO: c++ version
def write_obj_with_texture(obj_name, vertices, triangles, texture, uv_coords):
''' Save 3D face model with texture represented by texture map.
Ref: https://github.com/patrikhuber/eos/blob/bd00155ebae4b1a13b08bf5a991694d682abbada/include/eos/core/Mesh.hpp
Args:
obj_name: str
vertices: shape = (nver, 3)
triangles: shape = (ntri, 3)
texture: shape = (256,256,3)
uv_coords: shape = (nver, 3) max value<=1
'''
if obj_name.split('.')[-1] != 'obj':
obj_name = obj_name + '.obj'
mtl_name = obj_name.replace('.obj', '.mtl')
texture_name = obj_name.replace('.obj', '_texture.png')
triangles = triangles.copy()
triangles += 1 # mesh lab start with 1
# write obj
with open(obj_name, 'w') as f:
# first line: write mtlib(material library)
s = "mtllib {}\n".format(os.path.abspath(mtl_name))
f.write(s)
# write vertices
for i in range(vertices.shape[0]):
s = 'v {} {} {}\n'.format(vertices[i, 0], vertices[i, 1], vertices[i, 2])
f.write(s)
# write uv coords
for i in range(uv_coords.shape[0]):
s = 'vt {} {}\n'.format(uv_coords[i,0], 1 - uv_coords[i,1])
f.write(s)
f.write("usemtl FaceTexture\n")
# write f: ver ind/ uv ind
for i in range(triangles.shape[0]):
s = 'f {}/{} {}/{} {}/{}\n'.format(triangles[i,2], triangles[i,2], triangles[i,1], triangles[i,1], triangles[i,0], triangles[i,0])
f.write(s)
# write mtl
with open(mtl_name, 'w') as f:
f.write("newmtl FaceTexture\n")
s = 'map_Kd {}\n'.format(os.path.abspath(texture_name)) # map to image
f.write(s)
# write texture as png
imsave(texture_name, texture)
# c++ version
def write_obj_with_colors_texture(obj_name, vertices, triangles, colors, texture, uv_coords):
''' Save 3D face model with texture.
Ref: https://github.com/patrikhuber/eos/blob/bd00155ebae4b1a13b08bf5a991694d682abbada/include/eos/core/Mesh.hpp
Args:
obj_name: str
vertices: shape = (nver, 3)
triangles: shape = (ntri, 3)
colors: shape = (nver, 3)
texture: shape = (256,256,3)
uv_coords: shape = (nver, 3) max value<=1
'''
if obj_name.split('.')[-1] != 'obj':
obj_name = obj_name + '.obj'
mtl_name = obj_name.replace('.obj', '.mtl')
texture_name = obj_name.replace('.obj', '_texture.png')
triangles = triangles.copy()
triangles += 1 # mesh lab start with 1
# write obj
vertices, colors, uv_coords = vertices.astype(np.float32).copy(), colors.astype(np.float32).copy(), uv_coords.astype(np.float32).copy()
mesh_core_cython.write_obj_with_colors_texture_core(str.encode(obj_name), str.encode(os.path.abspath(mtl_name)), vertices, triangles, colors, uv_coords, vertices.shape[0], triangles.shape[0], uv_coords.shape[0])
# write mtl
with open(mtl_name, 'w') as f:
f.write("newmtl FaceTexture\n")
s = 'map_Kd {}\n'.format(os.path.abspath(texture_name)) # map to image
f.write(s)
# write texture as png
io.imsave(texture_name, texture)

View File

@ -0,0 +1,213 @@
'''
Functions about lighting mesh(changing colors/texture of mesh).
1. add light to colors/texture (shade each vertex)
2. fit light according to colors/texture & image.
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from .cython import mesh_core_cython
def get_normal(vertices, triangles):
''' calculate normal direction in each vertex
Args:
vertices: [nver, 3]
triangles: [ntri, 3]
Returns:
normal: [nver, 3]
'''
pt0 = vertices[triangles[:, 0], :] # [ntri, 3]
pt1 = vertices[triangles[:, 1], :] # [ntri, 3]
pt2 = vertices[triangles[:, 2], :] # [ntri, 3]
tri_normal = np.cross(pt0 - pt1, pt0 - pt2) # [ntri, 3]. normal of each triangle
normal = np.zeros_like(vertices, dtype = np.float32).copy() # [nver, 3]
# for i in range(triangles.shape[0]):
# normal[triangles[i, 0], :] = normal[triangles[i, 0], :] + tri_normal[i, :]
# normal[triangles[i, 1], :] = normal[triangles[i, 1], :] + tri_normal[i, :]
# normal[triangles[i, 2], :] = normal[triangles[i, 2], :] + tri_normal[i, :]
mesh_core_cython.get_normal_core(normal, tri_normal.astype(np.float32).copy(), triangles.copy(), triangles.shape[0])
# normalize to unit length
mag = np.sum(normal**2, 1) # [nver]
zero_ind = (mag == 0)
mag[zero_ind] = 1;
normal[zero_ind, 0] = np.ones((np.sum(zero_ind)))
normal = normal/np.sqrt(mag[:,np.newaxis])
return normal
# TODO: test
def add_light_sh(vertices, triangles, colors, sh_coeff):
'''
In 3d face, usually assume:
1. The surface of face is Lambertian(reflect only the low frequencies of lighting)
2. Lighting can be an arbitrary combination of point sources
--> can be expressed in terms of spherical harmonics(omit the lighting coefficients)
I = albedo * (sh(n) x sh_coeff)
albedo: n x 1
sh_coeff: 9 x 1
Y(n) = (1, n_x, n_y, n_z, n_xn_y, n_xn_z, n_yn_z, n_x^2 - n_y^2, 3n_z^2 - 1)': n x 9
# Y(n) = (1, n_x, n_y, n_z)': n x 4
Args:
vertices: [nver, 3]
triangles: [ntri, 3]
colors: [nver, 3] albedo
sh_coeff: [9, 1] spherical harmonics coefficients
Returns:
lit_colors: [nver, 3]
'''
assert vertices.shape[0] == colors.shape[0]
nver = vertices.shape[0]
normal = get_normal(vertices, triangles) # [nver, 3]
sh = np.array((np.ones(nver), n[:,0], n[:,1], n[:,2], n[:,0]*n[:,1], n[:,0]*n[:,2], n[:,1]*n[:,2], n[:,0]**2 - n[:,1]**2, 3*(n[:,2]**2) - 1)) # [nver, 9]
ref = sh.dot(sh_coeff) #[nver, 1]
lit_colors = colors*ref
return lit_colors
def add_light(vertices, triangles, colors, light_positions = 0, light_intensities = 0):
''' Gouraud shading. add point lights.
In 3d face, usually assume:
1. The surface of face is Lambertian(reflect only the low frequencies of lighting)
2. Lighting can be an arbitrary combination of point sources
3. No specular (unless skin is oil, 23333)
Ref: https://cs184.eecs.berkeley.edu/lecture/pipeline
Args:
vertices: [nver, 3]
triangles: [ntri, 3]
light_positions: [nlight, 3]
light_intensities: [nlight, 3]
Returns:
lit_colors: [nver, 3]
'''
nver = vertices.shape[0]
normals = get_normal(vertices, triangles) # [nver, 3]
# ambient
# La = ka*Ia
# diffuse
# Ld = kd*(I/r^2)max(0, nxl)
direction_to_lights = vertices[np.newaxis, :, :] - light_positions[:, np.newaxis, :] # [nlight, nver, 3]
direction_to_lights_n = np.sqrt(np.sum(direction_to_lights**2, axis = 2)) # [nlight, nver]
direction_to_lights = direction_to_lights/direction_to_lights_n[:, :, np.newaxis]
normals_dot_lights = normals[np.newaxis, :, :]*direction_to_lights # [nlight, nver, 3]
normals_dot_lights = np.sum(normals_dot_lights, axis = 2) # [nlight, nver]
diffuse_output = colors[np.newaxis, :, :]*normals_dot_lights[:, :, np.newaxis]*light_intensities[:, np.newaxis, :]
diffuse_output = np.sum(diffuse_output, axis = 0) # [nver, 3]
# specular
# h = (v + l)/(|v + l|) bisector
# Ls = ks*(I/r^2)max(0, nxh)^p
# increasing p narrows the reflectionlob
lit_colors = diffuse_output # only diffuse part here.
lit_colors = np.minimum(np.maximum(lit_colors, 0), 1)
return lit_colors
## TODO. estimate light(sh coeff)
## -------------------------------- estimate. can not use now.
def fit_light(image, vertices, colors, triangles, vis_ind, lamb = 10, max_iter = 3):
[h, w, c] = image.shape
# surface normal
norm = get_normal(vertices, triangles)
nver = vertices.shape[1]
# vertices --> corresponding image pixel
pt2d = vertices[:2, :]
pt2d[0,:] = np.minimum(np.maximum(pt2d[0,:], 0), w - 1)
pt2d[1,:] = np.minimum(np.maximum(pt2d[1,:], 0), h - 1)
pt2d = np.round(pt2d).astype(np.int32) # 2 x nver
image_pixel = image[pt2d[1,:], pt2d[0,:], :] # nver x 3
image_pixel = image_pixel.T # 3 x nver
# vertices --> corresponding mean texture pixel with illumination
# Spherical Harmonic Basis
harmonic_dim = 9
nx = norm[0,:];
ny = norm[1,:];
nz = norm[2,:];
harmonic = np.zeros((nver, harmonic_dim))
pi = np.pi
harmonic[:,0] = np.sqrt(1/(4*pi)) * np.ones((nver,));
harmonic[:,1] = np.sqrt(3/(4*pi)) * nx;
harmonic[:,2] = np.sqrt(3/(4*pi)) * ny;
harmonic[:,3] = np.sqrt(3/(4*pi)) * nz;
harmonic[:,4] = 1/2. * np.sqrt(3/(4*pi)) * (2*nz**2 - nx**2 - ny**2);
harmonic[:,5] = 3 * np.sqrt(5/(12*pi)) * (ny*nz);
harmonic[:,6] = 3 * np.sqrt(5/(12*pi)) * (nx*nz);
harmonic[:,7] = 3 * np.sqrt(5/(12*pi)) * (nx*ny);
harmonic[:,8] = 3/2. * np.sqrt(5/(12*pi)) * (nx*nx - ny*ny);
'''
I' = sum(albedo * lj * hj) j = 0:9 (albedo = tex)
set A = albedo*h (n x 9)
alpha = lj (9 x 1)
Y = I (n x 1)
Y' = A.dot(alpha)
opt function:
||Y - A*alpha|| + lambda*(alpha'*alpha)
result:
A'*(Y - A*alpha) + lambda*alpha = 0
==>
(A'*A*alpha - lambda)*alpha = A'*Y
left: 9 x 9
right: 9 x 1
'''
n_vis_ind = len(vis_ind)
n = n_vis_ind*c
Y = np.zeros((n, 1))
A = np.zeros((n, 9))
light = np.zeros((3, 1))
for k in range(c):
Y[k*n_vis_ind:(k+1)*n_vis_ind, :] = image_pixel[k, vis_ind][:, np.newaxis]
A[k*n_vis_ind:(k+1)*n_vis_ind, :] = texture[k, vis_ind][:, np.newaxis] * harmonic[vis_ind, :]
Ac = texture[k, vis_ind][:, np.newaxis]
Yc = image_pixel[k, vis_ind][:, np.newaxis]
light[k] = (Ac.T.dot(Yc))/(Ac.T.dot(Ac))
for i in range(max_iter):
Yc = Y.copy()
for k in range(c):
Yc[k*n_vis_ind:(k+1)*n_vis_ind, :] /= light[k]
# update alpha
equation_left = np.dot(A.T, A) + lamb*np.eye(harmonic_dim); # why + ?
equation_right = np.dot(A.T, Yc)
alpha = np.dot(np.linalg.inv(equation_left), equation_right)
# update light
for k in range(c):
Ac = A[k*n_vis_ind:(k+1)*n_vis_ind, :].dot(alpha)
Yc = Y[k*n_vis_ind:(k+1)*n_vis_ind, :]
light[k] = (Ac.T.dot(Yc))/(Ac.T.dot(Ac))
appearance = np.zeros_like(texture)
for k in range(c):
tmp = np.dot(harmonic*texture[k, :][:, np.newaxis], alpha*light[k])
appearance[k,:] = tmp.T
appearance = np.minimum(np.maximum(appearance, 0), 1)
return appearance

View File

@ -0,0 +1,135 @@
'''
functions about rendering mesh(from 3d obj to 2d image).
only use rasterization render here.
Note that:
1. Generally, render func includes camera, light, raterize. Here no camera and light(I write these in other files)
2. Generally, the input vertices are normalized to [-1,1] and cetered on [0, 0]. (in world space)
Here, the vertices are using image coords, which centers on [w/2, h/2] with the y-axis pointing to oppisite direction.
Means: render here only conducts interpolation.(I just want to make the input flexible)
Author: Yao Feng
Mail: yaofeng1995@gmail.com
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from time import time
from .cython import mesh_core_cython
def rasterize_triangles(vertices, triangles, h, w):
'''
Args:
vertices: [nver, 3]
triangles: [ntri, 3]
h: height
w: width
Returns:
depth_buffer: [h, w] saves the depth, here, the bigger the z, the fronter the point.
triangle_buffer: [h, w] saves the tri id(-1 for no triangle).
barycentric_weight: [h, w, 3] saves corresponding barycentric weight.
# Each triangle has 3 vertices & Each vertex has 3 coordinates x, y, z.
# h, w is the size of rendering
'''
# initial
depth_buffer = np.zeros([h, w]) - 999999. #set the initial z to the farest position
triangle_buffer = np.zeros([h, w], dtype = np.int32) - 1 # if tri id = -1, the pixel has no triangle correspondance
barycentric_weight = np.zeros([h, w, 3], dtype = np.float32) #
vertices = vertices.astype(np.float32).copy()
triangles = triangles.astype(np.int32).copy()
mesh_core_cython.rasterize_triangles_core(
vertices, triangles,
depth_buffer, triangle_buffer, barycentric_weight,
vertices.shape[0], triangles.shape[0],
h, w)
def render_colors(vertices, triangles, colors, h, w, c = 3, BG = None):
''' render mesh with colors
Args:
vertices: [nver, 3]
triangles: [ntri, 3]
colors: [nver, 3]
h: height
w: width
c: channel
BG: background image
Returns:
image: [h, w, c]. rendered image./rendering.
'''
# initial
if BG is None:
image = np.zeros((h, w, c), dtype = np.float32)
else:
assert BG.shape[0] == h and BG.shape[1] == w and BG.shape[2] == c
image = BG
depth_buffer = np.zeros([h, w], dtype = np.float32, order = 'C') - 999999.
# change orders. --> C-contiguous order(column major)
vertices = vertices.astype(np.float32).copy()
triangles = triangles.astype(np.int32).copy()
colors = colors.astype(np.float32).copy()
###
st = time()
mesh_core_cython.render_colors_core(
image, vertices, triangles,
colors,
depth_buffer,
vertices.shape[0], triangles.shape[0],
h, w, c)
return image
def render_texture(vertices, triangles, texture, tex_coords, tex_triangles, h, w, c = 3, mapping_type = 'nearest', BG = None):
''' render mesh with texture map
Args:
vertices: [3, nver]
triangles: [3, ntri]
texture: [tex_h, tex_w, 3]
tex_coords: [ntexcoords, 3]
tex_triangles: [ntri, 3]
h: height of rendering
w: width of rendering
c: channel
mapping_type: 'bilinear' or 'nearest'
'''
# initial
if BG is None:
image = np.zeros((h, w, c), dtype = np.float32)
else:
assert BG.shape[0] == h and BG.shape[1] == w and BG.shape[2] == c
image = BG
depth_buffer = np.zeros([h, w], dtype = np.float32, order = 'C') - 999999.
tex_h, tex_w, tex_c = texture.shape
if mapping_type == 'nearest':
mt = int(0)
elif mapping_type == 'bilinear':
mt = int(1)
else:
mt = int(0)
# -> C order
vertices = vertices.astype(np.float32).copy()
triangles = triangles.astype(np.int32).copy()
texture = texture.astype(np.float32).copy()
tex_coords = tex_coords.astype(np.float32).copy()
tex_triangles = tex_triangles.astype(np.int32).copy()
mesh_core_cython.render_texture_core(
image, vertices, triangles,
texture, tex_coords, tex_triangles,
depth_buffer,
vertices.shape[0], tex_coords.shape[0], triangles.shape[0],
h, w, c,
tex_h, tex_w, tex_c,
mt)
return image

Some files were not shown because too many files have changed in this diff Show More