cleanup unneeded files for Cat file creation
19
.vscode/settings.json
vendored
@ -1,19 +0,0 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.tabSize": 4
|
||||
},
|
||||
"files.eol": "\n",
|
||||
"files.insertFinalNewline": true,
|
||||
"files.trimFinalNewlines": true,
|
||||
"files.trimTrailingWhitespace": true,
|
||||
"files.exclude": {
|
||||
"**/.git": true,
|
||||
"**/.svn": true,
|
||||
"**/.hg": true,
|
||||
"**/CVS": true,
|
||||
"**/.DS_Store": true,
|
||||
"**/Thumbs.db": true,
|
||||
"**/*.crswap": true,
|
||||
"**/__pycache__": true
|
||||
}
|
||||
}
|
215
app.py
@ -1,215 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
The entrance of the gradio
|
||||
"""
|
||||
|
||||
import tyro
|
||||
import subprocess
|
||||
import gradio as gr
|
||||
import os.path as osp
|
||||
from src.utils.helper import load_description
|
||||
from src.gradio_pipeline import GradioPipeline
|
||||
from src.config.crop_config import CropConfig
|
||||
from src.config.argument_config import ArgumentConfig
|
||||
from src.config.inference_config import InferenceConfig
|
||||
|
||||
|
||||
def partial_fields(target_class, kwargs):
|
||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# set tyro theme
|
||||
tyro.extras.set_accent_color("bright_cyan")
|
||||
args = tyro.cli(ArgumentConfig)
|
||||
|
||||
if not fast_check_ffmpeg():
|
||||
raise ImportError(
|
||||
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
|
||||
)
|
||||
|
||||
# specify configs for inference
|
||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||
|
||||
gradio_pipeline = GradioPipeline(
|
||||
inference_cfg=inference_cfg,
|
||||
crop_cfg=crop_cfg,
|
||||
args=args
|
||||
)
|
||||
|
||||
|
||||
def gpu_wrapped_execute_video(*args, **kwargs):
|
||||
return gradio_pipeline.execute_video(*args, **kwargs)
|
||||
|
||||
|
||||
def gpu_wrapped_execute_image(*args, **kwargs):
|
||||
return gradio_pipeline.execute_image(*args, **kwargs)
|
||||
|
||||
|
||||
# assets
|
||||
title_md = "assets/gradio_title.md"
|
||||
example_portrait_dir = "assets/examples/source"
|
||||
example_video_dir = "assets/examples/driving"
|
||||
data_examples = [
|
||||
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
|
||||
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
|
||||
]
|
||||
#################### interface logic ####################
|
||||
|
||||
# Define components first
|
||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||
retargeting_input_image = gr.Image(type="filepath")
|
||||
output_image = gr.Image(type="numpy")
|
||||
output_image_paste_back = gr.Image(type="numpy")
|
||||
output_video = gr.Video()
|
||||
output_video_concat = gr.Video()
|
||||
|
||||
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
||||
gr.HTML(load_description(title_md))
|
||||
gr.Markdown(load_description("assets/gradio_description_upload.md"))
|
||||
with gr.Row():
|
||||
with gr.Accordion(open=True, label="Source Portrait"):
|
||||
image_input = gr.Image(type="filepath")
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s9.jpg")],
|
||||
[osp.join(example_portrait_dir, "s6.jpg")],
|
||||
[osp.join(example_portrait_dir, "s10.jpg")],
|
||||
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||
],
|
||||
inputs=[image_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.Accordion(open=True, label="Driving Video"):
|
||||
video_input = gr.Video()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_video_dir, "d0.mp4")],
|
||||
[osp.join(example_video_dir, "d18.mp4")],
|
||||
[osp.join(example_video_dir, "d19.mp4")],
|
||||
[osp.join(example_video_dir, "d14.mp4")],
|
||||
[osp.join(example_video_dir, "d6.mp4")],
|
||||
],
|
||||
inputs=[video_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Accordion(open=False, label="Animation Instructions and Options"):
|
||||
gr.Markdown(load_description("assets/gradio_description_animation.md"))
|
||||
with gr.Row():
|
||||
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
||||
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
|
||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||
with gr.Column():
|
||||
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
||||
output_video.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video"):
|
||||
output_video_concat.render()
|
||||
with gr.Row():
|
||||
# Examples
|
||||
gr.Markdown("## You could also choose the examples below by one click ⬇️")
|
||||
with gr.Row():
|
||||
gr.Examples(
|
||||
examples=data_examples,
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
image_input,
|
||||
video_input,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input
|
||||
],
|
||||
outputs=[output_image, output_image_paste_back],
|
||||
examples_per_page=len(data_examples),
|
||||
cache_examples=False,
|
||||
)
|
||||
gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
|
||||
with gr.Row(visible=True):
|
||||
eye_retargeting_slider.render()
|
||||
lip_retargeting_slider.render()
|
||||
with gr.Row(visible=True):
|
||||
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
|
||||
process_button_reset_retargeting = gr.ClearButton(
|
||||
[
|
||||
eye_retargeting_slider,
|
||||
lip_retargeting_slider,
|
||||
retargeting_input_image,
|
||||
output_image,
|
||||
output_image_paste_back
|
||||
],
|
||||
value="🧹 Clear"
|
||||
)
|
||||
with gr.Row(visible=True):
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Retargeting Input"):
|
||||
retargeting_input_image.render()
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s9.jpg")],
|
||||
[osp.join(example_portrait_dir, "s6.jpg")],
|
||||
[osp.join(example_portrait_dir, "s10.jpg")],
|
||||
[osp.join(example_portrait_dir, "s5.jpg")],
|
||||
[osp.join(example_portrait_dir, "s7.jpg")],
|
||||
[osp.join(example_portrait_dir, "s12.jpg")],
|
||||
],
|
||||
inputs=[retargeting_input_image],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Retargeting Result"):
|
||||
output_image.render()
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="Paste-back Result"):
|
||||
output_image_paste_back.render()
|
||||
# binding functions for buttons
|
||||
process_button_retargeting.click(
|
||||
# fn=gradio_pipeline.execute_image,
|
||||
fn=gpu_wrapped_execute_image,
|
||||
inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
|
||||
outputs=[output_image, output_image_paste_back],
|
||||
show_progress=True
|
||||
)
|
||||
process_button_animation.click(
|
||||
fn=gpu_wrapped_execute_video,
|
||||
inputs=[
|
||||
image_input,
|
||||
video_input,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input
|
||||
],
|
||||
outputs=[output_video, output_video_concat],
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
|
||||
demo.launch(
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
server_name=args.server_name
|
||||
)
|
@ -1,22 +0,0 @@
|
||||
## 2024/07/10
|
||||
|
||||
**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️
|
||||
The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
|
||||
|
||||
### Updates
|
||||
|
||||
- <strong>Audio and video concatenating: </strong> If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
|
||||
|
||||
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
|
||||
|
||||
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option.
|
||||
|
||||
|
||||
### About driving video
|
||||
|
||||
- For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
|
||||
|
||||
|
||||
### Others
|
||||
|
||||
- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).
|
Before Width: | Height: | Size: 801 KiB |
Before Width: | Height: | Size: 6.3 MiB |
Before Width: | Height: | Size: 2.7 MiB |
@ -1,16 +0,0 @@
|
||||
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
1. In the <strong>Animation Options</strong> section, we recommend enabling the <strong>do crop (source)</strong> option if faces occupy a small portion of your image.
|
||||
</div>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
|
||||
</div>
|
||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||
3. If you want to upload your own driving video, <strong>the best practice</strong>:
|
||||
|
||||
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
|
||||
- Focus on the head area, similar to the example videos.
|
||||
- Minimize shoulder movement.
|
||||
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
|
||||
|
||||
</div>
|
@ -1,4 +0,0 @@
|
||||
<br>
|
||||
|
||||
## Retargeting
|
||||
<span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
@ -1,2 +0,0 @@
|
||||
## 🤗 This is the official gradio demo for **LivePortrait**.
|
||||
<div style="font-size: 1.2em;">Please upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) and upload a <strong>Driving Video</strong> (1:1 aspect ratio, or any aspect ratio with <code>do crop (driving video)</code> checked).</div>
|
@ -1,11 +0,0 @@
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
||||
<div>
|
||||
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
|
||||
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
||||
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
|
||||
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
|
||||
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
|
||||
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
60
inference.py
@ -1,60 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
import os.path as osp
|
||||
import tyro
|
||||
import subprocess
|
||||
from src.config.argument_config import ArgumentConfig
|
||||
from src.config.inference_config import InferenceConfig
|
||||
from src.config.crop_config import CropConfig
|
||||
from src.live_portrait_pipeline import LivePortraitPipeline
|
||||
|
||||
|
||||
def partial_fields(target_class, kwargs):
|
||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def fast_check_args(args: ArgumentConfig):
|
||||
if not osp.exists(args.source_image):
|
||||
raise FileNotFoundError(f"source image not found: {args.source_image}")
|
||||
if not osp.exists(args.driving_info):
|
||||
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
|
||||
|
||||
|
||||
def main():
|
||||
# set tyro theme
|
||||
tyro.extras.set_accent_color("bright_cyan")
|
||||
args = tyro.cli(ArgumentConfig)
|
||||
|
||||
if not fast_check_ffmpeg():
|
||||
raise ImportError(
|
||||
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
|
||||
)
|
||||
|
||||
# fast check the args
|
||||
fast_check_args(args)
|
||||
|
||||
# specify configs for inference
|
||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||
|
||||
live_portrait_pipeline = LivePortraitPipeline(
|
||||
inference_cfg=inference_cfg,
|
||||
crop_cfg=crop_cfg
|
||||
)
|
||||
|
||||
# run
|
||||
live_portrait_pipeline.execute(args)
|
||||
# live_portrait_pipeline_cp = torch.compile(live_portrait_pipeline.execute, backend="inductor")
|
||||
# with torch.no_grad():
|
||||
# live_portrait_pipeline_cp(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
195
speed.py
@ -1,195 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Benchmark the inference speed of each module in LivePortrait.
|
||||
|
||||
TODO: heavy GPT style, need to refactor
|
||||
"""
|
||||
|
||||
import torch
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
|
||||
import yaml
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from src.utils.helper import load_model, concat_feat
|
||||
from src.config.inference_config import InferenceConfig
|
||||
|
||||
|
||||
def initialize_inputs(batch_size=1, device_id=0):
|
||||
"""
|
||||
Generate random input tensors and move them to GPU
|
||||
"""
|
||||
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
|
||||
kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
|
||||
kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
|
||||
source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
|
||||
generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
|
||||
eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
|
||||
lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
|
||||
feat_stitching = concat_feat(kp_source, kp_driving).half()
|
||||
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
|
||||
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
|
||||
|
||||
inputs = {
|
||||
'feature_3d': feature_3d,
|
||||
'kp_source': kp_source,
|
||||
'kp_driving': kp_driving,
|
||||
'source_image': source_image,
|
||||
'generator_input': generator_input,
|
||||
'feat_stitching': feat_stitching,
|
||||
'feat_eye': feat_eye,
|
||||
'feat_lip': feat_lip
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def load_and_compile_models(cfg, model_config):
|
||||
"""
|
||||
Load and compile models for inference
|
||||
"""
|
||||
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
||||
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
||||
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
||||
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
||||
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
||||
|
||||
models_with_params = [
|
||||
('Appearance Feature Extractor', appearance_feature_extractor),
|
||||
('Motion Extractor', motion_extractor),
|
||||
('Warping Network', warping_module),
|
||||
('SPADE Decoder', spade_generator)
|
||||
]
|
||||
|
||||
compiled_models = {}
|
||||
for name, model in models_with_params:
|
||||
model = model.half()
|
||||
model = torch.compile(model, mode='max-autotune') # Optimize for inference
|
||||
model.eval() # Switch to evaluation mode
|
||||
compiled_models[name] = model
|
||||
|
||||
retargeting_models = ['stitching', 'eye', 'lip']
|
||||
for retarget in retargeting_models:
|
||||
module = stitching_retargeting_module[retarget].half()
|
||||
module = torch.compile(module, mode='max-autotune') # Optimize for inference
|
||||
module.eval() # Switch to evaluation mode
|
||||
stitching_retargeting_module[retarget] = module
|
||||
|
||||
return compiled_models, stitching_retargeting_module
|
||||
|
||||
|
||||
def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
|
||||
"""
|
||||
Warm up models to prepare them for benchmarking
|
||||
"""
|
||||
print("Warm up start!")
|
||||
with torch.no_grad():
|
||||
for _ in range(10):
|
||||
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
|
||||
compiled_models['Motion Extractor'](inputs['source_image'])
|
||||
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
|
||||
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
|
||||
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
|
||||
stitching_retargeting_module['eye'](inputs['feat_eye'])
|
||||
stitching_retargeting_module['lip'](inputs['feat_lip'])
|
||||
print("Warm up end!")
|
||||
|
||||
|
||||
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
|
||||
"""
|
||||
Measure inference times for each model
|
||||
"""
|
||||
times = {name: [] for name in compiled_models.keys()}
|
||||
times['Stitching and Retargeting Modules'] = []
|
||||
|
||||
overall_times = []
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(100):
|
||||
torch.cuda.synchronize()
|
||||
overall_start = time.time()
|
||||
|
||||
start = time.time()
|
||||
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
|
||||
torch.cuda.synchronize()
|
||||
times['Appearance Feature Extractor'].append(time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
compiled_models['Motion Extractor'](inputs['source_image'])
|
||||
torch.cuda.synchronize()
|
||||
times['Motion Extractor'].append(time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
|
||||
torch.cuda.synchronize()
|
||||
times['Warping Network'].append(time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
|
||||
torch.cuda.synchronize()
|
||||
times['SPADE Decoder'].append(time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
|
||||
stitching_retargeting_module['eye'](inputs['feat_eye'])
|
||||
stitching_retargeting_module['lip'](inputs['feat_lip'])
|
||||
torch.cuda.synchronize()
|
||||
times['Stitching and Retargeting Modules'].append(time.time() - start)
|
||||
|
||||
overall_times.append(time.time() - overall_start)
|
||||
|
||||
return times, overall_times
|
||||
|
||||
|
||||
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
|
||||
"""
|
||||
Print benchmark results with average and standard deviation of inference times
|
||||
"""
|
||||
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
|
||||
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
|
||||
|
||||
for name, model in compiled_models.items():
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
num_params_in_millions = num_params / 1e6
|
||||
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
|
||||
|
||||
for index, retarget in enumerate(retargeting_models):
|
||||
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
|
||||
num_params_in_millions = num_params / 1e6
|
||||
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
|
||||
|
||||
for name, avg_time in average_times.items():
|
||||
std_time = std_times[name]
|
||||
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to benchmark speed and model parameters
|
||||
"""
|
||||
# Load configuration
|
||||
cfg = InferenceConfig()
|
||||
model_config_path = cfg.models_config
|
||||
with open(model_config_path, 'r') as file:
|
||||
model_config = yaml.safe_load(file)
|
||||
|
||||
# Sample input tensors
|
||||
inputs = initialize_inputs(device_id = cfg.device_id)
|
||||
|
||||
# Load and compile models
|
||||
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
|
||||
|
||||
# Warm up models
|
||||
warm_up_models(compiled_models, stitching_retargeting_module, inputs)
|
||||
|
||||
# Measure inference times
|
||||
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
|
||||
|
||||
# Print benchmark results
|
||||
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,117 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Pipeline for gradio
|
||||
"""
|
||||
import gradio as gr
|
||||
|
||||
from .config.argument_config import ArgumentConfig
|
||||
from .live_portrait_pipeline import LivePortraitPipeline
|
||||
from .utils.io import load_img_online
|
||||
from .utils.rprint import rlog as log
|
||||
from .utils.crop import prepare_paste_back, paste_back
|
||||
from .utils.camera import get_rotation_matrix
|
||||
|
||||
|
||||
def update_args(args, user_args):
|
||||
"""update the args according to user inputs
|
||||
"""
|
||||
for k, v in user_args.items():
|
||||
if hasattr(args, k):
|
||||
setattr(args, k, v)
|
||||
return args
|
||||
|
||||
|
||||
class GradioPipeline(LivePortraitPipeline):
|
||||
|
||||
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
|
||||
super().__init__(inference_cfg, crop_cfg)
|
||||
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
||||
self.args = args
|
||||
|
||||
def execute_video(
|
||||
self,
|
||||
input_image_path,
|
||||
input_video_path,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input
|
||||
):
|
||||
""" for video driven potrait animation
|
||||
"""
|
||||
if input_image_path is not None and input_video_path is not None:
|
||||
args_user = {
|
||||
'source_image': input_image_path,
|
||||
'driving_info': input_video_path,
|
||||
'flag_relative': flag_relative_input,
|
||||
'flag_do_crop': flag_do_crop_input,
|
||||
'flag_pasteback': flag_remap_input,
|
||||
'flag_crop_driving_video': flag_crop_driving_video_input
|
||||
}
|
||||
# update config from user input
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.live_portrait_wrapper.update_config(self.args.__dict__)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
# video driven animation
|
||||
video_path, video_path_concat = self.execute(self.args)
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return video_path, video_path_concat,
|
||||
else:
|
||||
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
||||
|
||||
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
|
||||
""" for single image retargeting
|
||||
"""
|
||||
# disposable feature
|
||||
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
|
||||
self.prepare_retargeting(input_image, flag_do_crop)
|
||||
|
||||
if input_eye_ratio is None or input_lip_ratio is None:
|
||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||
else:
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
|
||||
f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
|
||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
|
||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
|
||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
|
||||
num_kp = x_s_user.shape[1]
|
||||
# default: use x_s
|
||||
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
||||
# D(W(f_s; x_s, x′_d))
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
|
||||
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return out, out_to_ori_blend
|
||||
|
||||
def prepare_retargeting(self, input_image, flag_do_crop=True):
|
||||
""" for single image retargeting
|
||||
"""
|
||||
if input_image is not None:
|
||||
# gr.Info("Upload successfully!", duration=2)
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
######## process source portrait ########
|
||||
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
|
||||
log(f"Load source image from {input_image}.")
|
||||
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
|
||||
if flag_do_crop:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
||||
else:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
############################################
|
||||
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
source_lmk_user = crop_info['lmk_crop']
|
||||
crop_M_c2o = crop_info['M_c2o']
|
||||
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
|
||||
else:
|
||||
# when press the clear button, go here
|
||||
raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)
|
@ -1,285 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Pipeline of LivePortrait
|
||||
"""
|
||||
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
|
||||
|
||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
from rich.progress import track
|
||||
|
||||
from .config.argument_config import ArgumentConfig
|
||||
from .config.inference_config import InferenceConfig
|
||||
from .config.crop_config import CropConfig
|
||||
from .utils.cropper import Cropper
|
||||
from .utils.camera import get_rotation_matrix
|
||||
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
|
||||
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
||||
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load
|
||||
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix
|
||||
from .utils.rprint import rlog as log
|
||||
# from .utils.viz import viz_lmk
|
||||
from .live_portrait_wrapper import LivePortraitWrapper
|
||||
|
||||
|
||||
def make_abs_path(fn):
|
||||
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
||||
|
||||
|
||||
class LivePortraitPipeline(object):
|
||||
|
||||
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
||||
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
|
||||
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
|
||||
|
||||
def execute(self, args: ArgumentConfig):
|
||||
# for convenience
|
||||
inf_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
device = self.live_portrait_wrapper.device
|
||||
crop_cfg = self.cropper.crop_cfg
|
||||
|
||||
######## process source portrait ########
|
||||
img_rgb = load_image_rgb(args.source_image)
|
||||
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
|
||||
log(f"Load source image from {args.source_image}")
|
||||
|
||||
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
|
||||
if crop_info is None:
|
||||
raise Exception("No face detected in the source image!")
|
||||
source_lmk = crop_info['lmk_crop']
|
||||
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
||||
|
||||
if inf_cfg.flag_do_crop:
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||
else:
|
||||
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
|
||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||
x_c_s = x_s_info['kp']
|
||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||
|
||||
flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite
|
||||
if flag_lip_zero:
|
||||
# let lip-open scalar to be 0 at first
|
||||
c_d_lip_before_animation = [0.]
|
||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
||||
if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold:
|
||||
flag_lip_zero = False
|
||||
else:
|
||||
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
||||
############################################
|
||||
|
||||
######## process driving info ########
|
||||
flag_load_from_template = is_template(args.driving_info)
|
||||
driving_rgb_crop_256x256_lst = None
|
||||
wfp_template = None
|
||||
|
||||
if flag_load_from_template:
|
||||
# NOTE: load from template, it is fast, but the cropping video is None
|
||||
log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
|
||||
template_dct = load(args.driving_info)
|
||||
n_frames = template_dct['n_frames']
|
||||
|
||||
# set output_fps
|
||||
output_fps = template_dct.get('output_fps', inf_cfg.output_fps)
|
||||
log(f'The FPS of template: {output_fps}')
|
||||
|
||||
if args.flag_crop_driving_video:
|
||||
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
||||
|
||||
elif osp.exists(args.driving_info) and is_video(args.driving_info):
|
||||
# load from video file, AND make motion template
|
||||
log(f"Load video: {args.driving_info}")
|
||||
if osp.isdir(args.driving_info):
|
||||
output_fps = inf_cfg.output_fps
|
||||
else:
|
||||
output_fps = int(get_fps(args.driving_info))
|
||||
log(f'The FPS of {args.driving_info} is: {output_fps}')
|
||||
|
||||
log(f"Load video file (mp4 mov avi etc...): {args.driving_info}")
|
||||
driving_rgb_lst = load_driving_info(args.driving_info)
|
||||
|
||||
######## make motion template ########
|
||||
log("Start making motion template...")
|
||||
if inf_cfg.flag_crop_driving_video:
|
||||
ret = self.cropper.crop_driving_video(driving_rgb_lst)
|
||||
log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.')
|
||||
driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst']
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
|
||||
else:
|
||||
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
|
||||
|
||||
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst)
|
||||
# save the motion template
|
||||
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst)
|
||||
template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
|
||||
|
||||
wfp_template = remove_suffix(args.driving_info) + '.pkl'
|
||||
dump(wfp_template, template_dct)
|
||||
log(f"Dump motion template to {wfp_template}")
|
||||
|
||||
n_frames = I_d_lst.shape[0]
|
||||
else:
|
||||
raise Exception(f"{args.driving_info} not exists or unsupported driving info types!")
|
||||
#########################################
|
||||
|
||||
######## prepare for pasteback ########
|
||||
I_p_pstbk_lst = None
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||
I_p_pstbk_lst = []
|
||||
log("Prepared pasteback mask done.")
|
||||
#########################################
|
||||
|
||||
I_p_lst = []
|
||||
R_d_0, x_d_0_info = None, None
|
||||
|
||||
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
|
||||
x_d_i_info = template_dct['motion'][i]
|
||||
x_d_i_info = dct2device(x_d_i_info, device)
|
||||
R_d_i = x_d_i_info['R_d']
|
||||
|
||||
if i == 0:
|
||||
R_d_0 = R_d_i
|
||||
x_d_0_info = x_d_i_info
|
||||
|
||||
if inf_cfg.flag_relative_motion:
|
||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||
else:
|
||||
R_new = R_d_i
|
||||
delta_new = x_d_i_info['exp']
|
||||
scale_new = x_s_info['scale']
|
||||
t_new = x_d_i_info['t']
|
||||
|
||||
t_new[..., 2].fill_(0) # zero tz
|
||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||
|
||||
# Algorithm 1:
|
||||
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
|
||||
# without stitching or retargeting
|
||||
if flag_lip_zero:
|
||||
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
||||
else:
|
||||
pass
|
||||
elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
|
||||
# with stitching and without retargeting
|
||||
if flag_lip_zero:
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
||||
else:
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
||||
else:
|
||||
eyes_delta, lip_delta = None, None
|
||||
if inf_cfg.flag_eye_retargeting:
|
||||
c_d_eyes_i = c_d_eyes_lst[i]
|
||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
||||
if inf_cfg.flag_lip_retargeting:
|
||||
c_d_lip_i = c_d_lip_lst[i]
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
||||
|
||||
if inf_cfg.flag_relative_motion: # use x_s
|
||||
x_d_i_new = x_s + \
|
||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
||||
else: # use x_d,i
|
||||
x_d_i_new = x_d_i_new + \
|
||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
||||
|
||||
if inf_cfg.flag_stitching:
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
||||
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
I_p_lst.append(I_p_i)
|
||||
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
|
||||
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU
|
||||
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
|
||||
I_p_pstbk_lst.append(I_p_pstbk)
|
||||
|
||||
mkdir(args.output_dir)
|
||||
wfp_concat = None
|
||||
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
|
||||
|
||||
######### build final concact result #########
|
||||
# driving frame | source image | generation, or source image | generation
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||
|
||||
if flag_has_audio:
|
||||
# final result with concact
|
||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
|
||||
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
|
||||
|
||||
# save drived result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||
|
||||
######### build final result #########
|
||||
if flag_has_audio:
|
||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
|
||||
add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp} with {wfp_with_audio}")
|
||||
|
||||
# final log
|
||||
if wfp_template not in (None, ''):
|
||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||
log(f'Animated video: {wfp}')
|
||||
log(f'Animated video with concact: {wfp_concat}')
|
||||
|
||||
return wfp, wfp_concat
|
||||
|
||||
def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs):
|
||||
n_frames = I_d_lst.shape[0]
|
||||
template_dct = {
|
||||
'n_frames': n_frames,
|
||||
'output_fps': kwargs.get('output_fps', 25),
|
||||
'motion': [],
|
||||
'c_d_eyes_lst': [],
|
||||
'c_d_lip_lst': [],
|
||||
}
|
||||
|
||||
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
|
||||
# collect s_d, R_d, δ_d and t_d for inference
|
||||
I_d_i = I_d_lst[i]
|
||||
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
||||
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
||||
|
||||
item_dct = {
|
||||
'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32),
|
||||
'R_d': R_d_i.cpu().numpy().astype(np.float32),
|
||||
'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32),
|
||||
't': x_d_i_info['t'].cpu().numpy().astype(np.float32),
|
||||
}
|
||||
|
||||
template_dct['motion'].append(item_dct)
|
||||
|
||||
c_d_eyes = c_d_eyes_lst[i].astype(np.float32)
|
||||
template_dct['c_d_eyes_lst'].append(c_d_eyes)
|
||||
|
||||
c_d_lip = c_d_lip_lst[i].astype(np.float32)
|
||||
template_dct['c_d_lip_lst'].append(c_d_lip)
|
||||
|
||||
return template_dct
|
@ -1,333 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Wrapper for LivePortrait core functions
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from .utils.timer import Timer
|
||||
from .utils.helper import load_model, concat_feat
|
||||
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
||||
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
||||
from .config.inference_config import InferenceConfig
|
||||
from .utils.rprint import rlog as log
|
||||
|
||||
|
||||
class LivePortraitWrapper(object):
|
||||
|
||||
def __init__(self, inference_cfg: InferenceConfig):
|
||||
|
||||
self.inference_cfg = inference_cfg
|
||||
self.device_id = inference_cfg.device_id
|
||||
self.compile = inference_cfg.flag_do_torch_compile
|
||||
if inference_cfg.flag_force_cpu:
|
||||
self.device = 'cpu'
|
||||
else:
|
||||
self.device = 'cuda:' + str(self.device_id)
|
||||
|
||||
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
||||
jit = False
|
||||
jit = True
|
||||
if jit:
|
||||
self.appearance_feature_extractor = torch.jit.load("build/appearance_feature_extractor.pt")
|
||||
self.motion_extractor = torch.jit.load("build/motion_extractor.pt")
|
||||
self.warping_module = torch.jit.load("build/warping_module.pt")
|
||||
self.spade_generator = torch.jit.load("build/spade_generator.pt")
|
||||
|
||||
eyes = torch.jit.load("build/stitching_retargeting_module_eye.pt")
|
||||
lips = torch.jit.load("build/stitching_retargeting_module_lip.pt")
|
||||
stitching = torch.jit.load("build/stitching_retargeting_module_stitching.pt")
|
||||
self.stitching_retargeting_module = {'eye': eyes, 'lip': lips, 'stitching': stitching}
|
||||
else:
|
||||
# init F
|
||||
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
|
||||
log(f'Load appearance_feature_extractor done.')
|
||||
# init M
|
||||
self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
|
||||
log(f'Load motion_extractor done.')
|
||||
# init W
|
||||
self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
|
||||
log(f'Load warping_module done.')
|
||||
# init G
|
||||
self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
|
||||
log(f'Load spade_generator done.')
|
||||
# init S and R
|
||||
if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
|
||||
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
|
||||
log(f'Load stitching_retargeting_module done.')
|
||||
else:
|
||||
self.stitching_retargeting_module = None
|
||||
|
||||
# Optimize for inference
|
||||
if self.compile:
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
|
||||
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
|
||||
|
||||
self.timer = Timer()
|
||||
|
||||
def update_config(self, user_args):
|
||||
for k, v in user_args.items():
|
||||
if hasattr(self.inference_cfg, k):
|
||||
setattr(self.inference_cfg, k, v)
|
||||
|
||||
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
|
||||
""" construct the input as standard
|
||||
img: HxWx3, uint8, 256x256
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
|
||||
x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
|
||||
else:
|
||||
x = img.copy()
|
||||
|
||||
if x.ndim == 3:
|
||||
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
|
||||
elif x.ndim == 4:
|
||||
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
|
||||
else:
|
||||
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
||||
x = np.clip(x, 0, 1) # clip to 0~1
|
||||
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
||||
x = x.to(self.device)
|
||||
return x
|
||||
|
||||
def prepare_driving_videos(self, imgs) -> torch.Tensor:
|
||||
""" construct the input as standard
|
||||
imgs: NxBxHxWx3, uint8
|
||||
"""
|
||||
if isinstance(imgs, list):
|
||||
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
|
||||
elif isinstance(imgs, np.ndarray):
|
||||
_imgs = imgs
|
||||
else:
|
||||
raise ValueError(f'imgs type error: {type(imgs)}')
|
||||
|
||||
y = _imgs.astype(np.float32) / 255.
|
||||
y = np.clip(y, 0, 1) # clip to 0~1
|
||||
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
||||
y = y.to(self.device)
|
||||
|
||||
return y
|
||||
|
||||
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" get the appearance feature of the image by F
|
||||
x: Bx3xHxW, normalized to 0~1
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
||||
feature_3d = self.appearance_feature_extractor(x)
|
||||
|
||||
return feature_3d.float()
|
||||
|
||||
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
|
||||
""" get the implicit keypoint information
|
||||
x: Bx3xHxW, normalized to 0~1
|
||||
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
|
||||
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
||||
kp_info = self.motion_extractor(x)
|
||||
|
||||
if self.inference_cfg.flag_use_half_precision:
|
||||
# float the dict
|
||||
for k, v in kp_info.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
kp_info[k] = v.float()
|
||||
|
||||
flag_refine_info: bool = kwargs.get('flag_refine_info', True)
|
||||
if flag_refine_info:
|
||||
bs = kp_info['kp'].shape[0]
|
||||
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
|
||||
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
|
||||
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
|
||||
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
|
||||
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
|
||||
|
||||
return kp_info
|
||||
|
||||
def get_pose_dct(self, kp_info: dict) -> dict:
|
||||
pose_dct = dict(
|
||||
pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
|
||||
yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
|
||||
roll=headpose_pred_to_degree(kp_info['roll']).item(),
|
||||
)
|
||||
return pose_dct
|
||||
|
||||
def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
|
||||
|
||||
# get the canonical keypoints of source image by M
|
||||
source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
|
||||
source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
|
||||
|
||||
# get the canonical keypoints of first driving frame by M
|
||||
driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
|
||||
driving_first_frame_rotation = get_rotation_matrix(
|
||||
driving_first_frame_kp_info['pitch'],
|
||||
driving_first_frame_kp_info['yaw'],
|
||||
driving_first_frame_kp_info['roll']
|
||||
)
|
||||
|
||||
# get feature volume by F
|
||||
source_feature_3d = self.extract_feature_3d(source_prepared)
|
||||
|
||||
return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
|
||||
|
||||
def transform_keypoint(self, kp_info: dict):
|
||||
"""
|
||||
transform the implicit keypoints with the pose, shift, and expression deformation
|
||||
kp: BxNx3
|
||||
"""
|
||||
kp = kp_info['kp'] # (bs, k, 3)
|
||||
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
|
||||
|
||||
t, exp = kp_info['t'], kp_info['exp']
|
||||
scale = kp_info['scale']
|
||||
|
||||
pitch = headpose_pred_to_degree(pitch)
|
||||
yaw = headpose_pred_to_degree(yaw)
|
||||
roll = headpose_pred_to_degree(roll)
|
||||
|
||||
bs = kp.shape[0]
|
||||
if kp.ndim == 2:
|
||||
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
|
||||
else:
|
||||
num_kp = kp.shape[1] # Bxnum_kpx3
|
||||
|
||||
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
|
||||
|
||||
# Eqn.2: s * (R * x_c,s + exp) + t
|
||||
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
|
||||
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
|
||||
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
|
||||
|
||||
return kp_transformed
|
||||
|
||||
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
kp_source: BxNx3
|
||||
eye_close_ratio: Bx3
|
||||
Return: Bx(3*num_kp)
|
||||
"""
|
||||
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
||||
|
||||
with torch.no_grad():
|
||||
delta = self.stitching_retargeting_module['eye'](feat_eye)
|
||||
|
||||
return delta
|
||||
|
||||
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
kp_source: BxNx3
|
||||
lip_close_ratio: Bx2
|
||||
Return: Bx(3*num_kp)
|
||||
"""
|
||||
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
||||
|
||||
with torch.no_grad():
|
||||
delta = self.stitching_retargeting_module['lip'](feat_lip)
|
||||
|
||||
return delta
|
||||
|
||||
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
kp_source: BxNx3
|
||||
kp_driving: BxNx3
|
||||
Return: Bx(3*num_kp+2)
|
||||
"""
|
||||
feat_stiching = concat_feat(kp_source, kp_driving)
|
||||
|
||||
with torch.no_grad():
|
||||
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
|
||||
|
||||
return delta
|
||||
|
||||
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
||||
""" conduct the stitching
|
||||
kp_source: Bxnum_kpx3
|
||||
kp_driving: Bxnum_kpx3
|
||||
"""
|
||||
|
||||
if self.stitching_retargeting_module is not None:
|
||||
|
||||
bs, num_kp = kp_source.shape[:2]
|
||||
|
||||
kp_driving_new = kp_driving.clone()
|
||||
delta = self.stitch(kp_source, kp_driving_new)
|
||||
|
||||
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
||||
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
|
||||
|
||||
kp_driving_new += delta_exp
|
||||
kp_driving_new[..., :2] += delta_tx_ty
|
||||
|
||||
return kp_driving_new
|
||||
|
||||
return kp_driving
|
||||
|
||||
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
||||
""" get the image after the warping of the implicit keypoints
|
||||
feature_3d: Bx32x16x64x64, feature volume
|
||||
kp_source: BxNx3
|
||||
kp_driving: BxNx3
|
||||
"""
|
||||
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
|
||||
if self.compile:
|
||||
# Mark the beginning of a new CUDA Graph step
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
# get decoder input
|
||||
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
||||
# decode
|
||||
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
||||
|
||||
# float the dict
|
||||
if self.inference_cfg.flag_use_half_precision:
|
||||
for k, v in ret_dct.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
ret_dct[k] = v.float()
|
||||
|
||||
return ret_dct
|
||||
|
||||
def parse_output(self, out: torch.Tensor) -> np.ndarray:
|
||||
""" construct the output as standard
|
||||
return: 1xHxWx3, uint8
|
||||
"""
|
||||
out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
|
||||
out = np.clip(out, 0, 1) # clip to 0~1
|
||||
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
|
||||
|
||||
return out
|
||||
|
||||
def calc_driving_ratio(self, driving_lmk_lst):
|
||||
input_eye_ratio_lst = []
|
||||
input_lip_ratio_lst = []
|
||||
for lmk in driving_lmk_lst:
|
||||
# for eyes retargeting
|
||||
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
|
||||
# for lip retargeting
|
||||
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
||||
return input_eye_ratio_lst, input_lip_ratio_lst
|
||||
|
||||
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
|
||||
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
|
||||
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
|
||||
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
|
||||
# [c_s,eyes, c_d,eyes,i]
|
||||
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
|
||||
return combined_eye_ratio_tensor
|
||||
|
||||
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
|
||||
c_s_lip = calc_lip_close_ratio(source_lmk[None])
|
||||
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
|
||||
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
|
||||
# [c_s,lip, c_d,lip,i]
|
||||
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
|
||||
return combined_lip_ratio_tensor
|
@ -1,20 +0,0 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=wrong-import-position
|
||||
"""InsightFace: A Face Analysis Toolkit."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
try:
|
||||
#import mxnet as mx
|
||||
import onnxruntime
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Unable to import dependency onnxruntime. "
|
||||
)
|
||||
|
||||
__version__ = '0.7.3'
|
||||
|
||||
from . import model_zoo
|
||||
from . import utils
|
||||
from . import app
|
||||
from . import data
|
||||
|
@ -1 +0,0 @@
|
||||
from .face_analysis import *
|
@ -1,49 +0,0 @@
|
||||
import numpy as np
|
||||
from numpy.linalg import norm as l2norm
|
||||
#from easydict import EasyDict
|
||||
|
||||
class Face(dict):
|
||||
|
||||
def __init__(self, d=None, **kwargs):
|
||||
if d is None:
|
||||
d = {}
|
||||
if kwargs:
|
||||
d.update(**kwargs)
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
# Class attributes
|
||||
#for k in self.__class__.__dict__.keys():
|
||||
# if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
|
||||
# setattr(self, k, getattr(self, k))
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = [self.__class__(x)
|
||||
if isinstance(x, dict) else x for x in value]
|
||||
elif isinstance(value, dict) and not isinstance(value, self.__class__):
|
||||
value = self.__class__(value)
|
||||
super(Face, self).__setattr__(name, value)
|
||||
super(Face, self).__setitem__(name, value)
|
||||
|
||||
__setitem__ = __setattr__
|
||||
|
||||
def __getattr__(self, name):
|
||||
return None
|
||||
|
||||
@property
|
||||
def embedding_norm(self):
|
||||
if self.embedding is None:
|
||||
return None
|
||||
return l2norm(self.embedding)
|
||||
|
||||
@property
|
||||
def normed_embedding(self):
|
||||
if self.embedding is None:
|
||||
return None
|
||||
return self.embedding / self.embedding_norm
|
||||
|
||||
@property
|
||||
def sex(self):
|
||||
if self.gender is None:
|
||||
return None
|
||||
return 'M' if self.gender==1 else 'F'
|
@ -1,110 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import glob
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
from numpy.linalg import norm
|
||||
|
||||
from ..model_zoo import model_zoo
|
||||
from ..utils import ensure_available
|
||||
from .common import Face
|
||||
|
||||
|
||||
DEFAULT_MP_NAME = 'buffalo_l'
|
||||
__all__ = ['FaceAnalysis']
|
||||
|
||||
class FaceAnalysis:
|
||||
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
self.models = {}
|
||||
self.model_dir = ensure_available('models', name, root=root)
|
||||
onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
|
||||
onnx_files = sorted(onnx_files)
|
||||
for onnx_file in onnx_files:
|
||||
model = model_zoo.get_model(onnx_file, **kwargs)
|
||||
if model is None:
|
||||
print('model not recognized:', onnx_file)
|
||||
elif allowed_modules is not None and model.taskname not in allowed_modules:
|
||||
print('model ignore:', onnx_file, model.taskname)
|
||||
del model
|
||||
elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
|
||||
# print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std)
|
||||
self.models[model.taskname] = model
|
||||
else:
|
||||
print('duplicated model task type, ignore:', onnx_file, model.taskname)
|
||||
del model
|
||||
assert 'detection' in self.models
|
||||
self.det_model = self.models['detection']
|
||||
|
||||
|
||||
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
|
||||
self.det_thresh = det_thresh
|
||||
assert det_size is not None
|
||||
# print('set det-size:', det_size)
|
||||
self.det_size = det_size
|
||||
for taskname, model in self.models.items():
|
||||
if taskname=='detection':
|
||||
model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
|
||||
else:
|
||||
model.prepare(ctx_id)
|
||||
|
||||
def get(self, img, max_num=0):
|
||||
bboxes, kpss = self.det_model.detect(img,
|
||||
max_num=max_num,
|
||||
metric='default')
|
||||
if bboxes.shape[0] == 0:
|
||||
return []
|
||||
ret = []
|
||||
for i in range(bboxes.shape[0]):
|
||||
bbox = bboxes[i, 0:4]
|
||||
det_score = bboxes[i, 4]
|
||||
kps = None
|
||||
if kpss is not None:
|
||||
kps = kpss[i]
|
||||
face = Face(bbox=bbox, kps=kps, det_score=det_score)
|
||||
for taskname, model in self.models.items():
|
||||
if taskname=='detection':
|
||||
continue
|
||||
model.get(img, face)
|
||||
ret.append(face)
|
||||
return ret
|
||||
|
||||
def draw_on(self, img, faces):
|
||||
import cv2
|
||||
dimg = img.copy()
|
||||
for i in range(len(faces)):
|
||||
face = faces[i]
|
||||
box = face.bbox.astype(np.int)
|
||||
color = (0, 0, 255)
|
||||
cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2)
|
||||
if face.kps is not None:
|
||||
kps = face.kps.astype(np.int)
|
||||
#print(landmark.shape)
|
||||
for l in range(kps.shape[0]):
|
||||
color = (0, 0, 255)
|
||||
if l == 0 or l == 3:
|
||||
color = (0, 255, 0)
|
||||
cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color,
|
||||
2)
|
||||
if face.gender is not None and face.age is not None:
|
||||
cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1)
|
||||
|
||||
#for key, value in face.items():
|
||||
# if key.startswith('landmark_3d'):
|
||||
# print(key, value.shape)
|
||||
# print(value[0:10,:])
|
||||
# lmk = np.round(value).astype(np.int)
|
||||
# for l in range(lmk.shape[0]):
|
||||
# color = (255, 0, 0)
|
||||
# cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color,
|
||||
# 2)
|
||||
return dimg
|
@ -1,2 +0,0 @@
|
||||
from .image import get_image
|
||||
from .pickle_object import get_object
|
@ -1,27 +0,0 @@
|
||||
import cv2
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
|
||||
class ImageCache:
|
||||
data = {}
|
||||
|
||||
def get_image(name, to_rgb=False):
|
||||
key = (name, to_rgb)
|
||||
if key in ImageCache.data:
|
||||
return ImageCache.data[key]
|
||||
images_dir = osp.join(Path(__file__).parent.absolute(), 'images')
|
||||
ext_names = ['.jpg', '.png', '.jpeg']
|
||||
image_file = None
|
||||
for ext_name in ext_names:
|
||||
_image_file = osp.join(images_dir, "%s%s"%(name, ext_name))
|
||||
if osp.exists(_image_file):
|
||||
image_file = _image_file
|
||||
break
|
||||
assert image_file is not None, '%s not found'%name
|
||||
img = cv2.imread(image_file)
|
||||
if to_rgb:
|
||||
img = img[:,:,::-1]
|
||||
ImageCache.data[key] = img
|
||||
return img
|
||||
|
Before Width: | Height: | Size: 12 KiB |
Before Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 44 KiB |
Before Width: | Height: | Size: 6.0 KiB |
Before Width: | Height: | Size: 77 KiB |
Before Width: | Height: | Size: 126 KiB |
@ -1,17 +0,0 @@
|
||||
import cv2
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
def get_object(name):
|
||||
objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects')
|
||||
if not name.endswith('.pkl'):
|
||||
name = name+".pkl"
|
||||
filepath = osp.join(objects_dir, name)
|
||||
if not osp.exists(filepath):
|
||||
return None
|
||||
with open(filepath, 'rb') as f:
|
||||
obj = pickle.load(f)
|
||||
return obj
|
||||
|
@ -1,71 +0,0 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import mxnet as mx
|
||||
|
||||
|
||||
class RecBuilder():
|
||||
def __init__(self, path, image_size=(112, 112)):
|
||||
self.path = path
|
||||
self.image_size = image_size
|
||||
self.widx = 0
|
||||
self.wlabel = 0
|
||||
self.max_label = -1
|
||||
assert not osp.exists(path), '%s exists' % path
|
||||
os.makedirs(path)
|
||||
self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'),
|
||||
os.path.join(path, 'train.rec'),
|
||||
'w')
|
||||
self.meta = []
|
||||
|
||||
def add(self, imgs):
|
||||
#!!! img should be BGR!!!!
|
||||
#assert label >= 0
|
||||
#assert label > self.last_label
|
||||
assert len(imgs) > 0
|
||||
label = self.wlabel
|
||||
for img in imgs:
|
||||
idx = self.widx
|
||||
image_meta = {'image_index': idx, 'image_classes': [label]}
|
||||
header = mx.recordio.IRHeader(0, label, idx, 0)
|
||||
if isinstance(img, np.ndarray):
|
||||
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
|
||||
else:
|
||||
s = mx.recordio.pack(header, img)
|
||||
self.writer.write_idx(idx, s)
|
||||
self.meta.append(image_meta)
|
||||
self.widx += 1
|
||||
self.max_label = label
|
||||
self.wlabel += 1
|
||||
|
||||
|
||||
def add_image(self, img, label):
|
||||
#!!! img should be BGR!!!!
|
||||
#assert label >= 0
|
||||
#assert label > self.last_label
|
||||
idx = self.widx
|
||||
header = mx.recordio.IRHeader(0, label, idx, 0)
|
||||
if isinstance(label, list):
|
||||
idlabel = label[0]
|
||||
else:
|
||||
idlabel = label
|
||||
image_meta = {'image_index': idx, 'image_classes': [idlabel]}
|
||||
if isinstance(img, np.ndarray):
|
||||
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
|
||||
else:
|
||||
s = mx.recordio.pack(header, img)
|
||||
self.writer.write_idx(idx, s)
|
||||
self.meta.append(image_meta)
|
||||
self.widx += 1
|
||||
self.max_label = max(self.max_label, idlabel)
|
||||
|
||||
def close(self):
|
||||
with open(osp.join(self.path, 'train.meta'), 'wb') as pfile:
|
||||
pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
print('stat:', self.widx, self.wlabel)
|
||||
with open(os.path.join(self.path, 'property'), 'w') as f:
|
||||
f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1]))
|
||||
f.write("%d\n" % (self.widx))
|
||||
|
@ -1,6 +0,0 @@
|
||||
from .model_zoo import get_model
|
||||
from .arcface_onnx import ArcFaceONNX
|
||||
from .retinaface import RetinaFace
|
||||
from .scrfd import SCRFD
|
||||
from .landmark import Landmark
|
||||
from .attribute import Attribute
|
@ -1,92 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from ..utils import face_align
|
||||
|
||||
__all__ = [
|
||||
'ArcFaceONNX',
|
||||
]
|
||||
|
||||
|
||||
class ArcFaceONNX:
|
||||
def __init__(self, model_file=None, session=None):
|
||||
assert model_file is not None
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
self.taskname = 'recognition'
|
||||
find_sub = False
|
||||
find_mul = False
|
||||
model = onnx.load(self.model_file)
|
||||
graph = model.graph
|
||||
for nid, node in enumerate(graph.node[:8]):
|
||||
#print(nid, node.name)
|
||||
if node.name.startswith('Sub') or node.name.startswith('_minus'):
|
||||
find_sub = True
|
||||
if node.name.startswith('Mul') or node.name.startswith('_mul'):
|
||||
find_mul = True
|
||||
if find_sub and find_mul:
|
||||
#mxnet arcface model
|
||||
input_mean = 0.0
|
||||
input_std = 1.0
|
||||
else:
|
||||
input_mean = 127.5
|
||||
input_std = 127.5
|
||||
self.input_mean = input_mean
|
||||
self.input_std = input_std
|
||||
#print('input mean and std:', self.input_mean, self.input_std)
|
||||
if self.session is None:
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
input_name = input_cfg.name
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
self.input_shape = input_shape
|
||||
outputs = self.session.get_outputs()
|
||||
output_names = []
|
||||
for out in outputs:
|
||||
output_names.append(out.name)
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
assert len(self.output_names)==1
|
||||
self.output_shape = outputs[0].shape
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
|
||||
def get(self, img, face):
|
||||
aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])
|
||||
face.embedding = self.get_feat(aimg).flatten()
|
||||
return face.embedding
|
||||
|
||||
def compute_sim(self, feat1, feat2):
|
||||
from numpy.linalg import norm
|
||||
feat1 = feat1.ravel()
|
||||
feat2 = feat2.ravel()
|
||||
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
|
||||
return sim
|
||||
|
||||
def get_feat(self, imgs):
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
input_size = self.input_size
|
||||
|
||||
blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
|
||||
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
|
||||
return net_out
|
||||
|
||||
def forward(self, batch_data):
|
||||
blob = (batch_data - self.input_mean) / self.input_std
|
||||
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
|
||||
return net_out
|
||||
|
||||
|
@ -1,94 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-06-19
|
||||
# @Function :
|
||||
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from ..utils import face_align
|
||||
|
||||
__all__ = [
|
||||
'Attribute',
|
||||
]
|
||||
|
||||
|
||||
class Attribute:
|
||||
def __init__(self, model_file=None, session=None):
|
||||
assert model_file is not None
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
find_sub = False
|
||||
find_mul = False
|
||||
model = onnx.load(self.model_file)
|
||||
graph = model.graph
|
||||
for nid, node in enumerate(graph.node[:8]):
|
||||
#print(nid, node.name)
|
||||
if node.name.startswith('Sub') or node.name.startswith('_minus'):
|
||||
find_sub = True
|
||||
if node.name.startswith('Mul') or node.name.startswith('_mul'):
|
||||
find_mul = True
|
||||
if nid<3 and node.name=='bn_data':
|
||||
find_sub = True
|
||||
find_mul = True
|
||||
if find_sub and find_mul:
|
||||
#mxnet arcface model
|
||||
input_mean = 0.0
|
||||
input_std = 1.0
|
||||
else:
|
||||
input_mean = 127.5
|
||||
input_std = 128.0
|
||||
self.input_mean = input_mean
|
||||
self.input_std = input_std
|
||||
#print('input mean and std:', model_file, self.input_mean, self.input_std)
|
||||
if self.session is None:
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
input_name = input_cfg.name
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
self.input_shape = input_shape
|
||||
outputs = self.session.get_outputs()
|
||||
output_names = []
|
||||
for out in outputs:
|
||||
output_names.append(out.name)
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
assert len(self.output_names)==1
|
||||
output_shape = outputs[0].shape
|
||||
#print('init output_shape:', output_shape)
|
||||
if output_shape[1]==3:
|
||||
self.taskname = 'genderage'
|
||||
else:
|
||||
self.taskname = 'attribute_%d'%output_shape[1]
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
|
||||
def get(self, img, face):
|
||||
bbox = face.bbox
|
||||
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
|
||||
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
|
||||
rotate = 0
|
||||
_scale = self.input_size[0] / (max(w, h)*1.5)
|
||||
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
|
||||
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
|
||||
input_size = tuple(aimg.shape[0:2][::-1])
|
||||
#assert input_size==self.input_size
|
||||
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
|
||||
if self.taskname=='genderage':
|
||||
assert len(pred)==3
|
||||
gender = np.argmax(pred[:2])
|
||||
age = int(np.round(pred[2]*100))
|
||||
face['gender'] = gender
|
||||
face['age'] = age
|
||||
return gender, age
|
||||
else:
|
||||
return pred
|
||||
|
||||
|
@ -1,114 +0,0 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import cv2
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from ..utils import face_align
|
||||
|
||||
|
||||
|
||||
|
||||
class INSwapper():
|
||||
def __init__(self, model_file=None, session=None):
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
model = onnx.load(self.model_file)
|
||||
graph = model.graph
|
||||
self.emap = numpy_helper.to_array(graph.initializer[-1])
|
||||
self.input_mean = 0.0
|
||||
self.input_std = 255.0
|
||||
#print('input mean and std:', model_file, self.input_mean, self.input_std)
|
||||
if self.session is None:
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
inputs = self.session.get_inputs()
|
||||
self.input_names = []
|
||||
for inp in inputs:
|
||||
self.input_names.append(inp.name)
|
||||
outputs = self.session.get_outputs()
|
||||
output_names = []
|
||||
for out in outputs:
|
||||
output_names.append(out.name)
|
||||
self.output_names = output_names
|
||||
assert len(self.output_names)==1
|
||||
output_shape = outputs[0].shape
|
||||
input_cfg = inputs[0]
|
||||
input_shape = input_cfg.shape
|
||||
self.input_shape = input_shape
|
||||
# print('inswapper-shape:', self.input_shape)
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
|
||||
def forward(self, img, latent):
|
||||
img = (img - self.input_mean) / self.input_std
|
||||
pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0]
|
||||
return pred
|
||||
|
||||
def get(self, img, target_face, source_face, paste_back=True):
|
||||
face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
|
||||
cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1)
|
||||
aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0])
|
||||
blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
|
||||
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
latent = source_face.normed_embedding.reshape((1,-1))
|
||||
latent = np.dot(latent, self.emap)
|
||||
latent /= np.linalg.norm(latent)
|
||||
pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0]
|
||||
#print(latent.shape, latent.dtype, pred.shape)
|
||||
img_fake = pred.transpose((0,2,3,1))[0]
|
||||
bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
|
||||
if not paste_back:
|
||||
return bgr_fake, M
|
||||
else:
|
||||
target_img = img
|
||||
fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32)
|
||||
fake_diff = np.abs(fake_diff).mean(axis=2)
|
||||
fake_diff[:2,:] = 0
|
||||
fake_diff[-2:,:] = 0
|
||||
fake_diff[:,:2] = 0
|
||||
fake_diff[:,-2:] = 0
|
||||
IM = cv2.invertAffineTransform(M)
|
||||
img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32)
|
||||
bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
|
||||
img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
|
||||
fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
|
||||
img_white[img_white>20] = 255
|
||||
fthresh = 10
|
||||
fake_diff[fake_diff<fthresh] = 0
|
||||
fake_diff[fake_diff>=fthresh] = 255
|
||||
img_mask = img_white
|
||||
mask_h_inds, mask_w_inds = np.where(img_mask==255)
|
||||
mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
|
||||
mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
|
||||
mask_size = int(np.sqrt(mask_h*mask_w))
|
||||
k = max(mask_size//10, 10)
|
||||
#k = max(mask_size//20, 6)
|
||||
#k = 6
|
||||
kernel = np.ones((k,k),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel = np.ones((2,2),np.uint8)
|
||||
fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1)
|
||||
|
||||
face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1)
|
||||
fake_diff[face_mask==1] = 255
|
||||
|
||||
k = max(mask_size//20, 5)
|
||||
#k = 3
|
||||
#k = 3
|
||||
kernel_size = (k, k)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
k = 5
|
||||
kernel_size = (k, k)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
fake_diff = cv2.blur(fake_diff, (11,11), 0)
|
||||
##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0)
|
||||
# print('blur_size: ', blur_size)
|
||||
# fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size
|
||||
img_mask /= 255
|
||||
fake_diff /= 255
|
||||
# img_mask = fake_diff
|
||||
img_mask = img_mask*fake_diff
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32)
|
||||
fake_merged = fake_merged.astype(np.uint8)
|
||||
return fake_merged
|
@ -1,114 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from ..utils import face_align
|
||||
from ..utils import transform
|
||||
from ..data import get_object
|
||||
|
||||
__all__ = [
|
||||
'Landmark',
|
||||
]
|
||||
|
||||
|
||||
class Landmark:
|
||||
def __init__(self, model_file=None, session=None):
|
||||
assert model_file is not None
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
find_sub = False
|
||||
find_mul = False
|
||||
model = onnx.load(self.model_file)
|
||||
graph = model.graph
|
||||
for nid, node in enumerate(graph.node[:8]):
|
||||
#print(nid, node.name)
|
||||
if node.name.startswith('Sub') or node.name.startswith('_minus'):
|
||||
find_sub = True
|
||||
if node.name.startswith('Mul') or node.name.startswith('_mul'):
|
||||
find_mul = True
|
||||
if nid<3 and node.name=='bn_data':
|
||||
find_sub = True
|
||||
find_mul = True
|
||||
if find_sub and find_mul:
|
||||
#mxnet arcface model
|
||||
input_mean = 0.0
|
||||
input_std = 1.0
|
||||
else:
|
||||
input_mean = 127.5
|
||||
input_std = 128.0
|
||||
self.input_mean = input_mean
|
||||
self.input_std = input_std
|
||||
#print('input mean and std:', model_file, self.input_mean, self.input_std)
|
||||
if self.session is None:
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
input_name = input_cfg.name
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
self.input_shape = input_shape
|
||||
outputs = self.session.get_outputs()
|
||||
output_names = []
|
||||
for out in outputs:
|
||||
output_names.append(out.name)
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
assert len(self.output_names)==1
|
||||
output_shape = outputs[0].shape
|
||||
self.require_pose = False
|
||||
#print('init output_shape:', output_shape)
|
||||
if output_shape[1]==3309:
|
||||
self.lmk_dim = 3
|
||||
self.lmk_num = 68
|
||||
self.mean_lmk = get_object('meanshape_68.pkl')
|
||||
self.require_pose = True
|
||||
else:
|
||||
self.lmk_dim = 2
|
||||
self.lmk_num = output_shape[1]//self.lmk_dim
|
||||
self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num)
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
|
||||
def get(self, img, face):
|
||||
bbox = face.bbox
|
||||
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
|
||||
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
|
||||
rotate = 0
|
||||
_scale = self.input_size[0] / (max(w, h)*1.5)
|
||||
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
|
||||
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
|
||||
input_size = tuple(aimg.shape[0:2][::-1])
|
||||
#assert input_size==self.input_size
|
||||
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
|
||||
if pred.shape[0] >= 3000:
|
||||
pred = pred.reshape((-1, 3))
|
||||
else:
|
||||
pred = pred.reshape((-1, 2))
|
||||
if self.lmk_num < pred.shape[0]:
|
||||
pred = pred[self.lmk_num*-1:,:]
|
||||
pred[:, 0:2] += 1
|
||||
pred[:, 0:2] *= (self.input_size[0] // 2)
|
||||
if pred.shape[1] == 3:
|
||||
pred[:, 2] *= (self.input_size[0] // 2)
|
||||
|
||||
IM = cv2.invertAffineTransform(M)
|
||||
pred = face_align.trans_points(pred, IM)
|
||||
face[self.taskname] = pred
|
||||
if self.require_pose:
|
||||
P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred)
|
||||
s, R, t = transform.P2sRt(P)
|
||||
rx, ry, rz = transform.matrix2angle(R)
|
||||
pose = np.array( [rx, ry, rz], dtype=np.float32 )
|
||||
face['pose'] = pose #pitch, yaw, roll
|
||||
return pred
|
||||
|
||||
|
@ -1,103 +0,0 @@
|
||||
"""
|
||||
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['get_model_file']
|
||||
import os
|
||||
import zipfile
|
||||
import glob
|
||||
|
||||
from ..utils import download, check_sha1
|
||||
|
||||
_model_sha1 = {
|
||||
name: checksum
|
||||
for checksum, name in [
|
||||
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
|
||||
('', 'arcface_mfn_v1'),
|
||||
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
|
||||
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
|
||||
('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'),
|
||||
('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'),
|
||||
]
|
||||
}
|
||||
|
||||
base_repo_url = 'https://insightface.ai/files/'
|
||||
_url_format = '{repo_url}models/{file_name}.zip'
|
||||
|
||||
|
||||
def short_hash(name):
|
||||
if name not in _model_sha1:
|
||||
raise ValueError(
|
||||
'Pretrained model for {name} is not available.'.format(name=name))
|
||||
return _model_sha1[name][:8]
|
||||
|
||||
|
||||
def find_params_file(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
return None
|
||||
paths = glob.glob("%s/*.params" % dir_path)
|
||||
if len(paths) == 0:
|
||||
return None
|
||||
paths = sorted(paths)
|
||||
return paths[-1]
|
||||
|
||||
|
||||
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
|
||||
r"""Return location for the pretrained on local file system.
|
||||
|
||||
This function will download from online model zoo when model cannot be found or has mismatch.
|
||||
The root directory will be created if it doesn't exist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name of the model.
|
||||
root : str, default '~/.mxnet/models'
|
||||
Location for keeping the model parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
file_path
|
||||
Path to the requested pretrained model file.
|
||||
"""
|
||||
|
||||
file_name = name
|
||||
root = os.path.expanduser(root)
|
||||
dir_path = os.path.join(root, name)
|
||||
file_path = find_params_file(dir_path)
|
||||
#file_path = os.path.join(root, file_name + '.params')
|
||||
sha1_hash = _model_sha1[name]
|
||||
if file_path is not None:
|
||||
if check_sha1(file_path, sha1_hash):
|
||||
return file_path
|
||||
else:
|
||||
print(
|
||||
'Mismatch in the content of model file detected. Downloading again.'
|
||||
)
|
||||
else:
|
||||
print('Model file is not found. Downloading.')
|
||||
|
||||
if not os.path.exists(root):
|
||||
os.makedirs(root)
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
zip_file_path = os.path.join(root, file_name + '.zip')
|
||||
repo_url = base_repo_url
|
||||
if repo_url[-1] != '/':
|
||||
repo_url = repo_url + '/'
|
||||
download(_url_format.format(repo_url=repo_url, file_name=file_name),
|
||||
path=zip_file_path,
|
||||
overwrite=True)
|
||||
with zipfile.ZipFile(zip_file_path) as zf:
|
||||
zf.extractall(dir_path)
|
||||
os.remove(zip_file_path)
|
||||
file_path = find_params_file(dir_path)
|
||||
|
||||
if check_sha1(file_path, sha1_hash):
|
||||
return file_path
|
||||
else:
|
||||
raise ValueError(
|
||||
'Downloaded file has different hash. Please try again.')
|
||||
|
@ -1,97 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import onnxruntime
|
||||
from .arcface_onnx import *
|
||||
from .retinaface import *
|
||||
#from .scrfd import *
|
||||
from .landmark import *
|
||||
from .attribute import Attribute
|
||||
from .inswapper import INSwapper
|
||||
from ..utils import download_onnx
|
||||
|
||||
__all__ = ['get_model']
|
||||
|
||||
|
||||
class PickableInferenceSession(onnxruntime.InferenceSession):
|
||||
# This is a wrapper to make the current InferenceSession class pickable.
|
||||
def __init__(self, model_path, **kwargs):
|
||||
super().__init__(model_path, **kwargs)
|
||||
self.model_path = model_path
|
||||
|
||||
def __getstate__(self):
|
||||
return {'model_path': self.model_path}
|
||||
|
||||
def __setstate__(self, values):
|
||||
model_path = values['model_path']
|
||||
self.__init__(model_path)
|
||||
|
||||
class ModelRouter:
|
||||
def __init__(self, onnx_file):
|
||||
self.onnx_file = onnx_file
|
||||
|
||||
def get_model(self, **kwargs):
|
||||
session = PickableInferenceSession(self.onnx_file, **kwargs)
|
||||
# print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
|
||||
inputs = session.get_inputs()
|
||||
input_cfg = inputs[0]
|
||||
input_shape = input_cfg.shape
|
||||
outputs = session.get_outputs()
|
||||
|
||||
if len(outputs)>=5:
|
||||
return RetinaFace(model_file=self.onnx_file, session=session)
|
||||
elif input_shape[2]==192 and input_shape[3]==192:
|
||||
return Landmark(model_file=self.onnx_file, session=session)
|
||||
elif input_shape[2]==96 and input_shape[3]==96:
|
||||
return Attribute(model_file=self.onnx_file, session=session)
|
||||
elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128:
|
||||
return INSwapper(model_file=self.onnx_file, session=session)
|
||||
elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0:
|
||||
return ArcFaceONNX(model_file=self.onnx_file, session=session)
|
||||
else:
|
||||
#raise RuntimeError('error on model routing')
|
||||
return None
|
||||
|
||||
def find_onnx_file(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
return None
|
||||
paths = glob.glob("%s/*.onnx" % dir_path)
|
||||
if len(paths) == 0:
|
||||
return None
|
||||
paths = sorted(paths)
|
||||
return paths[-1]
|
||||
|
||||
def get_default_providers():
|
||||
return ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
def get_default_provider_options():
|
||||
return None
|
||||
|
||||
def get_model(name, **kwargs):
|
||||
root = kwargs.get('root', '~/.insightface')
|
||||
root = os.path.expanduser(root)
|
||||
model_root = osp.join(root, 'models')
|
||||
allow_download = kwargs.get('download', False)
|
||||
download_zip = kwargs.get('download_zip', False)
|
||||
if not name.endswith('.onnx'):
|
||||
model_dir = os.path.join(model_root, name)
|
||||
model_file = find_onnx_file(model_dir)
|
||||
if model_file is None:
|
||||
return None
|
||||
else:
|
||||
model_file = name
|
||||
if not osp.exists(model_file) and allow_download:
|
||||
model_file = download_onnx('models', model_file, root=root, download_zip=download_zip)
|
||||
assert osp.exists(model_file), 'model_file %s should exist'%model_file
|
||||
assert osp.isfile(model_file), 'model_file %s should be a file'%model_file
|
||||
router = ModelRouter(model_file)
|
||||
providers = kwargs.get('providers', get_default_providers())
|
||||
provider_options = kwargs.get('provider_options', get_default_provider_options())
|
||||
model = router.get_model(providers=providers, provider_options=provider_options)
|
||||
return model
|
@ -1,301 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-09-18
|
||||
# @Function :
|
||||
|
||||
from __future__ import division
|
||||
import datetime
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
def softmax(z):
|
||||
assert len(z.shape) == 2
|
||||
s = np.max(z, axis=1)
|
||||
s = s[:, np.newaxis] # necessary step to do broadcasting
|
||||
e_x = np.exp(z - s)
|
||||
div = np.sum(e_x, axis=1)
|
||||
div = div[:, np.newaxis] # dito
|
||||
return e_x / div
|
||||
|
||||
def distance2bbox(points, distance, max_shape=None):
|
||||
"""Decode distance prediction to bounding box.
|
||||
|
||||
Args:
|
||||
points (Tensor): Shape (n, 2), [x, y].
|
||||
distance (Tensor): Distance from the given point to 4
|
||||
boundaries (left, top, right, bottom).
|
||||
max_shape (tuple): Shape of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Decoded bboxes.
|
||||
"""
|
||||
x1 = points[:, 0] - distance[:, 0]
|
||||
y1 = points[:, 1] - distance[:, 1]
|
||||
x2 = points[:, 0] + distance[:, 2]
|
||||
y2 = points[:, 1] + distance[:, 3]
|
||||
if max_shape is not None:
|
||||
x1 = x1.clamp(min=0, max=max_shape[1])
|
||||
y1 = y1.clamp(min=0, max=max_shape[0])
|
||||
x2 = x2.clamp(min=0, max=max_shape[1])
|
||||
y2 = y2.clamp(min=0, max=max_shape[0])
|
||||
return np.stack([x1, y1, x2, y2], axis=-1)
|
||||
|
||||
def distance2kps(points, distance, max_shape=None):
|
||||
"""Decode distance prediction to bounding box.
|
||||
|
||||
Args:
|
||||
points (Tensor): Shape (n, 2), [x, y].
|
||||
distance (Tensor): Distance from the given point to 4
|
||||
boundaries (left, top, right, bottom).
|
||||
max_shape (tuple): Shape of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Decoded bboxes.
|
||||
"""
|
||||
preds = []
|
||||
for i in range(0, distance.shape[1], 2):
|
||||
px = points[:, i%2] + distance[:, i]
|
||||
py = points[:, i%2+1] + distance[:, i+1]
|
||||
if max_shape is not None:
|
||||
px = px.clamp(min=0, max=max_shape[1])
|
||||
py = py.clamp(min=0, max=max_shape[0])
|
||||
preds.append(px)
|
||||
preds.append(py)
|
||||
return np.stack(preds, axis=-1)
|
||||
|
||||
class RetinaFace:
|
||||
def __init__(self, model_file=None, session=None):
|
||||
import onnxruntime
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
self.taskname = 'detection'
|
||||
if self.session is None:
|
||||
assert self.model_file is not None
|
||||
assert osp.exists(self.model_file)
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
self.center_cache = {}
|
||||
self.nms_thresh = 0.4
|
||||
self.det_thresh = 0.5
|
||||
self._init_vars()
|
||||
|
||||
def _init_vars(self):
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
#print(input_shape)
|
||||
if isinstance(input_shape[2], str):
|
||||
self.input_size = None
|
||||
else:
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
#print('image_size:', self.image_size)
|
||||
input_name = input_cfg.name
|
||||
self.input_shape = input_shape
|
||||
outputs = self.session.get_outputs()
|
||||
output_names = []
|
||||
for o in outputs:
|
||||
output_names.append(o.name)
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
self.input_mean = 127.5
|
||||
self.input_std = 128.0
|
||||
#print(self.output_names)
|
||||
#assert len(outputs)==10 or len(outputs)==15
|
||||
self.use_kps = False
|
||||
self._anchor_ratio = 1.0
|
||||
self._num_anchors = 1
|
||||
if len(outputs)==6:
|
||||
self.fmc = 3
|
||||
self._feat_stride_fpn = [8, 16, 32]
|
||||
self._num_anchors = 2
|
||||
elif len(outputs)==9:
|
||||
self.fmc = 3
|
||||
self._feat_stride_fpn = [8, 16, 32]
|
||||
self._num_anchors = 2
|
||||
self.use_kps = True
|
||||
elif len(outputs)==10:
|
||||
self.fmc = 5
|
||||
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
||||
self._num_anchors = 1
|
||||
elif len(outputs)==15:
|
||||
self.fmc = 5
|
||||
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
||||
self._num_anchors = 1
|
||||
self.use_kps = True
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
nms_thresh = kwargs.get('nms_thresh', None)
|
||||
if nms_thresh is not None:
|
||||
self.nms_thresh = nms_thresh
|
||||
det_thresh = kwargs.get('det_thresh', None)
|
||||
if det_thresh is not None:
|
||||
self.det_thresh = det_thresh
|
||||
input_size = kwargs.get('input_size', None)
|
||||
if input_size is not None:
|
||||
if self.input_size is not None:
|
||||
print('warning: det_size is already set in detection model, ignore')
|
||||
else:
|
||||
self.input_size = input_size
|
||||
|
||||
def forward(self, img, threshold):
|
||||
scores_list = []
|
||||
bboxes_list = []
|
||||
kpss_list = []
|
||||
input_size = tuple(img.shape[0:2][::-1])
|
||||
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
net_outs = self.session.run(self.output_names, {self.input_name : blob})
|
||||
|
||||
input_height = blob.shape[2]
|
||||
input_width = blob.shape[3]
|
||||
fmc = self.fmc
|
||||
for idx, stride in enumerate(self._feat_stride_fpn):
|
||||
scores = net_outs[idx]
|
||||
bbox_preds = net_outs[idx+fmc]
|
||||
bbox_preds = bbox_preds * stride
|
||||
if self.use_kps:
|
||||
kps_preds = net_outs[idx+fmc*2] * stride
|
||||
height = input_height // stride
|
||||
width = input_width // stride
|
||||
K = height * width
|
||||
key = (height, width, stride)
|
||||
if key in self.center_cache:
|
||||
anchor_centers = self.center_cache[key]
|
||||
else:
|
||||
#solution-1, c style:
|
||||
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
|
||||
#for i in range(height):
|
||||
# anchor_centers[i, :, 1] = i
|
||||
#for i in range(width):
|
||||
# anchor_centers[:, i, 0] = i
|
||||
|
||||
#solution-2:
|
||||
#ax = np.arange(width, dtype=np.float32)
|
||||
#ay = np.arange(height, dtype=np.float32)
|
||||
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
|
||||
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
|
||||
|
||||
#solution-3:
|
||||
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
|
||||
#print(anchor_centers.shape)
|
||||
|
||||
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
|
||||
if self._num_anchors>1:
|
||||
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
|
||||
if len(self.center_cache)<100:
|
||||
self.center_cache[key] = anchor_centers
|
||||
|
||||
pos_inds = np.where(scores>=threshold)[0]
|
||||
bboxes = distance2bbox(anchor_centers, bbox_preds)
|
||||
pos_scores = scores[pos_inds]
|
||||
pos_bboxes = bboxes[pos_inds]
|
||||
scores_list.append(pos_scores)
|
||||
bboxes_list.append(pos_bboxes)
|
||||
if self.use_kps:
|
||||
kpss = distance2kps(anchor_centers, kps_preds)
|
||||
#kpss = kps_preds
|
||||
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
|
||||
pos_kpss = kpss[pos_inds]
|
||||
kpss_list.append(pos_kpss)
|
||||
return scores_list, bboxes_list, kpss_list
|
||||
|
||||
def detect(self, img, input_size = None, max_num=0, metric='default'):
|
||||
assert input_size is not None or self.input_size is not None
|
||||
input_size = self.input_size if input_size is None else input_size
|
||||
|
||||
im_ratio = float(img.shape[0]) / img.shape[1]
|
||||
model_ratio = float(input_size[1]) / input_size[0]
|
||||
if im_ratio>model_ratio:
|
||||
new_height = input_size[1]
|
||||
new_width = int(new_height / im_ratio)
|
||||
else:
|
||||
new_width = input_size[0]
|
||||
new_height = int(new_width * im_ratio)
|
||||
det_scale = float(new_height) / img.shape[0]
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
|
||||
det_img[:new_height, :new_width, :] = resized_img
|
||||
|
||||
scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
|
||||
|
||||
scores = np.vstack(scores_list)
|
||||
scores_ravel = scores.ravel()
|
||||
order = scores_ravel.argsort()[::-1]
|
||||
bboxes = np.vstack(bboxes_list) / det_scale
|
||||
if self.use_kps:
|
||||
kpss = np.vstack(kpss_list) / det_scale
|
||||
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
|
||||
pre_det = pre_det[order, :]
|
||||
keep = self.nms(pre_det)
|
||||
det = pre_det[keep, :]
|
||||
if self.use_kps:
|
||||
kpss = kpss[order,:,:]
|
||||
kpss = kpss[keep,:,:]
|
||||
else:
|
||||
kpss = None
|
||||
if max_num > 0 and det.shape[0] > max_num:
|
||||
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
|
||||
det[:, 1])
|
||||
img_center = img.shape[0] // 2, img.shape[1] // 2
|
||||
offsets = np.vstack([
|
||||
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
|
||||
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
|
||||
])
|
||||
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
|
||||
if metric=='max':
|
||||
values = area
|
||||
else:
|
||||
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
|
||||
bindex = np.argsort(
|
||||
values)[::-1] # some extra weight on the centering
|
||||
bindex = bindex[0:max_num]
|
||||
det = det[bindex, :]
|
||||
if kpss is not None:
|
||||
kpss = kpss[bindex, :]
|
||||
return det, kpss
|
||||
|
||||
def nms(self, dets):
|
||||
thresh = self.nms_thresh
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
def get_retinaface(name, download=False, root='~/.insightface/models', **kwargs):
|
||||
if not download:
|
||||
assert os.path.exists(name)
|
||||
return RetinaFace(name)
|
||||
else:
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("retinaface_%s" % name, root=root)
|
||||
return retinaface(_file)
|
||||
|
||||
|
@ -1,348 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
from __future__ import division
|
||||
import datetime
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
def softmax(z):
|
||||
assert len(z.shape) == 2
|
||||
s = np.max(z, axis=1)
|
||||
s = s[:, np.newaxis] # necessary step to do broadcasting
|
||||
e_x = np.exp(z - s)
|
||||
div = np.sum(e_x, axis=1)
|
||||
div = div[:, np.newaxis] # dito
|
||||
return e_x / div
|
||||
|
||||
def distance2bbox(points, distance, max_shape=None):
|
||||
"""Decode distance prediction to bounding box.
|
||||
|
||||
Args:
|
||||
points (Tensor): Shape (n, 2), [x, y].
|
||||
distance (Tensor): Distance from the given point to 4
|
||||
boundaries (left, top, right, bottom).
|
||||
max_shape (tuple): Shape of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Decoded bboxes.
|
||||
"""
|
||||
x1 = points[:, 0] - distance[:, 0]
|
||||
y1 = points[:, 1] - distance[:, 1]
|
||||
x2 = points[:, 0] + distance[:, 2]
|
||||
y2 = points[:, 1] + distance[:, 3]
|
||||
if max_shape is not None:
|
||||
x1 = x1.clamp(min=0, max=max_shape[1])
|
||||
y1 = y1.clamp(min=0, max=max_shape[0])
|
||||
x2 = x2.clamp(min=0, max=max_shape[1])
|
||||
y2 = y2.clamp(min=0, max=max_shape[0])
|
||||
return np.stack([x1, y1, x2, y2], axis=-1)
|
||||
|
||||
def distance2kps(points, distance, max_shape=None):
|
||||
"""Decode distance prediction to bounding box.
|
||||
|
||||
Args:
|
||||
points (Tensor): Shape (n, 2), [x, y].
|
||||
distance (Tensor): Distance from the given point to 4
|
||||
boundaries (left, top, right, bottom).
|
||||
max_shape (tuple): Shape of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Decoded bboxes.
|
||||
"""
|
||||
preds = []
|
||||
for i in range(0, distance.shape[1], 2):
|
||||
px = points[:, i%2] + distance[:, i]
|
||||
py = points[:, i%2+1] + distance[:, i+1]
|
||||
if max_shape is not None:
|
||||
px = px.clamp(min=0, max=max_shape[1])
|
||||
py = py.clamp(min=0, max=max_shape[0])
|
||||
preds.append(px)
|
||||
preds.append(py)
|
||||
return np.stack(preds, axis=-1)
|
||||
|
||||
class SCRFD:
|
||||
def __init__(self, model_file=None, session=None):
|
||||
import onnxruntime
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
self.taskname = 'detection'
|
||||
self.batched = False
|
||||
if self.session is None:
|
||||
assert self.model_file is not None
|
||||
assert osp.exists(self.model_file)
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
self.center_cache = {}
|
||||
self.nms_thresh = 0.4
|
||||
self.det_thresh = 0.5
|
||||
self._init_vars()
|
||||
|
||||
def _init_vars(self):
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
#print(input_shape)
|
||||
if isinstance(input_shape[2], str):
|
||||
self.input_size = None
|
||||
else:
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
#print('image_size:', self.image_size)
|
||||
input_name = input_cfg.name
|
||||
self.input_shape = input_shape
|
||||
outputs = self.session.get_outputs()
|
||||
if len(outputs[0].shape) == 3:
|
||||
self.batched = True
|
||||
output_names = []
|
||||
for o in outputs:
|
||||
output_names.append(o.name)
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
self.input_mean = 127.5
|
||||
self.input_std = 128.0
|
||||
#print(self.output_names)
|
||||
#assert len(outputs)==10 or len(outputs)==15
|
||||
self.use_kps = False
|
||||
self._anchor_ratio = 1.0
|
||||
self._num_anchors = 1
|
||||
if len(outputs)==6:
|
||||
self.fmc = 3
|
||||
self._feat_stride_fpn = [8, 16, 32]
|
||||
self._num_anchors = 2
|
||||
elif len(outputs)==9:
|
||||
self.fmc = 3
|
||||
self._feat_stride_fpn = [8, 16, 32]
|
||||
self._num_anchors = 2
|
||||
self.use_kps = True
|
||||
elif len(outputs)==10:
|
||||
self.fmc = 5
|
||||
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
||||
self._num_anchors = 1
|
||||
elif len(outputs)==15:
|
||||
self.fmc = 5
|
||||
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
||||
self._num_anchors = 1
|
||||
self.use_kps = True
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
nms_thresh = kwargs.get('nms_thresh', None)
|
||||
if nms_thresh is not None:
|
||||
self.nms_thresh = nms_thresh
|
||||
det_thresh = kwargs.get('det_thresh', None)
|
||||
if det_thresh is not None:
|
||||
self.det_thresh = det_thresh
|
||||
input_size = kwargs.get('input_size', None)
|
||||
if input_size is not None:
|
||||
if self.input_size is not None:
|
||||
print('warning: det_size is already set in scrfd model, ignore')
|
||||
else:
|
||||
self.input_size = input_size
|
||||
|
||||
def forward(self, img, threshold):
|
||||
scores_list = []
|
||||
bboxes_list = []
|
||||
kpss_list = []
|
||||
input_size = tuple(img.shape[0:2][::-1])
|
||||
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
net_outs = self.session.run(self.output_names, {self.input_name : blob})
|
||||
|
||||
input_height = blob.shape[2]
|
||||
input_width = blob.shape[3]
|
||||
fmc = self.fmc
|
||||
for idx, stride in enumerate(self._feat_stride_fpn):
|
||||
# If model support batch dim, take first output
|
||||
if self.batched:
|
||||
scores = net_outs[idx][0]
|
||||
bbox_preds = net_outs[idx + fmc][0]
|
||||
bbox_preds = bbox_preds * stride
|
||||
if self.use_kps:
|
||||
kps_preds = net_outs[idx + fmc * 2][0] * stride
|
||||
# If model doesn't support batching take output as is
|
||||
else:
|
||||
scores = net_outs[idx]
|
||||
bbox_preds = net_outs[idx + fmc]
|
||||
bbox_preds = bbox_preds * stride
|
||||
if self.use_kps:
|
||||
kps_preds = net_outs[idx + fmc * 2] * stride
|
||||
|
||||
height = input_height // stride
|
||||
width = input_width // stride
|
||||
K = height * width
|
||||
key = (height, width, stride)
|
||||
if key in self.center_cache:
|
||||
anchor_centers = self.center_cache[key]
|
||||
else:
|
||||
#solution-1, c style:
|
||||
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
|
||||
#for i in range(height):
|
||||
# anchor_centers[i, :, 1] = i
|
||||
#for i in range(width):
|
||||
# anchor_centers[:, i, 0] = i
|
||||
|
||||
#solution-2:
|
||||
#ax = np.arange(width, dtype=np.float32)
|
||||
#ay = np.arange(height, dtype=np.float32)
|
||||
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
|
||||
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
|
||||
|
||||
#solution-3:
|
||||
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
|
||||
#print(anchor_centers.shape)
|
||||
|
||||
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
|
||||
if self._num_anchors>1:
|
||||
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
|
||||
if len(self.center_cache)<100:
|
||||
self.center_cache[key] = anchor_centers
|
||||
|
||||
pos_inds = np.where(scores>=threshold)[0]
|
||||
bboxes = distance2bbox(anchor_centers, bbox_preds)
|
||||
pos_scores = scores[pos_inds]
|
||||
pos_bboxes = bboxes[pos_inds]
|
||||
scores_list.append(pos_scores)
|
||||
bboxes_list.append(pos_bboxes)
|
||||
if self.use_kps:
|
||||
kpss = distance2kps(anchor_centers, kps_preds)
|
||||
#kpss = kps_preds
|
||||
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
|
||||
pos_kpss = kpss[pos_inds]
|
||||
kpss_list.append(pos_kpss)
|
||||
return scores_list, bboxes_list, kpss_list
|
||||
|
||||
def detect(self, img, input_size = None, max_num=0, metric='default'):
|
||||
assert input_size is not None or self.input_size is not None
|
||||
input_size = self.input_size if input_size is None else input_size
|
||||
|
||||
im_ratio = float(img.shape[0]) / img.shape[1]
|
||||
model_ratio = float(input_size[1]) / input_size[0]
|
||||
if im_ratio>model_ratio:
|
||||
new_height = input_size[1]
|
||||
new_width = int(new_height / im_ratio)
|
||||
else:
|
||||
new_width = input_size[0]
|
||||
new_height = int(new_width * im_ratio)
|
||||
det_scale = float(new_height) / img.shape[0]
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
|
||||
det_img[:new_height, :new_width, :] = resized_img
|
||||
|
||||
scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
|
||||
|
||||
scores = np.vstack(scores_list)
|
||||
scores_ravel = scores.ravel()
|
||||
order = scores_ravel.argsort()[::-1]
|
||||
bboxes = np.vstack(bboxes_list) / det_scale
|
||||
if self.use_kps:
|
||||
kpss = np.vstack(kpss_list) / det_scale
|
||||
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
|
||||
pre_det = pre_det[order, :]
|
||||
keep = self.nms(pre_det)
|
||||
det = pre_det[keep, :]
|
||||
if self.use_kps:
|
||||
kpss = kpss[order,:,:]
|
||||
kpss = kpss[keep,:,:]
|
||||
else:
|
||||
kpss = None
|
||||
if max_num > 0 and det.shape[0] > max_num:
|
||||
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
|
||||
det[:, 1])
|
||||
img_center = img.shape[0] // 2, img.shape[1] // 2
|
||||
offsets = np.vstack([
|
||||
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
|
||||
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
|
||||
])
|
||||
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
|
||||
if metric=='max':
|
||||
values = area
|
||||
else:
|
||||
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
|
||||
bindex = np.argsort(
|
||||
values)[::-1] # some extra weight on the centering
|
||||
bindex = bindex[0:max_num]
|
||||
det = det[bindex, :]
|
||||
if kpss is not None:
|
||||
kpss = kpss[bindex, :]
|
||||
return det, kpss
|
||||
|
||||
def nms(self, dets):
|
||||
thresh = self.nms_thresh
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs):
|
||||
if not download:
|
||||
assert os.path.exists(name)
|
||||
return SCRFD(name)
|
||||
else:
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("scrfd_%s" % name, root=root)
|
||||
return SCRFD(_file)
|
||||
|
||||
|
||||
def scrfd_2p5gkps(**kwargs):
|
||||
return get_scrfd("2p5gkps", download=True, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import glob
|
||||
detector = SCRFD(model_file='./det.onnx')
|
||||
detector.prepare(-1)
|
||||
img_paths = ['tests/data/t1.jpg']
|
||||
for img_path in img_paths:
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
for _ in range(1):
|
||||
ta = datetime.datetime.now()
|
||||
#bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640))
|
||||
bboxes, kpss = detector.detect(img, 0.5)
|
||||
tb = datetime.datetime.now()
|
||||
print('all cost:', (tb-ta).total_seconds()*1000)
|
||||
print(img_path, bboxes.shape)
|
||||
if kpss is not None:
|
||||
print(kpss.shape)
|
||||
for i in range(bboxes.shape[0]):
|
||||
bbox = bboxes[i]
|
||||
x1,y1,x2,y2,score = bbox.astype(np.int)
|
||||
cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2)
|
||||
if kpss is not None:
|
||||
kps = kpss[i]
|
||||
for kp in kps:
|
||||
kp = kp.astype(np.int)
|
||||
cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2)
|
||||
filename = img_path.split('/')[-1]
|
||||
print('output:', filename)
|
||||
cv2.imwrite('./outputs/%s'%filename, img)
|
||||
|
@ -1,6 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .storage import download, ensure_available, download_onnx
|
||||
from .filesystem import get_model_dir
|
||||
from .filesystem import makedirs, try_import_dali
|
||||
from .constant import *
|
@ -1,3 +0,0 @@
|
||||
|
||||
DEFAULT_MP_NAME = 'buffalo_l'
|
||||
|
@ -1,95 +0,0 @@
|
||||
"""
|
||||
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def check_sha1(filename, sha1_hash):
|
||||
"""Check whether the sha1 hash of the file content matches the expected hash.
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
Path to the file.
|
||||
sha1_hash : str
|
||||
Expected sha1 hash in hexadecimal digits.
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Whether the file content matches the expected hash.
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(filename, 'rb') as f:
|
||||
while True:
|
||||
data = f.read(1048576)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
|
||||
sha1_file = sha1.hexdigest()
|
||||
l = min(len(sha1_file), len(sha1_hash))
|
||||
return sha1.hexdigest()[0:l] == sha1_hash[0:l]
|
||||
|
||||
|
||||
def download_file(url, path=None, overwrite=False, sha1_hash=None):
|
||||
"""Download an given URL
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
URL to download
|
||||
path : str, optional
|
||||
Destination path to store downloaded file. By default stores to the
|
||||
current directory with same name as in url.
|
||||
overwrite : bool, optional
|
||||
Whether to overwrite destination file if already exists.
|
||||
sha1_hash : str, optional
|
||||
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
|
||||
but doesn't match.
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The file path of the downloaded file.
|
||||
"""
|
||||
if path is None:
|
||||
fname = url.split('/')[-1]
|
||||
else:
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isdir(path):
|
||||
fname = os.path.join(path, url.split('/')[-1])
|
||||
else:
|
||||
fname = path
|
||||
|
||||
if overwrite or not os.path.exists(fname) or (
|
||||
sha1_hash and not check_sha1(fname, sha1_hash)):
|
||||
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
print('Downloading %s from %s...' % (fname, url))
|
||||
r = requests.get(url, stream=True)
|
||||
if r.status_code != 200:
|
||||
raise RuntimeError("Failed downloading url %s" % url)
|
||||
total_length = r.headers.get('content-length')
|
||||
with open(fname, 'wb') as f:
|
||||
if total_length is None: # no content length header
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
else:
|
||||
total_length = int(total_length)
|
||||
for chunk in tqdm(r.iter_content(chunk_size=1024),
|
||||
total=int(total_length / 1024. + 0.5),
|
||||
unit='KB',
|
||||
unit_scale=False,
|
||||
dynamic_ncols=True):
|
||||
f.write(chunk)
|
||||
|
||||
if sha1_hash and not check_sha1(fname, sha1_hash):
|
||||
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
|
||||
'The repo may be outdated or download may be incomplete. ' \
|
||||
'If the "repo_url" is overridden, consider switching to ' \
|
||||
'the default repo.'.format(fname))
|
||||
|
||||
return fname
|
@ -1,103 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage import transform as trans
|
||||
|
||||
|
||||
arcface_dst = np.array(
|
||||
[[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
|
||||
[41.5493, 92.3655], [70.7299, 92.2041]],
|
||||
dtype=np.float32)
|
||||
|
||||
def estimate_norm(lmk, image_size=112,mode='arcface'):
|
||||
assert lmk.shape == (5, 2)
|
||||
assert image_size%112==0 or image_size%128==0
|
||||
if image_size%112==0:
|
||||
ratio = float(image_size)/112.0
|
||||
diff_x = 0
|
||||
else:
|
||||
ratio = float(image_size)/128.0
|
||||
diff_x = 8.0*ratio
|
||||
dst = arcface_dst * ratio
|
||||
dst[:,0] += diff_x
|
||||
tform = trans.SimilarityTransform()
|
||||
tform.estimate(lmk, dst)
|
||||
M = tform.params[0:2, :]
|
||||
return M
|
||||
|
||||
def norm_crop(img, landmark, image_size=112, mode='arcface'):
|
||||
M = estimate_norm(landmark, image_size, mode)
|
||||
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
|
||||
return warped
|
||||
|
||||
def norm_crop2(img, landmark, image_size=112, mode='arcface'):
|
||||
M = estimate_norm(landmark, image_size, mode)
|
||||
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
|
||||
return warped, M
|
||||
|
||||
def square_crop(im, S):
|
||||
if im.shape[0] > im.shape[1]:
|
||||
height = S
|
||||
width = int(float(im.shape[1]) / im.shape[0] * S)
|
||||
scale = float(S) / im.shape[0]
|
||||
else:
|
||||
width = S
|
||||
height = int(float(im.shape[0]) / im.shape[1] * S)
|
||||
scale = float(S) / im.shape[1]
|
||||
resized_im = cv2.resize(im, (width, height))
|
||||
det_im = np.zeros((S, S, 3), dtype=np.uint8)
|
||||
det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
|
||||
return det_im, scale
|
||||
|
||||
|
||||
def transform(data, center, output_size, scale, rotation):
|
||||
scale_ratio = scale
|
||||
rot = float(rotation) * np.pi / 180.0
|
||||
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
|
||||
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
||||
cx = center[0] * scale_ratio
|
||||
cy = center[1] * scale_ratio
|
||||
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
||||
t3 = trans.SimilarityTransform(rotation=rot)
|
||||
t4 = trans.SimilarityTransform(translation=(output_size / 2,
|
||||
output_size / 2))
|
||||
t = t1 + t2 + t3 + t4
|
||||
M = t.params[0:2]
|
||||
cropped = cv2.warpAffine(data,
|
||||
M, (output_size, output_size),
|
||||
borderValue=0.0)
|
||||
return cropped, M
|
||||
|
||||
|
||||
def trans_points2d(pts, M):
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i] = new_pt[0:2]
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points3d(pts, M):
|
||||
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
|
||||
#print(scale)
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i][0:2] = new_pt[0:2]
|
||||
new_pts[i][2] = pts[i][2] * scale
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points(pts, M):
|
||||
if pts.shape[1] == 2:
|
||||
return trans_points2d(pts, M)
|
||||
else:
|
||||
return trans_points3d(pts, M)
|
||||
|
@ -1,157 +0,0 @@
|
||||
"""
|
||||
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/filesystem.py
|
||||
"""
|
||||
import os
|
||||
import os.path as osp
|
||||
import errno
|
||||
|
||||
|
||||
def get_model_dir(name, root='~/.insightface'):
|
||||
root = os.path.expanduser(root)
|
||||
model_dir = osp.join(root, 'models', name)
|
||||
return model_dir
|
||||
|
||||
def makedirs(path):
|
||||
"""Create directory recursively if not exists.
|
||||
Similar to `makedir -p`, you can skip checking existence before this function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path of the desired dir
|
||||
"""
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def try_import(package, message=None):
|
||||
"""Try import specified package, with custom message support.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
package : str
|
||||
The name of the targeting package.
|
||||
message : str, default is None
|
||||
If not None, this function will raise customized error message when import error is found.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
module if found, raise ImportError otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
return __import__(package)
|
||||
except ImportError as e:
|
||||
if not message:
|
||||
raise e
|
||||
raise ImportError(message)
|
||||
|
||||
|
||||
def try_import_cv2():
|
||||
"""Try import cv2 at runtime.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cv2 module if found. Raise ImportError otherwise
|
||||
|
||||
"""
|
||||
msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \
|
||||
or `pip install opencv-python --user` (note that this is unofficial PYPI package)."
|
||||
|
||||
return try_import('cv2', msg)
|
||||
|
||||
|
||||
def try_import_mmcv():
|
||||
"""Try import mmcv at runtime.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mmcv module if found. Raise ImportError otherwise
|
||||
|
||||
"""
|
||||
msg = "mmcv is required, you can install by first `pip install Cython --user` \
|
||||
and then `pip install mmcv --user` (note that this is unofficial PYPI package)."
|
||||
|
||||
return try_import('mmcv', msg)
|
||||
|
||||
|
||||
def try_import_rarfile():
|
||||
"""Try import rarfile at runtime.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rarfile module if found. Raise ImportError otherwise
|
||||
|
||||
"""
|
||||
msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \
|
||||
and then `pip install rarfile --user` (note that this is unofficial PYPI package)."
|
||||
|
||||
return try_import('rarfile', msg)
|
||||
|
||||
|
||||
def import_try_install(package, extern_url=None):
|
||||
"""Try import the specified package.
|
||||
If the package not installed, try use pip to install and import if success.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
package : str
|
||||
The name of the package trying to import.
|
||||
extern_url : str or None, optional
|
||||
The external url if package is not hosted on PyPI.
|
||||
For example, you can install a package using:
|
||||
"pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx".
|
||||
In this case, you can pass the url to the extern_url.
|
||||
|
||||
Returns
|
||||
-------
|
||||
<class 'Module'>
|
||||
The imported python module.
|
||||
|
||||
"""
|
||||
try:
|
||||
return __import__(package)
|
||||
except ImportError:
|
||||
try:
|
||||
from pip import main as pipmain
|
||||
except ImportError:
|
||||
from pip._internal import main as pipmain
|
||||
|
||||
# trying to install package
|
||||
url = package if extern_url is None else extern_url
|
||||
pipmain(['install', '--user',
|
||||
url]) # will raise SystemExit Error if fails
|
||||
|
||||
# trying to load again
|
||||
try:
|
||||
return __import__(package)
|
||||
except ImportError:
|
||||
import sys
|
||||
import site
|
||||
user_site = site.getusersitepackages()
|
||||
if user_site not in sys.path:
|
||||
sys.path.append(user_site)
|
||||
return __import__(package)
|
||||
return __import__(package)
|
||||
|
||||
|
||||
def try_import_dali():
|
||||
"""Try import NVIDIA DALI at runtime.
|
||||
"""
|
||||
try:
|
||||
dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types'])
|
||||
dali.Pipeline = dali.pipeline.Pipeline
|
||||
except ImportError:
|
||||
|
||||
class dali:
|
||||
class Pipeline:
|
||||
def __init__(self):
|
||||
raise NotImplementedError(
|
||||
"DALI not found, please check if you installed it correctly."
|
||||
)
|
||||
|
||||
return dali
|
@ -1,52 +0,0 @@
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import zipfile
|
||||
from .download import download_file
|
||||
|
||||
BASE_REPO_URL = 'https://github.com/deepinsight/insightface/releases/download/v0.7'
|
||||
|
||||
def download(sub_dir, name, force=False, root='~/.insightface'):
|
||||
_root = os.path.expanduser(root)
|
||||
dir_path = os.path.join(_root, sub_dir, name)
|
||||
if osp.exists(dir_path) and not force:
|
||||
return dir_path
|
||||
print('download_path:', dir_path)
|
||||
zip_file_path = os.path.join(_root, sub_dir, name + '.zip')
|
||||
model_url = "%s/%s.zip"%(BASE_REPO_URL, name)
|
||||
download_file(model_url,
|
||||
path=zip_file_path,
|
||||
overwrite=True)
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
with zipfile.ZipFile(zip_file_path) as zf:
|
||||
zf.extractall(dir_path)
|
||||
#os.remove(zip_file_path)
|
||||
return dir_path
|
||||
|
||||
def ensure_available(sub_dir, name, root='~/.insightface'):
|
||||
return download(sub_dir, name, force=False, root=root)
|
||||
|
||||
def download_onnx(sub_dir, model_file, force=False, root='~/.insightface', download_zip=False):
|
||||
_root = os.path.expanduser(root)
|
||||
model_root = osp.join(_root, sub_dir)
|
||||
new_model_file = osp.join(model_root, model_file)
|
||||
if osp.exists(new_model_file) and not force:
|
||||
return new_model_file
|
||||
if not osp.exists(model_root):
|
||||
os.makedirs(model_root)
|
||||
print('download_path:', new_model_file)
|
||||
if not download_zip:
|
||||
model_url = "%s/%s"%(BASE_REPO_URL, model_file)
|
||||
download_file(model_url,
|
||||
path=new_model_file,
|
||||
overwrite=True)
|
||||
else:
|
||||
model_url = "%s/%s.zip"%(BASE_REPO_URL, model_file)
|
||||
zip_file_path = new_model_file+".zip"
|
||||
download_file(model_url,
|
||||
path=zip_file_path,
|
||||
overwrite=True)
|
||||
with zipfile.ZipFile(zip_file_path) as zf:
|
||||
zf.extractall(model_root)
|
||||
return new_model_file
|
@ -1,116 +0,0 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
from skimage import transform as trans
|
||||
|
||||
|
||||
def transform(data, center, output_size, scale, rotation):
|
||||
scale_ratio = scale
|
||||
rot = float(rotation) * np.pi / 180.0
|
||||
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
|
||||
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
||||
cx = center[0] * scale_ratio
|
||||
cy = center[1] * scale_ratio
|
||||
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
||||
t3 = trans.SimilarityTransform(rotation=rot)
|
||||
t4 = trans.SimilarityTransform(translation=(output_size / 2,
|
||||
output_size / 2))
|
||||
t = t1 + t2 + t3 + t4
|
||||
M = t.params[0:2]
|
||||
cropped = cv2.warpAffine(data,
|
||||
M, (output_size, output_size),
|
||||
borderValue=0.0)
|
||||
return cropped, M
|
||||
|
||||
|
||||
def trans_points2d(pts, M):
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i] = new_pt[0:2]
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points3d(pts, M):
|
||||
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
|
||||
#print(scale)
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i][0:2] = new_pt[0:2]
|
||||
new_pts[i][2] = pts[i][2] * scale
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points(pts, M):
|
||||
if pts.shape[1] == 2:
|
||||
return trans_points2d(pts, M)
|
||||
else:
|
||||
return trans_points3d(pts, M)
|
||||
|
||||
def estimate_affine_matrix_3d23d(X, Y):
|
||||
''' Using least-squares solution
|
||||
Args:
|
||||
X: [n, 3]. 3d points(fixed)
|
||||
Y: [n, 3]. corresponding 3d points(moving). Y = PX
|
||||
Returns:
|
||||
P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]).
|
||||
'''
|
||||
X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4
|
||||
P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4
|
||||
return P
|
||||
|
||||
def P2sRt(P):
|
||||
''' decompositing camera matrix P
|
||||
Args:
|
||||
P: (3, 4). Affine Camera Matrix.
|
||||
Returns:
|
||||
s: scale factor.
|
||||
R: (3, 3). rotation matrix.
|
||||
t: (3,). translation.
|
||||
'''
|
||||
t = P[:, 3]
|
||||
R1 = P[0:1, :3]
|
||||
R2 = P[1:2, :3]
|
||||
s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0
|
||||
r1 = R1/np.linalg.norm(R1)
|
||||
r2 = R2/np.linalg.norm(R2)
|
||||
r3 = np.cross(r1, r2)
|
||||
|
||||
R = np.concatenate((r1, r2, r3), 0)
|
||||
return s, R, t
|
||||
|
||||
def matrix2angle(R):
|
||||
''' get three Euler angles from Rotation Matrix
|
||||
Args:
|
||||
R: (3,3). rotation matrix
|
||||
Returns:
|
||||
x: pitch
|
||||
y: yaw
|
||||
z: roll
|
||||
'''
|
||||
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
|
||||
|
||||
singular = sy < 1e-6
|
||||
|
||||
if not singular :
|
||||
x = math.atan2(R[2,1] , R[2,2])
|
||||
y = math.atan2(-R[2,0], sy)
|
||||
z = math.atan2(R[1,0], R[0,0])
|
||||
else :
|
||||
x = math.atan2(-R[1,2], R[1,1])
|
||||
y = math.atan2(-R[2,0], sy)
|
||||
z = 0
|
||||
|
||||
# rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z)
|
||||
rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi
|
||||
return rx, ry, rz
|
||||
|
125
src/utils/io.py
@ -1,125 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
import os.path as osp
|
||||
import imageio
|
||||
import numpy as np
|
||||
import pickle
|
||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
from .helper import mkdir, suffix
|
||||
|
||||
|
||||
def load_image_rgb(image_path: str):
|
||||
if not osp.exists(image_path):
|
||||
raise FileNotFoundError(f"Image not found: {image_path}")
|
||||
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
||||
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
def load_driving_info(driving_info):
|
||||
driving_video_ori = []
|
||||
|
||||
def load_images_from_directory(directory):
|
||||
image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg')))
|
||||
return [load_image_rgb(im_path) for im_path in image_paths]
|
||||
|
||||
def load_images_from_video(file_path):
|
||||
reader = imageio.get_reader(file_path, "ffmpeg")
|
||||
return [image for _, image in enumerate(reader)]
|
||||
|
||||
if osp.isdir(driving_info):
|
||||
driving_video_ori = load_images_from_directory(driving_info)
|
||||
elif osp.isfile(driving_info):
|
||||
driving_video_ori = load_images_from_video(driving_info)
|
||||
|
||||
return driving_video_ori
|
||||
|
||||
|
||||
def contiguous(obj):
|
||||
if not obj.flags.c_contiguous:
|
||||
obj = obj.copy(order="C")
|
||||
return obj
|
||||
|
||||
|
||||
def resize_to_limit(img: np.ndarray, max_dim=1920, division=2):
|
||||
"""
|
||||
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
|
||||
:param img: the image to be processed.
|
||||
:param max_dim: the maximum dimension constraint.
|
||||
:param n: the number that needs to be multiples of.
|
||||
:return: the adjusted image.
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
|
||||
# ajust the size of the image according to the maximum dimension
|
||||
if max_dim > 0 and max(h, w) > max_dim:
|
||||
if h > w:
|
||||
new_h = max_dim
|
||||
new_w = int(w * (max_dim / h))
|
||||
else:
|
||||
new_w = max_dim
|
||||
new_h = int(h * (max_dim / w))
|
||||
img = cv2.resize(img, (new_w, new_h))
|
||||
|
||||
# ensure that the image dimensions are multiples of n
|
||||
division = max(division, 1)
|
||||
new_h = img.shape[0] - (img.shape[0] % division)
|
||||
new_w = img.shape[1] - (img.shape[1] % division)
|
||||
|
||||
if new_h == 0 or new_w == 0:
|
||||
# when the width or height is less than n, no need to process
|
||||
return img
|
||||
|
||||
if new_h != img.shape[0] or new_w != img.shape[1]:
|
||||
img = img[:new_h, :new_w]
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def load_img_online(obj, mode="bgr", **kwargs):
|
||||
max_dim = kwargs.get("max_dim", 1920)
|
||||
n = kwargs.get("n", 2)
|
||||
if isinstance(obj, str):
|
||||
if mode.lower() == "gray":
|
||||
img = cv2.imread(obj, cv2.IMREAD_GRAYSCALE)
|
||||
else:
|
||||
img = cv2.imread(obj, cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img = obj
|
||||
|
||||
# Resize image to satisfy constraints
|
||||
img = resize_to_limit(img, max_dim=max_dim, division=n)
|
||||
|
||||
if mode.lower() == "bgr":
|
||||
return contiguous(img)
|
||||
elif mode.lower() == "rgb":
|
||||
return contiguous(img[..., ::-1])
|
||||
else:
|
||||
raise Exception(f"Unknown mode {mode}")
|
||||
|
||||
|
||||
def load(fp):
|
||||
suffix_ = suffix(fp)
|
||||
|
||||
if suffix_ == "npy":
|
||||
return np.load(fp)
|
||||
elif suffix_ == "pkl":
|
||||
return pickle.load(open(fp, "rb"))
|
||||
else:
|
||||
raise Exception(f"Unknown type: {suffix}")
|
||||
|
||||
|
||||
def dump(wfp, obj):
|
||||
wd = osp.split(wfp)[0]
|
||||
if wd != "" and not osp.exists(wd):
|
||||
mkdir(wd)
|
||||
|
||||
_suffix = suffix(wfp)
|
||||
if _suffix == "npy":
|
||||
np.save(wfp, obj)
|
||||
elif _suffix == "pkl":
|
||||
pickle.dump(obj, open(wfp, "wb"))
|
||||
else:
|
||||
raise Exception("Unknown type: {}".format(_suffix))
|
Before Width: | Height: | Size: 3.4 KiB |
@ -1,16 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
custom print and log functions
|
||||
"""
|
||||
|
||||
__all__ = ['rprint', 'rlog']
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
rprint = console.print
|
||||
rlog = console.log
|
||||
except:
|
||||
rprint = print
|
||||
rlog = print
|
@ -1,29 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
tools to measure elapsed time
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
class Timer(object):
|
||||
"""A simple timer."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_time = 0.
|
||||
self.calls = 0
|
||||
self.start_time = 0.
|
||||
self.diff = 0.
|
||||
|
||||
def tic(self):
|
||||
# using time.time instead of time.clock because time time.clock
|
||||
# does not normalize for multithreading
|
||||
self.start_time = time.time()
|
||||
|
||||
def toc(self, average=True):
|
||||
self.diff = time.time() - self.start_time
|
||||
return self.diff
|
||||
|
||||
def clear(self):
|
||||
self.start_time = 0.
|
||||
self.diff = 0.
|
@ -1,211 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
"""
|
||||
Functions for processing video
|
||||
|
||||
ATTENTION: you need to install ffmpeg and ffprobe in your env!
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import imageio
|
||||
import cv2
|
||||
from rich.progress import track
|
||||
|
||||
from .rprint import rlog as log
|
||||
from .rprint import rprint as print
|
||||
from .helper import prefix
|
||||
|
||||
|
||||
def exec_cmd(cmd):
|
||||
return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
|
||||
def images2video(images, wfp, **kwargs):
|
||||
fps = kwargs.get('fps', 30)
|
||||
video_format = kwargs.get('format', 'mp4') # default is mp4 format
|
||||
codec = kwargs.get('codec', 'libx264') # default is libx264 encoding
|
||||
quality = kwargs.get('quality') # video quality
|
||||
pixelformat = kwargs.get('pixelformat', 'yuv420p') # video pixel format
|
||||
image_mode = kwargs.get('image_mode', 'rgb')
|
||||
macro_block_size = kwargs.get('macro_block_size', 2)
|
||||
ffmpeg_params = ['-crf', str(kwargs.get('crf', 18))]
|
||||
|
||||
writer = imageio.get_writer(
|
||||
wfp, fps=fps, format=video_format,
|
||||
codec=codec, quality=quality, ffmpeg_params=ffmpeg_params, pixelformat=pixelformat, macro_block_size=macro_block_size
|
||||
)
|
||||
|
||||
n = len(images)
|
||||
for i in track(range(n), description='Writing', transient=True):
|
||||
if image_mode.lower() == 'bgr':
|
||||
writer.append_data(images[i][..., ::-1])
|
||||
else:
|
||||
writer.append_data(images[i])
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def video2gif(video_fp, fps=30, size=256):
|
||||
if osp.exists(video_fp):
|
||||
d = osp.split(video_fp)[0]
|
||||
fn = prefix(osp.basename(video_fp))
|
||||
palette_wfp = osp.join(d, 'palette.png')
|
||||
gif_wfp = osp.join(d, f'{fn}.gif')
|
||||
# generate the palette
|
||||
cmd = f'ffmpeg -i "{video_fp}" -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" "{palette_wfp}" -y'
|
||||
exec_cmd(cmd)
|
||||
# use the palette to generate the gif
|
||||
cmd = f'ffmpeg -i "{video_fp}" -i "{palette_wfp}" -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" "{gif_wfp}" -y'
|
||||
exec_cmd(cmd)
|
||||
else:
|
||||
print(f'video_fp: {video_fp} not exists!')
|
||||
|
||||
|
||||
def merge_audio_video(video_fp, audio_fp, wfp):
|
||||
if osp.exists(video_fp) and osp.exists(audio_fp):
|
||||
cmd = f'ffmpeg -i "{video_fp}" -i "{audio_fp}" -c:v copy -c:a aac "{wfp}" -y'
|
||||
exec_cmd(cmd)
|
||||
print(f'merge {video_fp} and {audio_fp} to {wfp}')
|
||||
else:
|
||||
print(f'video_fp: {video_fp} or audio_fp: {audio_fp} not exists!')
|
||||
|
||||
|
||||
def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):
|
||||
mask_float = mask.astype(np.float32) / 255.
|
||||
background_color = np.array(background_color).reshape([1, 1, 3])
|
||||
bg = np.ones_like(img) * background_color
|
||||
img = np.clip(mask_float * img + (1 - mask_float) * bg, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def concat_frames(driving_image_lst, source_image, I_p_lst):
|
||||
# TODO: add more concat style, e.g., left-down corner driving
|
||||
out_lst = []
|
||||
h, w, _ = I_p_lst[0].shape
|
||||
|
||||
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
|
||||
I_p = I_p_lst[idx]
|
||||
source_image_resized = cv2.resize(source_image, (w, h))
|
||||
|
||||
if driving_image_lst is None:
|
||||
out = np.hstack((source_image_resized, I_p))
|
||||
else:
|
||||
driving_image = driving_image_lst[idx]
|
||||
driving_image_resized = cv2.resize(driving_image, (w, h))
|
||||
out = np.hstack((driving_image_resized, source_image_resized, I_p))
|
||||
|
||||
out_lst.append(out)
|
||||
return out_lst
|
||||
|
||||
|
||||
class VideoWriter:
|
||||
def __init__(self, **kwargs):
|
||||
self.fps = kwargs.get('fps', 30)
|
||||
self.wfp = kwargs.get('wfp', 'video.mp4')
|
||||
self.video_format = kwargs.get('format', 'mp4')
|
||||
self.codec = kwargs.get('codec', 'libx264')
|
||||
self.quality = kwargs.get('quality')
|
||||
self.pixelformat = kwargs.get('pixelformat', 'yuv420p')
|
||||
self.image_mode = kwargs.get('image_mode', 'rgb')
|
||||
self.ffmpeg_params = kwargs.get('ffmpeg_params')
|
||||
|
||||
self.writer = imageio.get_writer(
|
||||
self.wfp, fps=self.fps, format=self.video_format,
|
||||
codec=self.codec, quality=self.quality,
|
||||
ffmpeg_params=self.ffmpeg_params, pixelformat=self.pixelformat
|
||||
)
|
||||
|
||||
def write(self, image):
|
||||
if self.image_mode.lower() == 'bgr':
|
||||
self.writer.append_data(image[..., ::-1])
|
||||
else:
|
||||
self.writer.append_data(image)
|
||||
|
||||
def close(self):
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
|
||||
def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=12):
|
||||
cmd = f'ffmpeg -i "{input_file}" -c:v {codec} -crf {crf} -r {fps} "{output_file}" -y'
|
||||
exec_cmd(cmd)
|
||||
|
||||
|
||||
def get_fps(filepath, default_fps=25):
|
||||
try:
|
||||
fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS)
|
||||
|
||||
if fps in (0, None):
|
||||
fps = default_fps
|
||||
except Exception as e:
|
||||
log(e)
|
||||
fps = default_fps
|
||||
|
||||
return fps
|
||||
|
||||
|
||||
def has_audio_stream(video_path: str) -> bool:
|
||||
"""
|
||||
Check if the video file contains an audio stream.
|
||||
|
||||
:param video_path: Path to the video file
|
||||
:return: True if the video contains an audio stream, False otherwise
|
||||
"""
|
||||
if osp.isdir(video_path):
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
'ffprobe',
|
||||
'-v', 'error',
|
||||
'-select_streams', 'a',
|
||||
'-show_entries', 'stream=codec_type',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
f'"{video_path}"'
|
||||
]
|
||||
|
||||
try:
|
||||
# result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
result = exec_cmd(' '.join(cmd))
|
||||
if result.returncode != 0:
|
||||
log(f"Error occurred while probing video: {result.stderr}")
|
||||
return False
|
||||
|
||||
# Check if there is any output from ffprobe command
|
||||
return bool(result.stdout.strip())
|
||||
except Exception as e:
|
||||
log(f"Error occurred while probing video: {video_path}, you may need to install ffprobe! Now set audio to false!", style="bold red")
|
||||
return False
|
||||
|
||||
|
||||
def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
|
||||
cmd = [
|
||||
'ffmpeg',
|
||||
'-y',
|
||||
'-i', f'"{silent_video_path}"',
|
||||
'-i', f'"{audio_video_path}"',
|
||||
'-map', '0:v',
|
||||
'-map', '1:a',
|
||||
'-c:v', 'copy',
|
||||
'-shortest',
|
||||
f'"{output_video_path}"'
|
||||
]
|
||||
|
||||
try:
|
||||
exec_cmd(' '.join(cmd))
|
||||
log(f"Video with audio generated successfully: {output_video_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def bb_intersection_over_union(boxA, boxB):
|
||||
xA = max(boxA[0], boxB[0])
|
||||
yA = max(boxA[1], boxB[1])
|
||||
xB = min(boxA[2], boxB[2])
|
||||
yB = min(boxA[3], boxB[3])
|
||||
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
||||
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
||||
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
||||
iou = interArea / float(boxAArea + boxBArea - interArea)
|
||||
return iou
|
@ -1,19 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
|
||||
def viz_lmk(img_, vps, **kwargs):
|
||||
"""可视化点"""
|
||||
lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA
|
||||
img_for_viz = img_.copy()
|
||||
for pt in vps:
|
||||
cv2.circle(
|
||||
img_for_viz,
|
||||
(int(pt[0]), int(pt[1])),
|
||||
radius=kwargs.get("radius", 1),
|
||||
color=(0, 255, 0),
|
||||
thickness=kwargs.get("thickness", 1),
|
||||
lineType=lineType,
|
||||
)
|
||||
return img_for_viz
|