cleanup unneeded files for Cat file creation

This commit is contained in:
Rafael Silva 2024-10-07 07:09:32 -04:00
parent 2bfc0d13ec
commit 827d45041b
57 changed files with 0 additions and 3777 deletions

19
.vscode/settings.json vendored
View File

@ -1,19 +0,0 @@
{
"[python]": {
"editor.tabSize": 4
},
"files.eol": "\n",
"files.insertFinalNewline": true,
"files.trimFinalNewlines": true,
"files.trimTrailingWhitespace": true,
"files.exclude": {
"**/.git": true,
"**/.svn": true,
"**/.hg": true,
"**/CVS": true,
"**/.DS_Store": true,
"**/Thumbs.db": true,
"**/*.crswap": true,
"**/__pycache__": true
}
}

215
app.py
View File

@ -1,215 +0,0 @@
# coding: utf-8
"""
The entrance of the gradio
"""
import tyro
import subprocess
import gradio as gr
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline.execute_video(*args, **kwargs)
def gpu_wrapped_execute_image(*args, **kwargs):
return gradio_pipeline.execute_image(*args, **kwargs)
# assets
title_md = "assets/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples = [
[osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
#################### interface logic ####################
# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
retargeting_input_image = gr.Image(type="filepath")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video()
output_video_concat = gr.Video()
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Source Portrait"):
image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s9.jpg")],
[osp.join(example_portrait_dir, "s6.jpg")],
[osp.join(example_portrait_dir, "s10.jpg")],
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
],
inputs=[image_input],
cache_examples=False,
)
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
],
inputs=[video_input],
cache_examples=False,
)
with gr.Row():
with gr.Accordion(open=False, label="Animation Instructions and Options"):
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
with gr.Row():
with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat.render()
with gr.Row():
# Examples
gr.Markdown("## You could also choose the examples below by one click ⬇️")
with gr.Row():
gr.Examples(
examples=data_examples,
fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
],
outputs=[output_image, output_image_paste_back],
examples_per_page=len(data_examples),
cache_examples=False,
)
gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True):
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Row(visible=True):
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton(
[
eye_retargeting_slider,
lip_retargeting_slider,
retargeting_input_image,
output_image,
output_image_paste_back
],
value="🧹 Clear"
)
with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"):
retargeting_input_image.render()
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s9.jpg")],
[osp.join(example_portrait_dir, "s6.jpg")],
[osp.join(example_portrait_dir, "s10.jpg")],
[osp.join(example_portrait_dir, "s5.jpg")],
[osp.join(example_portrait_dir, "s7.jpg")],
[osp.join(example_portrait_dir, "s12.jpg")],
],
inputs=[retargeting_input_image],
cache_examples=False,
)
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render()
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render()
# binding functions for buttons
process_button_retargeting.click(
# fn=gradio_pipeline.execute_image,
fn=gpu_wrapped_execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
image_input,
video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
],
outputs=[output_video, output_video_concat],
show_progress=True
)
demo.launch(
server_port=args.server_port,
share=args.share,
server_name=args.server_name
)

View File

@ -1,22 +0,0 @@
## 2024/07/10
**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️
The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
### Updates
- <strong>Audio and video concatenating: </strong> If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
- <strong>Driving video auto-cropping: </strong> Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
- <strong>Motion template making: </strong> Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option.
### About driving video
- For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
### Others
- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 801 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 MiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,16 +0,0 @@
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. In the <strong>Animation Options</strong> section, we recommend enabling the <strong>do crop (source)</strong> option if faces occupy a small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
3. If you want to upload your own driving video, <strong>the best practice</strong>:
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
</div>

View File

@ -1,4 +0,0 @@
<br>
## Retargeting
<span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>

View File

@ -1,2 +0,0 @@
## 🤗 This is the official gradio demo for **LivePortrait**.
<div style="font-size: 1.2em;">Please upload or use a webcam to get a <strong>Source Portrait</strong> (any aspect ratio) and upload a <strong>Driving Video</strong> (1:1 aspect ratio, or any aspect ratio with <code>do crop (driving video)</code> checked).</div>

View File

@ -1,11 +0,0 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
<a href='https://huggingface.co/spaces/KwaiVGI/liveportrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
</div>
</div>
</div>

View File

@ -1,60 +0,0 @@
# coding: utf-8
import os.path as osp
import tyro
import subprocess
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source_image):
raise FileNotFoundError(f"source image not found: {args.source_image}")
if not osp.exists(args.driving_info):
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
)
# fast check the args
fast_check_args(args)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)
# run
live_portrait_pipeline.execute(args)
# live_portrait_pipeline_cp = torch.compile(live_portrait_pipeline.execute, backend="inductor")
# with torch.no_grad():
# live_portrait_pipeline_cp(args)
if __name__ == "__main__":
main()

195
speed.py
View File

@ -1,195 +0,0 @@
# coding: utf-8
"""
Benchmark the inference speed of each module in LivePortrait.
TODO: heavy GPT style, need to refactor
"""
import torch
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
import yaml
import time
import numpy as np
from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig
def initialize_inputs(batch_size=1, device_id=0):
"""
Generate random input tensors and move them to GPU
"""
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
feat_stitching = concat_feat(kp_source, kp_driving).half()
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
inputs = {
'feature_3d': feature_3d,
'kp_source': kp_source,
'kp_driving': kp_driving,
'source_image': source_image,
'generator_input': generator_input,
'feat_stitching': feat_stitching,
'feat_eye': feat_eye,
'feat_lip': feat_lip
}
return inputs
def load_and_compile_models(cfg, model_config):
"""
Load and compile models for inference
"""
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor),
('Motion Extractor', motion_extractor),
('Warping Network', warping_module),
('SPADE Decoder', spade_generator)
]
compiled_models = {}
for name, model in models_with_params:
model = model.half()
model = torch.compile(model, mode='max-autotune') # Optimize for inference
model.eval() # Switch to evaluation mode
compiled_models[name] = model
retargeting_models = ['stitching', 'eye', 'lip']
for retarget in retargeting_models:
module = stitching_retargeting_module[retarget].half()
module = torch.compile(module, mode='max-autotune') # Optimize for inference
module.eval() # Switch to evaluation mode
stitching_retargeting_module[retarget] = module
return compiled_models, stitching_retargeting_module
def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
"""
Warm up models to prepare them for benchmarking
"""
print("Warm up start!")
with torch.no_grad():
for _ in range(10):
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
compiled_models['Motion Extractor'](inputs['source_image'])
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
print("Warm up end!")
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
"""
Measure inference times for each model
"""
times = {name: [] for name in compiled_models.keys()}
times['Stitching and Retargeting Modules'] = []
overall_times = []
with torch.no_grad():
for _ in range(100):
torch.cuda.synchronize()
overall_start = time.time()
start = time.time()
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Appearance Feature Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Motion Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Motion Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
torch.cuda.synchronize()
times['Warping Network'].append(time.time() - start)
start = time.time()
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
torch.cuda.synchronize()
times['SPADE Decoder'].append(time.time() - start)
start = time.time()
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
torch.cuda.synchronize()
times['Stitching and Retargeting Modules'].append(time.time() - start)
overall_times.append(time.time() - overall_start)
return times, overall_times
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
"""
Print benchmark results with average and standard deviation of inference times
"""
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
for name, model in compiled_models.items():
num_params = sum(p.numel() for p in model.parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
for index, retarget in enumerate(retargeting_models):
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
for name, avg_time in average_times.items():
std_time = std_times[name]
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
def main():
"""
Main function to benchmark speed and model parameters
"""
# Load configuration
cfg = InferenceConfig()
model_config_path = cfg.models_config
with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file)
# Sample input tensors
inputs = initialize_inputs(device_id = cfg.device_id)
# Load and compile models
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
# Warm up models
warm_up_models(compiled_models, stitching_retargeting_module, inputs)
# Measure inference times
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
# Print benchmark results
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
if __name__ == "__main__":
main()

View File

@ -1,117 +0,0 @@
# coding: utf-8
"""
Pipeline for gradio
"""
import gradio as gr
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
def update_args(args, user_args):
"""update the args according to user inputs
"""
for k, v in user_args.items():
if hasattr(args, k):
setattr(args, k, v)
return args
class GradioPipeline(LivePortraitPipeline):
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
def execute_video(
self,
input_image_path,
input_video_path,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
):
""" for video driven potrait animation
"""
if input_image_path is not None and input_video_path is not None:
args_user = {
'source_image': input_image_path,
'driving_info': input_video_path,
'flag_relative': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'flag_crop_driving_video': flag_crop_driving_video_input
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
else:
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
""" for single image retargeting
"""
# disposable feature
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, flag_do_crop)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
num_kp = x_s_user.shape[1]
# default: use x_s
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def prepare_retargeting(self, input_image, flag_do_crop=True):
""" for single image retargeting
"""
if input_image is not None:
# gr.Info("Upload successfully!", duration=2)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)

View File

@ -1,285 +0,0 @@
# coding: utf-8
"""
Pipeline of LivePortrait
"""
import torch
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import os
import os.path as osp
from rich.progress import track
from .config.argument_config import ArgumentConfig
from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix
from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapper
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
class LivePortraitPipeline(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg)
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg)
def execute(self, args: ArgumentConfig):
# for convenience
inf_cfg = self.live_portrait_wrapper.inference_cfg
device = self.live_portrait_wrapper.device
crop_cfg = self.cropper.crop_cfg
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image)
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
log(f"Load source image from {args.source_image}")
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
if crop_info is None:
raise Exception("No face detected in the source image!")
source_lmk = crop_info['lmk_crop']
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
if inf_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite
if flag_lip_zero:
# let lip-open scalar to be 0 at first
c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold:
flag_lip_zero = False
else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ########
flag_load_from_template = is_template(args.driving_info)
driving_rgb_crop_256x256_lst = None
wfp_template = None
if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
template_dct = load(args.driving_info)
n_frames = template_dct['n_frames']
# set output_fps
output_fps = template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving_info) and is_video(args.driving_info):
# load from video file, AND make motion template
log(f"Load video: {args.driving_info}")
if osp.isdir(args.driving_info):
output_fps = inf_cfg.output_fps
else:
output_fps = int(get_fps(args.driving_info))
log(f'The FPS of {args.driving_info} is: {output_fps}')
log(f"Load video file (mp4 mov avi etc...): {args.driving_info}")
driving_rgb_lst = load_driving_info(args.driving_info)
######## make motion template ########
log("Start making motion template...")
if inf_cfg.flag_crop_driving_video:
ret = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.')
driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else:
driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst)
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst)
# save the motion template
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst)
template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving_info) + '.pkl'
dump(wfp_template, template_dct)
log(f"Dump motion template to {wfp_template}")
n_frames = I_d_lst.shape[0]
else:
raise Exception(f"{args.driving_info} not exists or unsupported driving info types!")
#########################################
######## prepare for pasteback ########
I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_pstbk_lst = []
log("Prepared pasteback mask done.")
#########################################
I_p_lst = []
R_d_0, x_d_0_info = None, None
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
x_d_i_info = template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if inf_cfg.flag_relative_motion:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
R_new = R_d_i
delta_new = x_d_i_info['exp']
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
# Algorithm 1:
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# without stitching or retargeting
if flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
pass
elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# with stitching and without retargeting
if flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
if inf_cfg.flag_eye_retargeting:
c_d_eyes_i = c_d_eyes_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
if inf_cfg.flag_lip_retargeting:
c_d_lip_i = c_d_lip_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
if inf_cfg.flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
if inf_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir)
wfp_concat = None
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)
######### build final concact result #########
# driving frame | source image | generation, or source image | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_has_audio:
# final result with concact
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
# save drived result
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build final result #########
if flag_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4')
add_audio_to_video(wfp, args.driving_info, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp} with {wfp_with_audio}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concact: {wfp_concat}')
return wfp, wfp_concat
def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs):
n_frames = I_d_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
'c_d_eyes_lst': [],
'c_d_lip_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s_d, R_d, δ_d and t_d for inference
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
item_dct = {
'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32),
'R_d': R_d_i.cpu().numpy().astype(np.float32),
'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_d_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
c_d_eyes = c_d_eyes_lst[i].astype(np.float32)
template_dct['c_d_eyes_lst'].append(c_d_eyes)
c_d_lip = c_d_lip_lst[i].astype(np.float32)
template_dct['c_d_lip_lst'].append(c_d_lip)
return template_dct

View File

@ -1,333 +0,0 @@
# coding: utf-8
"""
Wrapper for LivePortrait core functions
"""
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
from .utils.timer import Timer
from .utils.helper import load_model, concat_feat
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from .config.inference_config import InferenceConfig
from .utils.rprint import rlog as log
class LivePortraitWrapper(object):
def __init__(self, inference_cfg: InferenceConfig):
self.inference_cfg = inference_cfg
self.device_id = inference_cfg.device_id
self.compile = inference_cfg.flag_do_torch_compile
if inference_cfg.flag_force_cpu:
self.device = 'cpu'
else:
self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
jit = False
jit = True
if jit:
self.appearance_feature_extractor = torch.jit.load("build/appearance_feature_extractor.pt")
self.motion_extractor = torch.jit.load("build/motion_extractor.pt")
self.warping_module = torch.jit.load("build/warping_module.pt")
self.spade_generator = torch.jit.load("build/spade_generator.pt")
eyes = torch.jit.load("build/stitching_retargeting_module_eye.pt")
lips = torch.jit.load("build/stitching_retargeting_module_lip.pt")
stitching = torch.jit.load("build/stitching_retargeting_module_stitching.pt")
self.stitching_retargeting_module = {'eye': eyes, 'lip': lips, 'stitching': stitching}
else:
# init F
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor done.')
# init M
self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor done.')
# init W
self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
log(f'Load warping_module done.')
# init G
self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
log(f'Load spade_generator done.')
# init S and R
if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module done.')
else:
self.stitching_retargeting_module = None
# Optimize for inference
if self.compile:
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
self.timer = Timer()
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.inference_cfg, k):
setattr(self.inference_cfg, k, v)
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
""" construct the input as standard
img: HxWx3, uint8, 256x256
"""
h, w = img.shape[:2]
if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
else:
x = img.copy()
if x.ndim == 3:
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
elif x.ndim == 4:
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
else:
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
x = np.clip(x, 0, 1) # clip to 0~1
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
x = x.to(self.device)
return x
def prepare_driving_videos(self, imgs) -> torch.Tensor:
""" construct the input as standard
imgs: NxBxHxWx3, uint8
"""
if isinstance(imgs, list):
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
elif isinstance(imgs, np.ndarray):
_imgs = imgs
else:
raise ValueError(f'imgs type error: {type(imgs)}')
y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
y = y.to(self.device)
return y
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
""" get the appearance feature of the image by F
x: Bx3xHxW, normalized to 0~1
"""
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
feature_3d = self.appearance_feature_extractor(x)
return feature_3d.float()
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
""" get the implicit keypoint information
x: Bx3xHxW, normalized to 0~1
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
"""
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
kp_info = self.motion_extractor(x)
if self.inference_cfg.flag_use_half_precision:
# float the dict
for k, v in kp_info.items():
if isinstance(v, torch.Tensor):
kp_info[k] = v.float()
flag_refine_info: bool = kwargs.get('flag_refine_info', True)
if flag_refine_info:
bs = kp_info['kp'].shape[0]
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
return kp_info
def get_pose_dct(self, kp_info: dict) -> dict:
pose_dct = dict(
pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
roll=headpose_pred_to_degree(kp_info['roll']).item(),
)
return pose_dct
def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
# get the canonical keypoints of source image by M
source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
# get the canonical keypoints of first driving frame by M
driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
driving_first_frame_rotation = get_rotation_matrix(
driving_first_frame_kp_info['pitch'],
driving_first_frame_kp_info['yaw'],
driving_first_frame_kp_info['roll']
)
# get feature volume by F
source_feature_3d = self.extract_feature_3d(source_prepared)
return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
def transform_keypoint(self, kp_info: dict):
"""
transform the implicit keypoints with the pose, shift, and expression deformation
kp: BxNx3
"""
kp = kp_info['kp'] # (bs, k, 3)
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
t, exp = kp_info['t'], kp_info['exp']
scale = kp_info['scale']
pitch = headpose_pred_to_degree(pitch)
yaw = headpose_pred_to_degree(yaw)
roll = headpose_pred_to_degree(roll)
bs = kp.shape[0]
if kp.ndim == 2:
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
else:
num_kp = kp.shape[1] # Bxnum_kpx3
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
# Eqn.2: s * (R * x_c,s + exp) + t
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
return kp_transformed
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
eye_close_ratio: Bx3
Return: Bx(3*num_kp)
"""
feat_eye = concat_feat(kp_source, eye_close_ratio)
with torch.no_grad():
delta = self.stitching_retargeting_module['eye'](feat_eye)
return delta
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
lip_close_ratio: Bx2
Return: Bx(3*num_kp)
"""
feat_lip = concat_feat(kp_source, lip_close_ratio)
with torch.no_grad():
delta = self.stitching_retargeting_module['lip'](feat_lip)
return delta
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
kp_driving: BxNx3
Return: Bx(3*num_kp+2)
"""
feat_stiching = concat_feat(kp_source, kp_driving)
with torch.no_grad():
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
return delta
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" conduct the stitching
kp_source: Bxnum_kpx3
kp_driving: Bxnum_kpx3
"""
if self.stitching_retargeting_module is not None:
bs, num_kp = kp_source.shape[:2]
kp_driving_new = kp_driving.clone()
delta = self.stitch(kp_source, kp_driving_new)
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
kp_driving_new += delta_exp
kp_driving_new[..., :2] += delta_tx_ty
return kp_driving_new
return kp_driving
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" get the image after the warping of the implicit keypoints
feature_3d: Bx32x16x64x64, feature volume
kp_source: BxNx3
kp_driving: BxNx3
"""
# The line 18 in Algorithm 1: D(W(f_s; x_s, x_d,i)
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
if self.compile:
# Mark the beginning of a new CUDA Graph step
torch.compiler.cudagraph_mark_step_begin()
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
# float the dict
if self.inference_cfg.flag_use_half_precision:
for k, v in ret_dct.items():
if isinstance(v, torch.Tensor):
ret_dct[k] = v.float()
return ret_dct
def parse_output(self, out: torch.Tensor) -> np.ndarray:
""" construct the output as standard
return: 1xHxWx3, uint8
"""
out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
out = np.clip(out, 0, 1) # clip to 0~1
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
return out
def calc_driving_ratio(self, driving_lmk_lst):
input_eye_ratio_lst = []
input_lip_ratio_lst = []
for lmk in driving_lmk_lst:
# for eyes retargeting
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
# for lip retargeting
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
return input_eye_ratio_lst, input_lip_ratio_lst
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
# [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
return combined_eye_ratio_tensor
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
c_s_lip = calc_lip_close_ratio(source_lmk[None])
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
# [c_s,lip, c_d,lip,i]
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
return combined_lip_ratio_tensor

View File

@ -1,20 +0,0 @@
# coding: utf-8
# pylint: disable=wrong-import-position
"""InsightFace: A Face Analysis Toolkit."""
from __future__ import absolute_import
try:
#import mxnet as mx
import onnxruntime
except ImportError:
raise ImportError(
"Unable to import dependency onnxruntime. "
)
__version__ = '0.7.3'
from . import model_zoo
from . import utils
from . import app
from . import data

View File

@ -1 +0,0 @@
from .face_analysis import *

View File

@ -1,49 +0,0 @@
import numpy as np
from numpy.linalg import norm as l2norm
#from easydict import EasyDict
class Face(dict):
def __init__(self, d=None, **kwargs):
if d is None:
d = {}
if kwargs:
d.update(**kwargs)
for k, v in d.items():
setattr(self, k, v)
# Class attributes
#for k in self.__class__.__dict__.keys():
# if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
# setattr(self, k, getattr(self, k))
def __setattr__(self, name, value):
if isinstance(value, (list, tuple)):
value = [self.__class__(x)
if isinstance(x, dict) else x for x in value]
elif isinstance(value, dict) and not isinstance(value, self.__class__):
value = self.__class__(value)
super(Face, self).__setattr__(name, value)
super(Face, self).__setitem__(name, value)
__setitem__ = __setattr__
def __getattr__(self, name):
return None
@property
def embedding_norm(self):
if self.embedding is None:
return None
return l2norm(self.embedding)
@property
def normed_embedding(self):
if self.embedding is None:
return None
return self.embedding / self.embedding_norm
@property
def sex(self):
if self.gender is None:
return None
return 'M' if self.gender==1 else 'F'

View File

@ -1,110 +0,0 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import glob
import os.path as osp
import numpy as np
import onnxruntime
from numpy.linalg import norm
from ..model_zoo import model_zoo
from ..utils import ensure_available
from .common import Face
DEFAULT_MP_NAME = 'buffalo_l'
__all__ = ['FaceAnalysis']
class FaceAnalysis:
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
onnxruntime.set_default_logger_severity(3)
self.models = {}
self.model_dir = ensure_available('models', name, root=root)
onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
onnx_files = sorted(onnx_files)
for onnx_file in onnx_files:
model = model_zoo.get_model(onnx_file, **kwargs)
if model is None:
print('model not recognized:', onnx_file)
elif allowed_modules is not None and model.taskname not in allowed_modules:
print('model ignore:', onnx_file, model.taskname)
del model
elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
# print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std)
self.models[model.taskname] = model
else:
print('duplicated model task type, ignore:', onnx_file, model.taskname)
del model
assert 'detection' in self.models
self.det_model = self.models['detection']
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
self.det_thresh = det_thresh
assert det_size is not None
# print('set det-size:', det_size)
self.det_size = det_size
for taskname, model in self.models.items():
if taskname=='detection':
model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
else:
model.prepare(ctx_id)
def get(self, img, max_num=0):
bboxes, kpss = self.det_model.detect(img,
max_num=max_num,
metric='default')
if bboxes.shape[0] == 0:
return []
ret = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i, 0:4]
det_score = bboxes[i, 4]
kps = None
if kpss is not None:
kps = kpss[i]
face = Face(bbox=bbox, kps=kps, det_score=det_score)
for taskname, model in self.models.items():
if taskname=='detection':
continue
model.get(img, face)
ret.append(face)
return ret
def draw_on(self, img, faces):
import cv2
dimg = img.copy()
for i in range(len(faces)):
face = faces[i]
box = face.bbox.astype(np.int)
color = (0, 0, 255)
cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2)
if face.kps is not None:
kps = face.kps.astype(np.int)
#print(landmark.shape)
for l in range(kps.shape[0]):
color = (0, 0, 255)
if l == 0 or l == 3:
color = (0, 255, 0)
cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color,
2)
if face.gender is not None and face.age is not None:
cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1)
#for key, value in face.items():
# if key.startswith('landmark_3d'):
# print(key, value.shape)
# print(value[0:10,:])
# lmk = np.round(value).astype(np.int)
# for l in range(lmk.shape[0]):
# color = (255, 0, 0)
# cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color,
# 2)
return dimg

View File

@ -1,2 +0,0 @@
from .image import get_image
from .pickle_object import get_object

View File

@ -1,27 +0,0 @@
import cv2
import os
import os.path as osp
from pathlib import Path
class ImageCache:
data = {}
def get_image(name, to_rgb=False):
key = (name, to_rgb)
if key in ImageCache.data:
return ImageCache.data[key]
images_dir = osp.join(Path(__file__).parent.absolute(), 'images')
ext_names = ['.jpg', '.png', '.jpeg']
image_file = None
for ext_name in ext_names:
_image_file = osp.join(images_dir, "%s%s"%(name, ext_name))
if osp.exists(_image_file):
image_file = _image_file
break
assert image_file is not None, '%s not found'%name
img = cv2.imread(image_file)
if to_rgb:
img = img[:,:,::-1]
ImageCache.data[key] = img
return img

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

View File

@ -1,17 +0,0 @@
import cv2
import os
import os.path as osp
from pathlib import Path
import pickle
def get_object(name):
objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects')
if not name.endswith('.pkl'):
name = name+".pkl"
filepath = osp.join(objects_dir, name)
if not osp.exists(filepath):
return None
with open(filepath, 'rb') as f:
obj = pickle.load(f)
return obj

View File

@ -1,71 +0,0 @@
import pickle
import numpy as np
import os
import os.path as osp
import sys
import mxnet as mx
class RecBuilder():
def __init__(self, path, image_size=(112, 112)):
self.path = path
self.image_size = image_size
self.widx = 0
self.wlabel = 0
self.max_label = -1
assert not osp.exists(path), '%s exists' % path
os.makedirs(path)
self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'),
os.path.join(path, 'train.rec'),
'w')
self.meta = []
def add(self, imgs):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
assert len(imgs) > 0
label = self.wlabel
for img in imgs:
idx = self.widx
image_meta = {'image_index': idx, 'image_classes': [label]}
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = label
self.wlabel += 1
def add_image(self, img, label):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
idx = self.widx
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(label, list):
idlabel = label[0]
else:
idlabel = label
image_meta = {'image_index': idx, 'image_classes': [idlabel]}
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = max(self.max_label, idlabel)
def close(self):
with open(osp.join(self.path, 'train.meta'), 'wb') as pfile:
pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL)
print('stat:', self.widx, self.wlabel)
with open(os.path.join(self.path, 'property'), 'w') as f:
f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1]))
f.write("%d\n" % (self.widx))

View File

@ -1,6 +0,0 @@
from .model_zoo import get_model
from .arcface_onnx import ArcFaceONNX
from .retinaface import RetinaFace
from .scrfd import SCRFD
from .landmark import Landmark
from .attribute import Attribute

View File

@ -1,92 +0,0 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
__all__ = [
'ArcFaceONNX',
]
class ArcFaceONNX:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
self.taskname = 'recognition'
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 127.5
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
self.output_shape = outputs[0].shape
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])
face.embedding = self.get_feat(aimg).flatten()
return face.embedding
def compute_sim(self, feat1, feat2):
from numpy.linalg import norm
feat1 = feat1.ravel()
feat2 = feat2.ravel()
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
return sim
def get_feat(self, imgs):
if not isinstance(imgs, list):
imgs = [imgs]
input_size = self.input_size
blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out
def forward(self, batch_data):
blob = (batch_data - self.input_mean) / self.input_std
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out

View File

@ -1,94 +0,0 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-06-19
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
__all__ = [
'Attribute',
]
class Attribute:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if nid<3 and node.name=='bn_data':
find_sub = True
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 128.0
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
#print('init output_shape:', output_shape)
if output_shape[1]==3:
self.taskname = 'genderage'
else:
self.taskname = 'attribute_%d'%output_shape[1]
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
bbox = face.bbox
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
rotate = 0
_scale = self.input_size[0] / (max(w, h)*1.5)
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
if self.taskname=='genderage':
assert len(pred)==3
gender = np.argmax(pred[:2])
age = int(np.round(pred[2]*100))
face['gender'] = gender
face['age'] = age
return gender, age
else:
return pred

View File

@ -1,114 +0,0 @@
import time
import numpy as np
import onnxruntime
import cv2
import onnx
from onnx import numpy_helper
from ..utils import face_align
class INSwapper():
def __init__(self, model_file=None, session=None):
self.model_file = model_file
self.session = session
model = onnx.load(self.model_file)
graph = model.graph
self.emap = numpy_helper.to_array(graph.initializer[-1])
self.input_mean = 0.0
self.input_std = 255.0
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
inputs = self.session.get_inputs()
self.input_names = []
for inp in inputs:
self.input_names.append(inp.name)
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
input_cfg = inputs[0]
input_shape = input_cfg.shape
self.input_shape = input_shape
# print('inswapper-shape:', self.input_shape)
self.input_size = tuple(input_shape[2:4][::-1])
def forward(self, img, latent):
img = (img - self.input_mean) / self.input_std
pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0]
return pred
def get(self, img, target_face, source_face, paste_back=True):
face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1)
aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0])
blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
latent = source_face.normed_embedding.reshape((1,-1))
latent = np.dot(latent, self.emap)
latent /= np.linalg.norm(latent)
pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0]
#print(latent.shape, latent.dtype, pred.shape)
img_fake = pred.transpose((0,2,3,1))[0]
bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
if not paste_back:
return bgr_fake, M
else:
target_img = img
fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32)
fake_diff = np.abs(fake_diff).mean(axis=2)
fake_diff[:2,:] = 0
fake_diff[-2:,:] = 0
fake_diff[:,:2] = 0
fake_diff[:,-2:] = 0
IM = cv2.invertAffineTransform(M)
img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32)
bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
img_white[img_white>20] = 255
fthresh = 10
fake_diff[fake_diff<fthresh] = 0
fake_diff[fake_diff>=fthresh] = 255
img_mask = img_white
mask_h_inds, mask_w_inds = np.where(img_mask==255)
mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
mask_size = int(np.sqrt(mask_h*mask_w))
k = max(mask_size//10, 10)
#k = max(mask_size//20, 6)
#k = 6
kernel = np.ones((k,k),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel = np.ones((2,2),np.uint8)
fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1)
face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1)
fake_diff[face_mask==1] = 255
k = max(mask_size//20, 5)
#k = 3
#k = 3
kernel_size = (k, k)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
k = 5
kernel_size = (k, k)
blur_size = tuple(2*i+1 for i in kernel_size)
fake_diff = cv2.blur(fake_diff, (11,11), 0)
##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0)
# print('blur_size: ', blur_size)
# fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size
img_mask /= 255
fake_diff /= 255
# img_mask = fake_diff
img_mask = img_mask*fake_diff
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32)
fake_merged = fake_merged.astype(np.uint8)
return fake_merged

View File

@ -1,114 +0,0 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
from ..utils import transform
from ..data import get_object
__all__ = [
'Landmark',
]
class Landmark:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if nid<3 and node.name=='bn_data':
find_sub = True
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 128.0
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
self.require_pose = False
#print('init output_shape:', output_shape)
if output_shape[1]==3309:
self.lmk_dim = 3
self.lmk_num = 68
self.mean_lmk = get_object('meanshape_68.pkl')
self.require_pose = True
else:
self.lmk_dim = 2
self.lmk_num = output_shape[1]//self.lmk_dim
self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num)
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
bbox = face.bbox
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
rotate = 0
_scale = self.input_size[0] / (max(w, h)*1.5)
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
if pred.shape[0] >= 3000:
pred = pred.reshape((-1, 3))
else:
pred = pred.reshape((-1, 2))
if self.lmk_num < pred.shape[0]:
pred = pred[self.lmk_num*-1:,:]
pred[:, 0:2] += 1
pred[:, 0:2] *= (self.input_size[0] // 2)
if pred.shape[1] == 3:
pred[:, 2] *= (self.input_size[0] // 2)
IM = cv2.invertAffineTransform(M)
pred = face_align.trans_points(pred, IM)
face[self.taskname] = pred
if self.require_pose:
P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred)
s, R, t = transform.P2sRt(P)
rx, ry, rz = transform.matrix2angle(R)
pose = np.array( [rx, ry, rz], dtype=np.float32 )
face['pose'] = pose #pitch, yaw, roll
return pred

View File

@ -1,103 +0,0 @@
"""
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py
"""
from __future__ import print_function
__all__ = ['get_model_file']
import os
import zipfile
import glob
from ..utils import download, check_sha1
_model_sha1 = {
name: checksum
for checksum, name in [
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
('', 'arcface_mfn_v1'),
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'),
('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'),
]
}
base_repo_url = 'https://insightface.ai/files/'
_url_format = '{repo_url}models/{file_name}.zip'
def short_hash(name):
if name not in _model_sha1:
raise ValueError(
'Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]
def find_params_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.params" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
The root directory will be created if it doesn't exist.
Parameters
----------
name : str
Name of the model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Returns
-------
file_path
Path to the requested pretrained model file.
"""
file_name = name
root = os.path.expanduser(root)
dir_path = os.path.join(root, name)
file_path = find_params_file(dir_path)
#file_path = os.path.join(root, file_name + '.params')
sha1_hash = _model_sha1[name]
if file_path is not None:
if check_sha1(file_path, sha1_hash):
return file_path
else:
print(
'Mismatch in the content of model file detected. Downloading again.'
)
else:
print('Model file is not found. Downloading.')
if not os.path.exists(root):
os.makedirs(root)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
zip_file_path = os.path.join(root, file_name + '.zip')
repo_url = base_repo_url
if repo_url[-1] != '/':
repo_url = repo_url + '/'
download(_url_format.format(repo_url=repo_url, file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(dir_path)
os.remove(zip_file_path)
file_path = find_params_file(dir_path)
if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError(
'Downloaded file has different hash. Please try again.')

View File

@ -1,97 +0,0 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
import os
import os.path as osp
import glob
import onnxruntime
from .arcface_onnx import *
from .retinaface import *
#from .scrfd import *
from .landmark import *
from .attribute import Attribute
from .inswapper import INSwapper
from ..utils import download_onnx
__all__ = ['get_model']
class PickableInferenceSession(onnxruntime.InferenceSession):
# This is a wrapper to make the current InferenceSession class pickable.
def __init__(self, model_path, **kwargs):
super().__init__(model_path, **kwargs)
self.model_path = model_path
def __getstate__(self):
return {'model_path': self.model_path}
def __setstate__(self, values):
model_path = values['model_path']
self.__init__(model_path)
class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file
def get_model(self, **kwargs):
session = PickableInferenceSession(self.onnx_file, **kwargs)
# print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
inputs = session.get_inputs()
input_cfg = inputs[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
if len(outputs)>=5:
return RetinaFace(model_file=self.onnx_file, session=session)
elif input_shape[2]==192 and input_shape[3]==192:
return Landmark(model_file=self.onnx_file, session=session)
elif input_shape[2]==96 and input_shape[3]==96:
return Attribute(model_file=self.onnx_file, session=session)
elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128:
return INSwapper(model_file=self.onnx_file, session=session)
elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
#raise RuntimeError('error on model routing')
return None
def find_onnx_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.onnx" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_default_providers():
return ['CUDAExecutionProvider', 'CPUExecutionProvider']
def get_default_provider_options():
return None
def get_model(name, **kwargs):
root = kwargs.get('root', '~/.insightface')
root = os.path.expanduser(root)
model_root = osp.join(root, 'models')
allow_download = kwargs.get('download', False)
download_zip = kwargs.get('download_zip', False)
if not name.endswith('.onnx'):
model_dir = os.path.join(model_root, name)
model_file = find_onnx_file(model_dir)
if model_file is None:
return None
else:
model_file = name
if not osp.exists(model_file) and allow_download:
model_file = download_onnx('models', model_file, root=root, download_zip=download_zip)
assert osp.exists(model_file), 'model_file %s should exist'%model_file
assert osp.isfile(model_file), 'model_file %s should be a file'%model_file
router = ModelRouter(model_file)
providers = kwargs.get('providers', get_default_providers())
provider_options = kwargs.get('provider_options', get_default_provider_options())
model = router.get_model(providers=providers, provider_options=provider_options)
return model

View File

@ -1,301 +0,0 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-09-18
# @Function :
from __future__ import division
import datetime
import numpy as np
import onnx
import onnxruntime
import os
import os.path as osp
import cv2
import sys
def softmax(z):
assert len(z.shape) == 2
s = np.max(z, axis=1)
s = s[:, np.newaxis] # necessary step to do broadcasting
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return np.stack([x1, y1, x2, y2], axis=-1)
def distance2kps(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
preds = []
for i in range(0, distance.shape[1], 2):
px = points[:, i%2] + distance[:, i]
py = points[:, i%2+1] + distance[:, i+1]
if max_shape is not None:
px = px.clamp(min=0, max=max_shape[1])
py = py.clamp(min=0, max=max_shape[0])
preds.append(px)
preds.append(py)
return np.stack(preds, axis=-1)
class RetinaFace:
def __init__(self, model_file=None, session=None):
import onnxruntime
self.model_file = model_file
self.session = session
self.taskname = 'detection'
if self.session is None:
assert self.model_file is not None
assert osp.exists(self.model_file)
self.session = onnxruntime.InferenceSession(self.model_file, None)
self.center_cache = {}
self.nms_thresh = 0.4
self.det_thresh = 0.5
self._init_vars()
def _init_vars(self):
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
#print(input_shape)
if isinstance(input_shape[2], str):
self.input_size = None
else:
self.input_size = tuple(input_shape[2:4][::-1])
#print('image_size:', self.image_size)
input_name = input_cfg.name
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for o in outputs:
output_names.append(o.name)
self.input_name = input_name
self.output_names = output_names
self.input_mean = 127.5
self.input_std = 128.0
#print(self.output_names)
#assert len(outputs)==10 or len(outputs)==15
self.use_kps = False
self._anchor_ratio = 1.0
self._num_anchors = 1
if len(outputs)==6:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
elif len(outputs)==9:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
self.use_kps = True
elif len(outputs)==10:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
elif len(outputs)==15:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
self.use_kps = True
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
nms_thresh = kwargs.get('nms_thresh', None)
if nms_thresh is not None:
self.nms_thresh = nms_thresh
det_thresh = kwargs.get('det_thresh', None)
if det_thresh is not None:
self.det_thresh = det_thresh
input_size = kwargs.get('input_size', None)
if input_size is not None:
if self.input_size is not None:
print('warning: det_size is already set in detection model, ignore')
else:
self.input_size = input_size
def forward(self, img, threshold):
scores_list = []
bboxes_list = []
kpss_list = []
input_size = tuple(img.shape[0:2][::-1])
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_outs = self.session.run(self.output_names, {self.input_name : blob})
input_height = blob.shape[2]
input_width = blob.shape[3]
fmc = self.fmc
for idx, stride in enumerate(self._feat_stride_fpn):
scores = net_outs[idx]
bbox_preds = net_outs[idx+fmc]
bbox_preds = bbox_preds * stride
if self.use_kps:
kps_preds = net_outs[idx+fmc*2] * stride
height = input_height // stride
width = input_width // stride
K = height * width
key = (height, width, stride)
if key in self.center_cache:
anchor_centers = self.center_cache[key]
else:
#solution-1, c style:
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
#for i in range(height):
# anchor_centers[i, :, 1] = i
#for i in range(width):
# anchor_centers[:, i, 0] = i
#solution-2:
#ax = np.arange(width, dtype=np.float32)
#ay = np.arange(height, dtype=np.float32)
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
#solution-3:
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
#print(anchor_centers.shape)
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
if self._num_anchors>1:
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
if len(self.center_cache)<100:
self.center_cache[key] = anchor_centers
pos_inds = np.where(scores>=threshold)[0]
bboxes = distance2bbox(anchor_centers, bbox_preds)
pos_scores = scores[pos_inds]
pos_bboxes = bboxes[pos_inds]
scores_list.append(pos_scores)
bboxes_list.append(pos_bboxes)
if self.use_kps:
kpss = distance2kps(anchor_centers, kps_preds)
#kpss = kps_preds
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
pos_kpss = kpss[pos_inds]
kpss_list.append(pos_kpss)
return scores_list, bboxes_list, kpss_list
def detect(self, img, input_size = None, max_num=0, metric='default'):
assert input_size is not None or self.input_size is not None
input_size = self.input_size if input_size is None else input_size
im_ratio = float(img.shape[0]) / img.shape[1]
model_ratio = float(input_size[1]) / input_size[0]
if im_ratio>model_ratio:
new_height = input_size[1]
new_width = int(new_height / im_ratio)
else:
new_width = input_size[0]
new_height = int(new_width * im_ratio)
det_scale = float(new_height) / img.shape[0]
resized_img = cv2.resize(img, (new_width, new_height))
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
det_img[:new_height, :new_width, :] = resized_img
scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
scores = np.vstack(scores_list)
scores_ravel = scores.ravel()
order = scores_ravel.argsort()[::-1]
bboxes = np.vstack(bboxes_list) / det_scale
if self.use_kps:
kpss = np.vstack(kpss_list) / det_scale
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
pre_det = pre_det[order, :]
keep = self.nms(pre_det)
det = pre_det[keep, :]
if self.use_kps:
kpss = kpss[order,:,:]
kpss = kpss[keep,:,:]
else:
kpss = None
if max_num > 0 and det.shape[0] > max_num:
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
det[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
if metric=='max':
values = area
else:
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
bindex = np.argsort(
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num]
det = det[bindex, :]
if kpss is not None:
kpss = kpss[bindex, :]
return det, kpss
def nms(self, dets):
thresh = self.nms_thresh
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def get_retinaface(name, download=False, root='~/.insightface/models', **kwargs):
if not download:
assert os.path.exists(name)
return RetinaFace(name)
else:
from .model_store import get_model_file
_file = get_model_file("retinaface_%s" % name, root=root)
return retinaface(_file)

View File

@ -1,348 +0,0 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import datetime
import numpy as np
import onnx
import onnxruntime
import os
import os.path as osp
import cv2
import sys
def softmax(z):
assert len(z.shape) == 2
s = np.max(z, axis=1)
s = s[:, np.newaxis] # necessary step to do broadcasting
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return np.stack([x1, y1, x2, y2], axis=-1)
def distance2kps(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
preds = []
for i in range(0, distance.shape[1], 2):
px = points[:, i%2] + distance[:, i]
py = points[:, i%2+1] + distance[:, i+1]
if max_shape is not None:
px = px.clamp(min=0, max=max_shape[1])
py = py.clamp(min=0, max=max_shape[0])
preds.append(px)
preds.append(py)
return np.stack(preds, axis=-1)
class SCRFD:
def __init__(self, model_file=None, session=None):
import onnxruntime
self.model_file = model_file
self.session = session
self.taskname = 'detection'
self.batched = False
if self.session is None:
assert self.model_file is not None
assert osp.exists(self.model_file)
self.session = onnxruntime.InferenceSession(self.model_file, None)
self.center_cache = {}
self.nms_thresh = 0.4
self.det_thresh = 0.5
self._init_vars()
def _init_vars(self):
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
#print(input_shape)
if isinstance(input_shape[2], str):
self.input_size = None
else:
self.input_size = tuple(input_shape[2:4][::-1])
#print('image_size:', self.image_size)
input_name = input_cfg.name
self.input_shape = input_shape
outputs = self.session.get_outputs()
if len(outputs[0].shape) == 3:
self.batched = True
output_names = []
for o in outputs:
output_names.append(o.name)
self.input_name = input_name
self.output_names = output_names
self.input_mean = 127.5
self.input_std = 128.0
#print(self.output_names)
#assert len(outputs)==10 or len(outputs)==15
self.use_kps = False
self._anchor_ratio = 1.0
self._num_anchors = 1
if len(outputs)==6:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
elif len(outputs)==9:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
self.use_kps = True
elif len(outputs)==10:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
elif len(outputs)==15:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
self.use_kps = True
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
nms_thresh = kwargs.get('nms_thresh', None)
if nms_thresh is not None:
self.nms_thresh = nms_thresh
det_thresh = kwargs.get('det_thresh', None)
if det_thresh is not None:
self.det_thresh = det_thresh
input_size = kwargs.get('input_size', None)
if input_size is not None:
if self.input_size is not None:
print('warning: det_size is already set in scrfd model, ignore')
else:
self.input_size = input_size
def forward(self, img, threshold):
scores_list = []
bboxes_list = []
kpss_list = []
input_size = tuple(img.shape[0:2][::-1])
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_outs = self.session.run(self.output_names, {self.input_name : blob})
input_height = blob.shape[2]
input_width = blob.shape[3]
fmc = self.fmc
for idx, stride in enumerate(self._feat_stride_fpn):
# If model support batch dim, take first output
if self.batched:
scores = net_outs[idx][0]
bbox_preds = net_outs[idx + fmc][0]
bbox_preds = bbox_preds * stride
if self.use_kps:
kps_preds = net_outs[idx + fmc * 2][0] * stride
# If model doesn't support batching take output as is
else:
scores = net_outs[idx]
bbox_preds = net_outs[idx + fmc]
bbox_preds = bbox_preds * stride
if self.use_kps:
kps_preds = net_outs[idx + fmc * 2] * stride
height = input_height // stride
width = input_width // stride
K = height * width
key = (height, width, stride)
if key in self.center_cache:
anchor_centers = self.center_cache[key]
else:
#solution-1, c style:
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
#for i in range(height):
# anchor_centers[i, :, 1] = i
#for i in range(width):
# anchor_centers[:, i, 0] = i
#solution-2:
#ax = np.arange(width, dtype=np.float32)
#ay = np.arange(height, dtype=np.float32)
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
#solution-3:
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
#print(anchor_centers.shape)
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
if self._num_anchors>1:
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
if len(self.center_cache)<100:
self.center_cache[key] = anchor_centers
pos_inds = np.where(scores>=threshold)[0]
bboxes = distance2bbox(anchor_centers, bbox_preds)
pos_scores = scores[pos_inds]
pos_bboxes = bboxes[pos_inds]
scores_list.append(pos_scores)
bboxes_list.append(pos_bboxes)
if self.use_kps:
kpss = distance2kps(anchor_centers, kps_preds)
#kpss = kps_preds
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
pos_kpss = kpss[pos_inds]
kpss_list.append(pos_kpss)
return scores_list, bboxes_list, kpss_list
def detect(self, img, input_size = None, max_num=0, metric='default'):
assert input_size is not None or self.input_size is not None
input_size = self.input_size if input_size is None else input_size
im_ratio = float(img.shape[0]) / img.shape[1]
model_ratio = float(input_size[1]) / input_size[0]
if im_ratio>model_ratio:
new_height = input_size[1]
new_width = int(new_height / im_ratio)
else:
new_width = input_size[0]
new_height = int(new_width * im_ratio)
det_scale = float(new_height) / img.shape[0]
resized_img = cv2.resize(img, (new_width, new_height))
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
det_img[:new_height, :new_width, :] = resized_img
scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
scores = np.vstack(scores_list)
scores_ravel = scores.ravel()
order = scores_ravel.argsort()[::-1]
bboxes = np.vstack(bboxes_list) / det_scale
if self.use_kps:
kpss = np.vstack(kpss_list) / det_scale
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
pre_det = pre_det[order, :]
keep = self.nms(pre_det)
det = pre_det[keep, :]
if self.use_kps:
kpss = kpss[order,:,:]
kpss = kpss[keep,:,:]
else:
kpss = None
if max_num > 0 and det.shape[0] > max_num:
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
det[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
if metric=='max':
values = area
else:
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
bindex = np.argsort(
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num]
det = det[bindex, :]
if kpss is not None:
kpss = kpss[bindex, :]
return det, kpss
def nms(self, dets):
thresh = self.nms_thresh
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs):
if not download:
assert os.path.exists(name)
return SCRFD(name)
else:
from .model_store import get_model_file
_file = get_model_file("scrfd_%s" % name, root=root)
return SCRFD(_file)
def scrfd_2p5gkps(**kwargs):
return get_scrfd("2p5gkps", download=True, **kwargs)
if __name__ == '__main__':
import glob
detector = SCRFD(model_file='./det.onnx')
detector.prepare(-1)
img_paths = ['tests/data/t1.jpg']
for img_path in img_paths:
img = cv2.imread(img_path)
for _ in range(1):
ta = datetime.datetime.now()
#bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640))
bboxes, kpss = detector.detect(img, 0.5)
tb = datetime.datetime.now()
print('all cost:', (tb-ta).total_seconds()*1000)
print(img_path, bboxes.shape)
if kpss is not None:
print(kpss.shape)
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1,y1,x2,y2,score = bbox.astype(np.int)
cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2)
if kpss is not None:
kps = kpss[i]
for kp in kps:
kp = kp.astype(np.int)
cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2)
filename = img_path.split('/')[-1]
print('output:', filename)
cv2.imwrite('./outputs/%s'%filename, img)

View File

@ -1,6 +0,0 @@
from __future__ import absolute_import
from .storage import download, ensure_available, download_onnx
from .filesystem import get_model_dir
from .filesystem import makedirs, try_import_dali
from .constant import *

View File

@ -1,3 +0,0 @@
DEFAULT_MP_NAME = 'buffalo_l'

View File

@ -1,95 +0,0 @@
"""
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py
"""
import os
import hashlib
import requests
from tqdm import tqdm
def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
sha1_file = sha1.hexdigest()
l = min(len(sha1_file), len(sha1_hash))
return sha1.hexdigest()[0:l] == sha1_hash[0:l]
def download_file(url, path=None, overwrite=False, sha1_hash=None):
"""Download an given URL
Parameters
----------
url : str
URL to download
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = url.split('/')[-1]
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
if overwrite or not os.path.exists(fname) or (
sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
print('Downloading %s from %s...' % (fname, url))
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url)
total_length = r.headers.get('content-length')
with open(fname, 'wb') as f:
if total_length is None: # no content length header
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
else:
total_length = int(total_length)
for chunk in tqdm(r.iter_content(chunk_size=1024),
total=int(total_length / 1024. + 0.5),
unit='KB',
unit_scale=False,
dynamic_ncols=True):
f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
'The repo may be outdated or download may be incomplete. ' \
'If the "repo_url" is overridden, consider switching to ' \
'the default repo.'.format(fname))
return fname

View File

@ -1,103 +0,0 @@
import cv2
import numpy as np
from skimage import transform as trans
arcface_dst = np.array(
[[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
[41.5493, 92.3655], [70.7299, 92.2041]],
dtype=np.float32)
def estimate_norm(lmk, image_size=112,mode='arcface'):
assert lmk.shape == (5, 2)
assert image_size%112==0 or image_size%128==0
if image_size%112==0:
ratio = float(image_size)/112.0
diff_x = 0
else:
ratio = float(image_size)/128.0
diff_x = 8.0*ratio
dst = arcface_dst * ratio
dst[:,0] += diff_x
tform = trans.SimilarityTransform()
tform.estimate(lmk, dst)
M = tform.params[0:2, :]
return M
def norm_crop(img, landmark, image_size=112, mode='arcface'):
M = estimate_norm(landmark, image_size, mode)
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
return warped
def norm_crop2(img, landmark, image_size=112, mode='arcface'):
M = estimate_norm(landmark, image_size, mode)
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
return warped, M
def square_crop(im, S):
if im.shape[0] > im.shape[1]:
height = S
width = int(float(im.shape[1]) / im.shape[0] * S)
scale = float(S) / im.shape[0]
else:
width = S
height = int(float(im.shape[0]) / im.shape[1] * S)
scale = float(S) / im.shape[1]
resized_im = cv2.resize(im, (width, height))
det_im = np.zeros((S, S, 3), dtype=np.uint8)
det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
return det_im, scale
def transform(data, center, output_size, scale, rotation):
scale_ratio = scale
rot = float(rotation) * np.pi / 180.0
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
t1 = trans.SimilarityTransform(scale=scale_ratio)
cx = center[0] * scale_ratio
cy = center[1] * scale_ratio
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
t3 = trans.SimilarityTransform(rotation=rot)
t4 = trans.SimilarityTransform(translation=(output_size / 2,
output_size / 2))
t = t1 + t2 + t3 + t4
M = t.params[0:2]
cropped = cv2.warpAffine(data,
M, (output_size, output_size),
borderValue=0.0)
return cropped, M
def trans_points2d(pts, M):
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
for i in range(pts.shape[0]):
pt = pts[i]
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
new_pt = np.dot(M, new_pt)
#print('new_pt', new_pt.shape, new_pt)
new_pts[i] = new_pt[0:2]
return new_pts
def trans_points3d(pts, M):
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
#print(scale)
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
for i in range(pts.shape[0]):
pt = pts[i]
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
new_pt = np.dot(M, new_pt)
#print('new_pt', new_pt.shape, new_pt)
new_pts[i][0:2] = new_pt[0:2]
new_pts[i][2] = pts[i][2] * scale
return new_pts
def trans_points(pts, M):
if pts.shape[1] == 2:
return trans_points2d(pts, M)
else:
return trans_points3d(pts, M)

View File

@ -1,157 +0,0 @@
"""
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/filesystem.py
"""
import os
import os.path as osp
import errno
def get_model_dir(name, root='~/.insightface'):
root = os.path.expanduser(root)
model_dir = osp.join(root, 'models', name)
return model_dir
def makedirs(path):
"""Create directory recursively if not exists.
Similar to `makedir -p`, you can skip checking existence before this function.
Parameters
----------
path : str
Path of the desired dir
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
def try_import(package, message=None):
"""Try import specified package, with custom message support.
Parameters
----------
package : str
The name of the targeting package.
message : str, default is None
If not None, this function will raise customized error message when import error is found.
Returns
-------
module if found, raise ImportError otherwise
"""
try:
return __import__(package)
except ImportError as e:
if not message:
raise e
raise ImportError(message)
def try_import_cv2():
"""Try import cv2 at runtime.
Returns
-------
cv2 module if found. Raise ImportError otherwise
"""
msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \
or `pip install opencv-python --user` (note that this is unofficial PYPI package)."
return try_import('cv2', msg)
def try_import_mmcv():
"""Try import mmcv at runtime.
Returns
-------
mmcv module if found. Raise ImportError otherwise
"""
msg = "mmcv is required, you can install by first `pip install Cython --user` \
and then `pip install mmcv --user` (note that this is unofficial PYPI package)."
return try_import('mmcv', msg)
def try_import_rarfile():
"""Try import rarfile at runtime.
Returns
-------
rarfile module if found. Raise ImportError otherwise
"""
msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \
and then `pip install rarfile --user` (note that this is unofficial PYPI package)."
return try_import('rarfile', msg)
def import_try_install(package, extern_url=None):
"""Try import the specified package.
If the package not installed, try use pip to install and import if success.
Parameters
----------
package : str
The name of the package trying to import.
extern_url : str or None, optional
The external url if package is not hosted on PyPI.
For example, you can install a package using:
"pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx".
In this case, you can pass the url to the extern_url.
Returns
-------
<class 'Module'>
The imported python module.
"""
try:
return __import__(package)
except ImportError:
try:
from pip import main as pipmain
except ImportError:
from pip._internal import main as pipmain
# trying to install package
url = package if extern_url is None else extern_url
pipmain(['install', '--user',
url]) # will raise SystemExit Error if fails
# trying to load again
try:
return __import__(package)
except ImportError:
import sys
import site
user_site = site.getusersitepackages()
if user_site not in sys.path:
sys.path.append(user_site)
return __import__(package)
return __import__(package)
def try_import_dali():
"""Try import NVIDIA DALI at runtime.
"""
try:
dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types'])
dali.Pipeline = dali.pipeline.Pipeline
except ImportError:
class dali:
class Pipeline:
def __init__(self):
raise NotImplementedError(
"DALI not found, please check if you installed it correctly."
)
return dali

View File

@ -1,52 +0,0 @@
import os
import os.path as osp
import zipfile
from .download import download_file
BASE_REPO_URL = 'https://github.com/deepinsight/insightface/releases/download/v0.7'
def download(sub_dir, name, force=False, root='~/.insightface'):
_root = os.path.expanduser(root)
dir_path = os.path.join(_root, sub_dir, name)
if osp.exists(dir_path) and not force:
return dir_path
print('download_path:', dir_path)
zip_file_path = os.path.join(_root, sub_dir, name + '.zip')
model_url = "%s/%s.zip"%(BASE_REPO_URL, name)
download_file(model_url,
path=zip_file_path,
overwrite=True)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(dir_path)
#os.remove(zip_file_path)
return dir_path
def ensure_available(sub_dir, name, root='~/.insightface'):
return download(sub_dir, name, force=False, root=root)
def download_onnx(sub_dir, model_file, force=False, root='~/.insightface', download_zip=False):
_root = os.path.expanduser(root)
model_root = osp.join(_root, sub_dir)
new_model_file = osp.join(model_root, model_file)
if osp.exists(new_model_file) and not force:
return new_model_file
if not osp.exists(model_root):
os.makedirs(model_root)
print('download_path:', new_model_file)
if not download_zip:
model_url = "%s/%s"%(BASE_REPO_URL, model_file)
download_file(model_url,
path=new_model_file,
overwrite=True)
else:
model_url = "%s/%s.zip"%(BASE_REPO_URL, model_file)
zip_file_path = new_model_file+".zip"
download_file(model_url,
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(model_root)
return new_model_file

View File

@ -1,116 +0,0 @@
import cv2
import math
import numpy as np
from skimage import transform as trans
def transform(data, center, output_size, scale, rotation):
scale_ratio = scale
rot = float(rotation) * np.pi / 180.0
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
t1 = trans.SimilarityTransform(scale=scale_ratio)
cx = center[0] * scale_ratio
cy = center[1] * scale_ratio
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
t3 = trans.SimilarityTransform(rotation=rot)
t4 = trans.SimilarityTransform(translation=(output_size / 2,
output_size / 2))
t = t1 + t2 + t3 + t4
M = t.params[0:2]
cropped = cv2.warpAffine(data,
M, (output_size, output_size),
borderValue=0.0)
return cropped, M
def trans_points2d(pts, M):
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
for i in range(pts.shape[0]):
pt = pts[i]
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
new_pt = np.dot(M, new_pt)
#print('new_pt', new_pt.shape, new_pt)
new_pts[i] = new_pt[0:2]
return new_pts
def trans_points3d(pts, M):
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
#print(scale)
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
for i in range(pts.shape[0]):
pt = pts[i]
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
new_pt = np.dot(M, new_pt)
#print('new_pt', new_pt.shape, new_pt)
new_pts[i][0:2] = new_pt[0:2]
new_pts[i][2] = pts[i][2] * scale
return new_pts
def trans_points(pts, M):
if pts.shape[1] == 2:
return trans_points2d(pts, M)
else:
return trans_points3d(pts, M)
def estimate_affine_matrix_3d23d(X, Y):
''' Using least-squares solution
Args:
X: [n, 3]. 3d points(fixed)
Y: [n, 3]. corresponding 3d points(moving). Y = PX
Returns:
P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]).
'''
X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4
P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4
return P
def P2sRt(P):
''' decompositing camera matrix P
Args:
P: (3, 4). Affine Camera Matrix.
Returns:
s: scale factor.
R: (3, 3). rotation matrix.
t: (3,). translation.
'''
t = P[:, 3]
R1 = P[0:1, :3]
R2 = P[1:2, :3]
s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0
r1 = R1/np.linalg.norm(R1)
r2 = R2/np.linalg.norm(R2)
r3 = np.cross(r1, r2)
R = np.concatenate((r1, r2, r3), 0)
return s, R, t
def matrix2angle(R):
''' get three Euler angles from Rotation Matrix
Args:
R: (3,3). rotation matrix
Returns:
x: pitch
y: yaw
z: roll
'''
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
singular = sy < 1e-6
if not singular :
x = math.atan2(R[2,1] , R[2,2])
y = math.atan2(-R[2,0], sy)
z = math.atan2(R[1,0], R[0,0])
else :
x = math.atan2(-R[1,2], R[1,1])
y = math.atan2(-R[2,0], sy)
z = 0
# rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z)
rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi
return rx, ry, rz

View File

@ -1,125 +0,0 @@
# coding: utf-8
import os
from glob import glob
import os.path as osp
import imageio
import numpy as np
import pickle
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
from .helper import mkdir, suffix
def load_image_rgb(image_path: str):
if not osp.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def load_driving_info(driving_info):
driving_video_ori = []
def load_images_from_directory(directory):
image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg')))
return [load_image_rgb(im_path) for im_path in image_paths]
def load_images_from_video(file_path):
reader = imageio.get_reader(file_path, "ffmpeg")
return [image for _, image in enumerate(reader)]
if osp.isdir(driving_info):
driving_video_ori = load_images_from_directory(driving_info)
elif osp.isfile(driving_info):
driving_video_ori = load_images_from_video(driving_info)
return driving_video_ori
def contiguous(obj):
if not obj.flags.c_contiguous:
obj = obj.copy(order="C")
return obj
def resize_to_limit(img: np.ndarray, max_dim=1920, division=2):
"""
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
:param img: the image to be processed.
:param max_dim: the maximum dimension constraint.
:param n: the number that needs to be multiples of.
:return: the adjusted image.
"""
h, w = img.shape[:2]
# ajust the size of the image according to the maximum dimension
if max_dim > 0 and max(h, w) > max_dim:
if h > w:
new_h = max_dim
new_w = int(w * (max_dim / h))
else:
new_w = max_dim
new_h = int(h * (max_dim / w))
img = cv2.resize(img, (new_w, new_h))
# ensure that the image dimensions are multiples of n
division = max(division, 1)
new_h = img.shape[0] - (img.shape[0] % division)
new_w = img.shape[1] - (img.shape[1] % division)
if new_h == 0 or new_w == 0:
# when the width or height is less than n, no need to process
return img
if new_h != img.shape[0] or new_w != img.shape[1]:
img = img[:new_h, :new_w]
return img
def load_img_online(obj, mode="bgr", **kwargs):
max_dim = kwargs.get("max_dim", 1920)
n = kwargs.get("n", 2)
if isinstance(obj, str):
if mode.lower() == "gray":
img = cv2.imread(obj, cv2.IMREAD_GRAYSCALE)
else:
img = cv2.imread(obj, cv2.IMREAD_COLOR)
else:
img = obj
# Resize image to satisfy constraints
img = resize_to_limit(img, max_dim=max_dim, division=n)
if mode.lower() == "bgr":
return contiguous(img)
elif mode.lower() == "rgb":
return contiguous(img[..., ::-1])
else:
raise Exception(f"Unknown mode {mode}")
def load(fp):
suffix_ = suffix(fp)
if suffix_ == "npy":
return np.load(fp)
elif suffix_ == "pkl":
return pickle.load(open(fp, "rb"))
else:
raise Exception(f"Unknown type: {suffix}")
def dump(wfp, obj):
wd = osp.split(wfp)[0]
if wd != "" and not osp.exists(wd):
mkdir(wd)
_suffix = suffix(wfp)
if _suffix == "npy":
np.save(wfp, obj)
elif _suffix == "pkl":
pickle.dump(obj, open(wfp, "wb"))
else:
raise Exception("Unknown type: {}".format(_suffix))

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -1,16 +0,0 @@
# coding: utf-8
"""
custom print and log functions
"""
__all__ = ['rprint', 'rlog']
try:
from rich.console import Console
console = Console()
rprint = console.print
rlog = console.log
except:
rprint = print
rlog = print

View File

@ -1,29 +0,0 @@
# coding: utf-8
"""
tools to measure elapsed time
"""
import time
class Timer(object):
"""A simple timer."""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
return self.diff
def clear(self):
self.start_time = 0.
self.diff = 0.

View File

@ -1,211 +0,0 @@
# coding: utf-8
"""
Functions for processing video
ATTENTION: you need to install ffmpeg and ffprobe in your env!
"""
import os.path as osp
import numpy as np
import subprocess
import imageio
import cv2
from rich.progress import track
from .rprint import rlog as log
from .rprint import rprint as print
from .helper import prefix
def exec_cmd(cmd):
return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
def images2video(images, wfp, **kwargs):
fps = kwargs.get('fps', 30)
video_format = kwargs.get('format', 'mp4') # default is mp4 format
codec = kwargs.get('codec', 'libx264') # default is libx264 encoding
quality = kwargs.get('quality') # video quality
pixelformat = kwargs.get('pixelformat', 'yuv420p') # video pixel format
image_mode = kwargs.get('image_mode', 'rgb')
macro_block_size = kwargs.get('macro_block_size', 2)
ffmpeg_params = ['-crf', str(kwargs.get('crf', 18))]
writer = imageio.get_writer(
wfp, fps=fps, format=video_format,
codec=codec, quality=quality, ffmpeg_params=ffmpeg_params, pixelformat=pixelformat, macro_block_size=macro_block_size
)
n = len(images)
for i in track(range(n), description='Writing', transient=True):
if image_mode.lower() == 'bgr':
writer.append_data(images[i][..., ::-1])
else:
writer.append_data(images[i])
writer.close()
def video2gif(video_fp, fps=30, size=256):
if osp.exists(video_fp):
d = osp.split(video_fp)[0]
fn = prefix(osp.basename(video_fp))
palette_wfp = osp.join(d, 'palette.png')
gif_wfp = osp.join(d, f'{fn}.gif')
# generate the palette
cmd = f'ffmpeg -i "{video_fp}" -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" "{palette_wfp}" -y'
exec_cmd(cmd)
# use the palette to generate the gif
cmd = f'ffmpeg -i "{video_fp}" -i "{palette_wfp}" -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" "{gif_wfp}" -y'
exec_cmd(cmd)
else:
print(f'video_fp: {video_fp} not exists!')
def merge_audio_video(video_fp, audio_fp, wfp):
if osp.exists(video_fp) and osp.exists(audio_fp):
cmd = f'ffmpeg -i "{video_fp}" -i "{audio_fp}" -c:v copy -c:a aac "{wfp}" -y'
exec_cmd(cmd)
print(f'merge {video_fp} and {audio_fp} to {wfp}')
else:
print(f'video_fp: {video_fp} or audio_fp: {audio_fp} not exists!')
def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):
mask_float = mask.astype(np.float32) / 255.
background_color = np.array(background_color).reshape([1, 1, 3])
bg = np.ones_like(img) * background_color
img = np.clip(mask_float * img + (1 - mask_float) * bg, 0, 255).astype(np.uint8)
return img
def concat_frames(driving_image_lst, source_image, I_p_lst):
# TODO: add more concat style, e.g., left-down corner driving
out_lst = []
h, w, _ = I_p_lst[0].shape
for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
I_p = I_p_lst[idx]
source_image_resized = cv2.resize(source_image, (w, h))
if driving_image_lst is None:
out = np.hstack((source_image_resized, I_p))
else:
driving_image = driving_image_lst[idx]
driving_image_resized = cv2.resize(driving_image, (w, h))
out = np.hstack((driving_image_resized, source_image_resized, I_p))
out_lst.append(out)
return out_lst
class VideoWriter:
def __init__(self, **kwargs):
self.fps = kwargs.get('fps', 30)
self.wfp = kwargs.get('wfp', 'video.mp4')
self.video_format = kwargs.get('format', 'mp4')
self.codec = kwargs.get('codec', 'libx264')
self.quality = kwargs.get('quality')
self.pixelformat = kwargs.get('pixelformat', 'yuv420p')
self.image_mode = kwargs.get('image_mode', 'rgb')
self.ffmpeg_params = kwargs.get('ffmpeg_params')
self.writer = imageio.get_writer(
self.wfp, fps=self.fps, format=self.video_format,
codec=self.codec, quality=self.quality,
ffmpeg_params=self.ffmpeg_params, pixelformat=self.pixelformat
)
def write(self, image):
if self.image_mode.lower() == 'bgr':
self.writer.append_data(image[..., ::-1])
else:
self.writer.append_data(image)
def close(self):
if self.writer is not None:
self.writer.close()
def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=12):
cmd = f'ffmpeg -i "{input_file}" -c:v {codec} -crf {crf} -r {fps} "{output_file}" -y'
exec_cmd(cmd)
def get_fps(filepath, default_fps=25):
try:
fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS)
if fps in (0, None):
fps = default_fps
except Exception as e:
log(e)
fps = default_fps
return fps
def has_audio_stream(video_path: str) -> bool:
"""
Check if the video file contains an audio stream.
:param video_path: Path to the video file
:return: True if the video contains an audio stream, False otherwise
"""
if osp.isdir(video_path):
return False
cmd = [
'ffprobe',
'-v', 'error',
'-select_streams', 'a',
'-show_entries', 'stream=codec_type',
'-of', 'default=noprint_wrappers=1:nokey=1',
f'"{video_path}"'
]
try:
# result = subprocess.run(cmd, capture_output=True, text=True)
result = exec_cmd(' '.join(cmd))
if result.returncode != 0:
log(f"Error occurred while probing video: {result.stderr}")
return False
# Check if there is any output from ffprobe command
return bool(result.stdout.strip())
except Exception as e:
log(f"Error occurred while probing video: {video_path}, you may need to install ffprobe! Now set audio to false!", style="bold red")
return False
def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
cmd = [
'ffmpeg',
'-y',
'-i', f'"{silent_video_path}"',
'-i', f'"{audio_video_path}"',
'-map', '0:v',
'-map', '1:a',
'-c:v', 'copy',
'-shortest',
f'"{output_video_path}"'
]
try:
exec_cmd(' '.join(cmd))
log(f"Video with audio generated successfully: {output_video_path}")
except subprocess.CalledProcessError as e:
log(f"Error occurred: {e}")
def bb_intersection_over_union(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
iou = interArea / float(boxAArea + boxBArea - interArea)
return iou

View File

@ -1,19 +0,0 @@
# coding: utf-8
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
def viz_lmk(img_, vps, **kwargs):
"""可视化点"""
lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA
img_for_viz = img_.copy()
for pt in vps:
cv2.circle(
img_for_viz,
(int(pt[0]), int(pt[1])),
radius=kwargs.get("radius", 1),
color=(0, 255, 0),
thickness=kwargs.get("thickness", 1),
lineType=lineType,
)
return img_for_viz