feat: animals mode, several updates and improvements (#264)

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* chore: refactor

* chore: refactor

* chore: refactor

* fix: video cropping

* chore: refactor

* chore: remove timm

* merge: animal support (#258)

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

---------

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>

feat: update

feat: update

chore: stage

* chore: stage

* chore: refactor

* chore: refactor

* doc: update readme

* doc: update readme

* doc: update readme

* chore: refactor

* doc: update

* doc: update

* doc: update

* doc: update

* chore: rename

* doc: update

* doc: update

* chore: refactor

* doc: update

* chore: refactor

* chore: refactor

* doc: update

* chore: update clip feature

* chore: add landmark option

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

* doc: update

---------

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
Co-authored-by: zzzweakman <1819489045@qq.com>
This commit is contained in:
Jianzhu Guo 2024-08-02 22:39:05 +08:00 committed by GitHub
parent 3f394785fb
commit bbb2e33599
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
90 changed files with 10433 additions and 155 deletions

6
.gitignore vendored
View File

@ -12,6 +12,7 @@ __pycache__/
pretrained_weights/*.md
pretrained_weights/docs
pretrained_weights/liveportrait
pretrained_weights/liveportrait_animals
# Ipython notebook
*.ipynb
@ -26,3 +27,8 @@ gradio_temp/**
# Windows dependencies
ffmpeg/
LivePortrait_env/
# XPose build files
src/utils/dependencies/XPose/models/UniPose/ops/build
src/utils/dependencies/XPose/models/UniPose/ops/dist
src/utils/dependencies/XPose/models/UniPose/ops/MultiScaleDeformableAttention.egg-info

93
app.py
View File

@ -1,7 +1,7 @@
# coding: utf-8
"""
The entrance of the gradio
The entrance of the gradio for human
"""
import os
@ -59,9 +59,11 @@ def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline.execute_video(*args, **kwargs)
def gpu_wrapped_execute_image(*args, **kwargs):
return gradio_pipeline.execute_image(*args, **kwargs)
def gpu_wrapped_execute_image_retargeting(*args, **kwargs):
return gradio_pipeline.execute_image_retargeting(*args, **kwargs)
def gpu_wrapped_execute_video_retargeting(*args, **kwargs):
return gradio_pipeline.execute_video_retargeting(*args, **kwargs)
# assets
title_md = "assets/gradio/gradio_title.md"
@ -87,14 +89,20 @@ data_examples_v2v = [
# Define components first
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch")
head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw")
head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll")
retargeting_input_image = gr.Image(type="filepath")
retargeting_input_video = gr.Video()
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video(autoplay=False)
output_video_paste_back = gr.Video(autoplay=False)
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
@ -179,6 +187,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_stitching_input = gr.Checkbox(value=True, label="stitching")
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
@ -235,9 +246,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
cache_examples=False,
)
# Retargeting
# Retargeting Image
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True):
flag_do_crop_input_retargeting_image = gr.Checkbox(value=True, label="do crop (source)")
retargeting_source_scale.render()
eye_retargeting_slider.render()
lip_retargeting_slider.render()
@ -246,10 +258,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
head_yaw_slider.render()
head_roll_slider.render()
with gr.Row(visible=True):
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_retargeting = gr.Button("🚗 Retargeting Image", variant="primary")
with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Input"):
with gr.Accordion(open=True, label="Retargeting Image Input"):
retargeting_input_image.render()
gr.Examples(
examples=[
@ -286,15 +298,48 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
value="🧹 Clear"
)
# binding functions for buttons
process_button_retargeting.click(
# fn=gradio_pipeline.execute_image,
fn=gpu_wrapped_execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input],
outputs=[output_image, output_image_paste_back],
show_progress=True
)
# Retargeting Video
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting_video.md"), visible=True)
with gr.Row(visible=True):
flag_do_crop_input_retargeting_video = gr.Checkbox(value=True, label="do crop (source)")
video_retargeting_source_scale.render()
video_lip_retargeting_slider.render()
driving_smooth_observation_variance_retargeting.render()
with gr.Row(visible=True):
process_button_retargeting_video = gr.Button("🍄 Retargeting Video", variant="primary")
with gr.Row(visible=True):
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Video Input"):
retargeting_input_video.render()
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s13.mp4")],
# [osp.join(example_portrait_dir, "s18.mp4")],
[osp.join(example_portrait_dir, "s20.mp4")],
[osp.join(example_portrait_dir, "s29.mp4")],
[osp.join(example_portrait_dir, "s32.mp4")],
],
inputs=[retargeting_input_video],
cache_examples=False,
)
with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"):
output_video.render()
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_video_paste_back.render()
with gr.Row(visible=True):
process_button_reset_retargeting = gr.ClearButton(
[
video_lip_retargeting_slider,
retargeting_input_video,
output_video,
output_video_paste_back
],
value="🧹 Clear"
)
# binding functions for buttons
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
@ -304,6 +349,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_stitching_input,
driving_option_input,
driving_multiplier,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
scale,
@ -320,11 +368,26 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
)
retargeting_input_image.change(
fn=gradio_pipeline.init_retargeting,
fn=gradio_pipeline.init_retargeting_image,
inputs=[retargeting_source_scale, retargeting_input_image],
outputs=[eye_retargeting_slider, lip_retargeting_slider]
)
process_button_retargeting.click(
# fn=gradio_pipeline.execute_image,
fn=gpu_wrapped_execute_image_retargeting,
inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input_retargeting_image],
outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_retargeting_video.click(
fn=gpu_wrapped_execute_video_retargeting,
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video],
outputs=[output_video, output_video_paste_back],
show_progress=True
)
demo.launch(
server_port=args.server_port,
share=args.share,

248
app_animals.py Normal file
View File

@ -0,0 +1,248 @@
# coding: utf-8
"""
The entrance of the gradio for animal
"""
import os
import tyro
import subprocess
import gradio as gr
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipelineAnimal
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
if osp.exists(ffmpeg_dir):
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
gradio_pipeline_animal: GradioPipelineAnimal = GradioPipelineAnimal(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
if args.gradio_temp_dir not in (None, ''):
os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
os.makedirs(args.gradio_temp_dir, exist_ok=True)
def gpu_wrapped_execute_video(*args, **kwargs):
return gradio_pipeline_animal.execute_video(*args, **kwargs)
# assets
title_md = "assets/gradio/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
data_examples_i2v = [
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "d3.mp4"), True, False, False, False],
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "d6.mp4"), True, False, False, False],
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "d19.mp4"), True, False, False, False],
]
data_examples_i2v_pickle = [
[osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "wink.pkl"), True, False, False, False],
[osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "talking.pkl"), True, False, False, False],
[osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "aggrieved.pkl"), True, False, False, False],
]
#################### interface logic ####################
# Define components first
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
output_video_i2v_gif = gr.Image(type="numpy")
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio/gradio_description_upload_animal.md"))
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="🐱 Source Animal Image"):
source_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_portrait_dir, "s25.jpg")],
[osp.join(example_portrait_dir, "s30.jpg")],
[osp.join(example_portrait_dir, "s31.jpg")],
[osp.join(example_portrait_dir, "s32.jpg")],
[osp.join(example_portrait_dir, "s39.jpg")],
[osp.join(example_portrait_dir, "s40.jpg")],
[osp.join(example_portrait_dir, "s41.jpg")],
[osp.join(example_portrait_dir, "s38.jpg")],
[osp.join(example_portrait_dir, "s36.jpg")],
],
inputs=[source_image_input],
cache_examples=False,
)
with gr.Accordion(open=True, label="Cropping Options for Source Image"):
with gr.Row():
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column():
with gr.Tabs():
with gr.TabItem("📁 Driving Pickle") as tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File()
gr.Examples(
examples=[
[osp.join(example_video_dir, "wink.pkl")],
[osp.join(example_video_dir, "shy.pkl")],
[osp.join(example_video_dir, "aggrieved.pkl")],
[osp.join(example_video_dir, "open_lip.pkl")],
[osp.join(example_video_dir, "laugh.pkl")],
[osp.join(example_video_dir, "talking.pkl")],
[osp.join(example_video_dir, "shake_face.pkl")],
],
inputs=[driving_video_pickle_input],
cache_examples=False,
)
with gr.TabItem("🎞️ Driving Video") as tab_video:
with gr.Accordion(open=True, label="Driving Video"):
driving_video_input = gr.Video()
gr.Examples(
examples=[
# [osp.join(example_video_dir, "d0.mp4")],
# [osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
[osp.join(example_video_dir, "d3.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
)
tab_selection = gr.Textbox(visible=False)
tab_pickle.select(lambda: "Pickle", None, tab_selection)
tab_video.select(lambda: "Video", None, tab_selection)
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
with gr.Row():
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05)
vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Row():
with gr.Accordion(open=False, label="Animation Options"):
with gr.Row():
flag_stitching = gr.Checkbox(value=False, label="stitching (not recommended)")
flag_remap_input = gr.Checkbox(value=False, label="paste-back (not recommended)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier", minimum=0.0, maximum=2.0, step=0.02)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
with gr.Row():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the cropped image space"):
output_video_i2v.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated gif in the cropped image space"):
output_video_i2v_gif.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, driving_video_input, output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], value="🧹 Clear")
with gr.Row():
# Examples
gr.Markdown("## You could also choose the examples below by one click ⬇️")
with gr.Row():
with gr.Tabs():
with gr.TabItem("📁 Driving Pickle") as tab_video:
gr.Examples(
examples=data_examples_i2v_pickle,
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_pickle_input,
flag_do_crop_input,
flag_stitching,
flag_remap_input,
flag_crop_driving_video_input,
],
outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
examples_per_page=len(data_examples_i2v_pickle),
cache_examples=False,
)
with gr.TabItem("🎞️ Driving Video") as tab_video:
gr.Examples(
examples=data_examples_i2v,
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_input,
flag_do_crop_input,
flag_stitching,
flag_remap_input,
flag_crop_driving_video_input,
],
outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
examples_per_page=len(data_examples_i2v),
cache_examples=False,
)
process_button_animation.click(
fn=gpu_wrapped_execute_video,
inputs=[
source_image_input,
driving_video_input,
driving_video_pickle_input,
flag_do_crop_input,
flag_remap_input,
driving_multiplier,
flag_stitching,
flag_crop_driving_video_input,
scale,
vx_ratio,
vy_ratio,
scale_crop_driving_video,
vx_ratio_crop_driving_video,
vy_ratio_crop_driving_video,
tab_selection,
],
outputs=[output_video_i2v, output_video_concat_i2v, output_video_i2v_gif],
show_progress=True
)
demo.launch(
server_port=args.server_port,
share=args.share,
server_name=args.server_name
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 344 KiB

View File

@ -0,0 +1,68 @@
## 2024/08/02
<table class="center" style="width: 80%; margin-left: auto; margin-right: auto;">
<tr>
<td style="text-align: center"><b>Animals Singing Dance Monkey 🎤</b></td>
</tr>
<tr>
<td style="border: none; text-align: center;">
<video controls loop src="https://github.com/user-attachments/assets/38d5b6e5-d29b-458d-9f2c-4dd52546cb41" muted="false" style="width: 60%;"></video>
</td>
</tr>
</table>
🎉 We are excited to announce the release of a new version featuring animals mode, along with several other updates. Special thanks to the dedicated efforts of the LivePortrait team. 💪
### Updates on Animals mode
We are pleased to announce the release of the animals mode, which is fine-tuned on approximately 230K frames of various animals (mostly cats and dogs). The trained weights have been updated in the `liveportrait_animals` subdirectory, available on [HuggingFace](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/) or [Google Drive](https://drive.google.com/drive/u/0/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib). You should [download the weights](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#2-download-pretrained-weights) before running. There are two ways to run this mode.
> Please note that we have not trained the stitching and retargeting modules for the animals model due to several technical issues. This may be addressed in future updates. Therefore, we recommend using the `--no_flag_stitching` option when running the model.
Before launching, ensure you have installed `transformers==4.22.0`, `pillow>=10.2.0`, which are already updated in [`requirements_base.txt`](../../../requirements_base.txt). We have chosen [XPose](https://github.com/IDEA-Research/X-Pose) as the keypoints detector for animals. This relies on `transformers` and requires building an OP named `MultiScaleDeformableAttention` by
```bash
cd src/utils/dependencies/XPose/models/UniPose/ops
python setup.py build install
cd - # equal to cd ../../../../../../../
```
You can run the model using the script `inference_animal.py`:
```bash
python inference_animal.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --no_flag_stitching --driving_multiplier 1.75
```
Alternatively, we recommend using Gradio. Simply launch it by running:
```bash
python app_animal.py # --server_port 8889 --server_name "0.0.0.0" --share
```
> [!WARNING]
> [XPose](https://github.com/IDEA-Research/X-Pose) is only for Non-commercial Scientific Research Purposes, you should remove and replace it with other detectors if you use it for commercial purposes.
### Updates on Humans Mode
- **Driving Options**: We have introduced an `expression-friendly` driving option to **reduce head wobbling**, now set as the default. While it may be less effective with large head poses, you can also select the `pose-friendly` option, which is the same as the previous version. This can be set using `--driving_option` or selected in the Gradio interface. Additionally, we added a `--driving_multiplier` option to adjust driving intensity, with a default value of 1, which can also be set in the Gradio interface.
- **Retargeting Video in Gradio**: We have implemented a video retargeting feature. You can specify a `target lip-open ratio` to adjust the mouth movement in the source video. For instance, setting it to 0 will close the mouth in the source video 🤐.
### Others
- [**Poe supports LivePortrait**](https://poe.com/LivePortrait). Check out the news on [X](https://x.com/poe_platform/status/1816136105781256260).
- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) (1.1K 🌟) now includes MediaPipe as an alternative to InsightFace, ensuring the license remains under MIT and Apache 2.0.
**Below are some screenshots of the new features and improvements:**
| ![The Gradio Interface of Animals Mode](../animals-mode-gradio-2024-08-02.jpg) |
|:---:|
| **The Gradio Interface of Animals Mode** |
| ![Driving Options and Multiplier](../driving-option-multiplier-2024-08-02.jpg) |
|:---:|
| **Driving Options and Multiplier** |
| ![The Feature of Retargeting Video](../retargeting-video-2024-08-02.jpg) |
|:---:|
| **The Feature of Retargeting Video** |

View File

@ -0,0 +1,28 @@
## The directory structure of `pretrained_weights`
```text
pretrained_weights
├── insightface
│ └── models
│ └── buffalo_l
│ ├── 2d106det.onnx
│ └── det_10g.onnx
├── liveportrait
│ ├── base_models
│ │ ├── appearance_feature_extractor.pth
│ │ ├── motion_extractor.pth
│ │ ├── spade_generator.pth
│ │ └── warping_module.pth
│ ├── landmark.onnx
│ └── retargeting_models
│ └── stitching_retargeting_module.pth
└── liveportrait_animals
├── base_models
│ ├── appearance_feature_extractor.pth
│ ├── motion_extractor.pth
│ ├── spade_generator.pth
│ └── warping_module.pth
├── retargeting_models
│ └── stitching_retargeting_module.pth
└── xpose.pth
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

View File

@ -0,0 +1,29 @@
## Install FFmpeg
Make sure you have `ffmpeg` and `ffprobe` installed on your system. If you don't have them installed, follow the instructions below.
> [!Note]
> The installation is copied from [SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 🤗
### Conda Users
```bash
conda install ffmpeg
```
### Ubuntu/Debian Users
```bash
sudo apt install ffmpeg
sudo apt install libsox-dev
conda install -c conda-forge 'ffmpeg<7'
```
### Windows Users
Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root.
### MacOS Users
```bash
brew install ffmpeg
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 491 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

13
assets/docs/speed.md Normal file
View File

@ -0,0 +1,13 @@
### Speed
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
| Model | Parameters(M) | Model Size(MB) | Inference(ms) |
|-----------------------------------|:-------------:|:--------------:|:-------------:|
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
| Motion Extractor | 28.12 | 108 | 0.84 |
| Spade Generator | 55.37 | 212 | 7.59 |
| Warping Module | 45.53 | 174 | 5.21 |
| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 390 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 500 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 220 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

View File

@ -6,8 +6,8 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
<div>
<h2>Retargeting</h2>
<p>Upload a Source Portrait as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times.
<h2>Retargeting Image</h2>
<p>Upload a Source Portrait as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting Image</strong> button. You can try running it multiple times.
<br>
<strong>😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!</strong></p>
</div>

View File

@ -0,0 +1,9 @@
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
<div>
<h2>Retargeting Video</h2>
<p>Upload a Source Video as Retargeting Input, then drag the sliders and click the <strong>🍄 Retargeting Video</strong> button. You can try running it multiple times.
<br>
<strong>🤐 Set target lip-open ratio to 0 to see what's going on!</strong></p>
</div>
</div>

View File

@ -0,0 +1,16 @@
<br>
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
<div style="flex: 1; text-align: center; margin-right: 20px;">
<div style="display: inline-block;">
Step 1: Upload a <strong>Source Animal Image</strong> (any aspect ratio) ⬇️
</div>
</div>
<div style="flex: 1; text-align: center; margin-left: 20px;">
<div style="display: inline-block;">
Step 2: Upload a <strong>Driving Pickle</strong> or <strong>Driving Video</strong> (any aspect ratio) ⬇️
</div>
<div style="display: inline-block; font-size: 0.8em;">
<strong>Tips:</strong> Focus on the head, minimize shoulder movement, <strong>neutral expression</strong> in first frame.
</div>
</div>
</div>

View File

@ -1,4 +1,7 @@
# coding: utf-8
"""
for human
"""
import os
import os.path as osp

65
inference_animal.py Normal file
View File

@ -0,0 +1,65 @@
# coding: utf-8
"""
for animal
"""
import os
import os.path as osp
import tyro
import subprocess
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline_animal import LivePortraitPipelineAnimal
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source):
raise FileNotFoundError(f"source info not found: {args.source}")
if not osp.exists(args.driving):
raise FileNotFoundError(f"driving info not found: {args.driving}")
def main():
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
if osp.exists(ffmpeg_dir):
os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
if not fast_check_ffmpeg():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
)
fast_check_args(args)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__)
live_portrait_pipeline_animal = LivePortraitPipelineAnimal(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)
# run
live_portrait_pipeline_animal.execute(args)
if __name__ == "__main__":
main()

102
readme.md
View File

@ -39,6 +39,7 @@
## 🔥 Updates
- **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)!
- **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main) or [BaiduYun](https://pan.baidu.com/s/1FWsWqKe0eNfXrwjEhhCqlw?pwd=86q2). Simply unzip and double-click `run_windows.bat` to enjoy!
- **`2024/07/24`**: 🎨 We support pose editing for source portraits in the Gradio interface. Weve also lowered the default detection threshold to increase recall. [Have fun](assets/docs/changelog/2024-07-24.md)!
- **`2024/07/19`**: ✨ We support 🎞️ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md).
@ -55,7 +56,7 @@ This repo, named **LivePortrait**, contains the official PyTorch implementation
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
## Getting Started 🏁
### 1. Clone the code and prepare the environment
### 1. Clone the code and prepare the environment 🛠️
```bash
git clone https://github.com/KwaiVGI/LivePortrait
cd LivePortrait
@ -64,56 +65,42 @@ cd LivePortrait
conda create -n LivePortrait python=3.9
conda activate LivePortrait
# install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt
# for macOS with Apple Silicon users
pip install -r requirements_macOS.txt
```
**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/download.html) installed, including both `ffmpeg` and `ffprobe`!
> [!Note]
> Make sure your system has [`git`](https://git-scm.com/), [`conda`](https://anaconda.org/anaconda/conda), and [`FFmpeg`](https://ffmpeg.org/download.html) installed. For details on FFmpeg installation, see [**how to install FFmpeg**](assets/docs/how-to-install-ffmpeg.md).
### 2. Download pretrained weights
### 2. Download pretrained weights 📥
The easiest way to download the pretrained weights is from HuggingFace:
```bash
# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
git lfs install
# clone and move the weights
git clone https://huggingface.co/KwaiVGI/LivePortrait temp_pretrained_weights
mv temp_pretrained_weights/* pretrained_weights/
rm -rf temp_pretrained_weights
# !pip install -U "huggingface_hub[cli]"
huggingface-cli download KwaiVGI/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs"
```
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
Ensuring the directory structure is as follows, or contains:
```text
pretrained_weights
├── insightface
│ └── models
│ └── buffalo_l
│ ├── 2d106det.onnx
│ └── det_10g.onnx
└── liveportrait
├── base_models
│ ├── appearance_feature_extractor.pth
│ ├── motion_extractor.pth
│ ├── spade_generator.pth
│ └── warping_module.pth
├── landmark.onnx
└── retargeting_models
└── stitching_retargeting_module.pth
If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download:
```bash
# !pip install -U "huggingface_hub[cli]"
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download KwaiVGI/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs"
```
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn) (WIP). Unzip and place them in `./pretrained_weights`.
Ensuring the directory structure is as or contains [**this**](assets/docs/directory-structure.md).
### 3. Inference 🚀
#### Fast hands-on
#### Fast hands-on (humans) 👤
```bash
# For Linux and Windows
# For Linux and Windows users
python inference.py
# For macOS with Apple Silicon, Intel not supported, this maybe 20x slower than RTX 4090
# For macOS users with Apple Silicon (Intel is not tested). NOTE: this maybe 20x slower than RTX 4090
PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py
```
@ -136,12 +123,32 @@ python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving
python inference.py -h
```
#### Fast hands-on (animals) 🐱🐶
Animals mode is ONLY tested on Linux with NVIDIA GPU.
You need to build an OP named `MultiScaleDeformableAttention` first, which is used by [X-Pose](https://github.com/IDEA-Research/X-Pose), a general keypoint detection framework.
```bash
cd src/utils/dependencies/XPose/models/UniPose/ops
python setup.py build install
cd - # equal to cd ../../../../../../../
```
Then
```bash
python inference_animal.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --driving_multiplier 1.75 --no_flag_stitching
```
If the script runs successfully, you will get an output mp4 file named `animations/s39--wink_concat.mp4`.
<p align="center">
<img src="./assets/docs/inference-animals.gif" alt="image">
</p>
#### Driving video auto-cropping 📢📢📢
To use your own driving video, we **recommend**: ⬇️
- Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
- Focus on the head area, similar to the example videos.
- Minimize shoulder movement.
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
> [!IMPORTANT]
> To use your own driving video, we **recommend**: ⬇️
> - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`.
> - Focus on the head area, similar to the example videos.
> - Minimize shoulder movement.
> - Make sure the first frame of driving video is a frontal face with **neutral expression**.
Below is a auto-cropping case by `--flag_crop_driving_video`:
```bash
@ -163,10 +170,15 @@ We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src
```bash
# For Linux and Windows users (and macOS with Intel??)
python app.py
python app.py # humans mode
# For macOS with Apple Silicon users, Intel not supported, this maybe 20x slower than RTX 4090
PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py
PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py # humans mode
```
We also provide a Gradio interface of animals mode, which is only tested on Linux with NVIDIA GPU:
```bash
python app_animals.py # animals mode 🐱🐶
```
You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs!
@ -188,17 +200,7 @@ We have also provided a script to evaluate the inference speed of each module:
python speed.py
```
Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
| Model | Parameters(M) | Model Size(MB) | Inference(ms) |
|-----------------------------------|:-------------:|:--------------:|:-------------:|
| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
| Motion Extractor | 28.12 | 108 | 0.84 |
| Spade Generator | 55.37 | 212 | 7.59 |
| Warping Module | 45.53 | 174 | 5.21 |
| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*
The results are [**here**](./assets/docs/speed.md).
## Community Resources 🤗

View File

@ -20,3 +20,5 @@ imageio-ffmpeg==0.5.1
tyro==0.8.5
gradio==4.37.1
pykalman==0.9.7
transformers==4.22.0
pillow>=10.2.0

View File

@ -1,2 +1,4 @@
-r requirements_base.txt
--extra-index-url https://download.pytorch.org/whl/cpu
onnxruntime-silicon==1.16.3

View File

@ -13,7 +13,7 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig):
########## input arguments ##########
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait or video
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human)
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
@ -27,10 +27,12 @@ class ArgumentConfig(PrintableConfig):
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large
flag_relative_motion: bool = True # whether to use relative motion
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
flag_relative_motion: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
driving_option: Literal["expression-friendly", "pose-friendly"] = "expression-friendly" # "expression-friendly" or "pose-friendly"; "expression-friendly" would adapt the driving motion with the global multiplier, and could be used when the source is a human image
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
########## source crop arguments ##########

View File

@ -6,24 +6,28 @@ parameters used for crop faces
from dataclasses import dataclass
from .base_config import PrintableConfig
from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class CropConfig(PrintableConfig):
insightface_root: str = "../../pretrained_weights/insightface"
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
insightface_root: str = make_abs_path("../../pretrained_weights/insightface")
landmark_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait/landmark.onnx")
xpose_config_file_path: str = make_abs_path("../utils/dependencies/XPose/config_model/UniPose_SwinT.py")
xpose_embedding_cache_path: str = make_abs_path('../utils/resources/clip_embedding')
xpose_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait_animals/xpose.pth")
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
det_thresh: float = 0.1 # detection threshold
########## source image or video cropping option ##########
dsize: int = 512 # crop size
scale: float = 2.8 # scale factor
scale: float = 2.3 # scale factor
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down
max_face_num: int = 0 # max face number, 0 mean no limit
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
animal_face_type: str = "animal_face_9" # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks
########## driving video auto cropping option ##########
scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video
vx_ratio_crop_driving_video: float = 0.0 # adjust y offset

View File

@ -13,7 +13,7 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig):
# MODEL CONFIG, NOT EXPORTED PARAMS
# HUMAN MODEL CONFIG, NOT EXPORTED PARAMS
models_config: str = make_abs_path('./models.yaml') # portrait animation config
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
@ -21,6 +21,13 @@ class InferenceConfig(PrintableConfig):
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
# ANIMAL MODEL CONFIG, NOT EXPORTED PARAMS
checkpoint_F_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
checkpoint_M_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/motion_extractor.pth') # path to checkpoint pf M
checkpoint_G_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/spade_generator.pth') # path to checkpoint of G
checkpoint_W_animal: str = make_abs_path('../../pretrained_weights/liveportrait_animals/base_models/warping_module.pth') # path to checkpoint of W
checkpoint_S_animal: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip, NOTE: use human temporarily!
# EXPORTED PARAMS
flag_use_half_precision: bool = True
flag_crop_driving_video: bool = False
@ -37,6 +44,8 @@ class InferenceConfig(PrintableConfig):
flag_do_rot: bool = True
flag_force_cpu: bool = False
flag_do_torch_compile: bool = False
driving_option: str = "pose-friendly" # "expression-friendly" or "pose-friendly"
driving_multiplier: float = 1.0
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number

View File

@ -5,16 +5,23 @@ Pipeline for gradio
"""
import os.path as osp
import os
import cv2
from rich.progress import track
import gradio as gr
import numpy as np
import torch
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .live_portrait_pipeline_animal import LivePortraitPipelineAnimal
from .utils.io import load_img_online, load_video, resize_to_limit
from .utils.filter import smooth
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.helper import is_square_video
from .utils.video import get_fps, has_audio_stream, concat_frames, images2video, add_audio_to_video
from .utils.helper import is_square_video, mkdir, dct2device, basename
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
@ -28,6 +35,8 @@ def update_args(args, user_args):
class GradioPipeline(LivePortraitPipeline):
"""gradio for human
"""
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg)
@ -43,6 +52,9 @@ class GradioPipeline(LivePortraitPipeline):
flag_relative_input=True,
flag_do_crop_input=True,
flag_remap_input=True,
flag_stitching_input=True,
driving_option_input="pose-friendly",
driving_multiplier=1.0,
flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False,
scale=2.3,
@ -66,8 +78,8 @@ class GradioPipeline(LivePortraitPipeline):
if input_source_path is not None and input_driving_video_path is not None:
if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False:
flag_crop_driving_video_input = True
log("The source video is not square, the driving video will be cropped to square automatically.")
gr.Info("The source video is not square, the driving video will be cropped to square automatically.", duration=2)
log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
args_user = {
'source': input_source_path,
@ -75,6 +87,9 @@ class GradioPipeline(LivePortraitPipeline):
'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'flag_stitching': flag_stitching_input,
'driving_option': driving_option_input,
'driving_multiplier': driving_multiplier,
'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale,
@ -92,19 +107,19 @@ class GradioPipeline(LivePortraitPipeline):
# video driven animation
video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
return video_path, video_path_concat
else:
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
@torch.no_grad()
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop=True):
def execute_image_retargeting(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop_input_retargeting_image=True):
""" for single image retargeting
"""
if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
raise gr.Error("Invalid relative pose input 💥!", duration=5)
# disposable feature
f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop)
self.prepare_retargeting_image(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
@ -134,12 +149,15 @@ class GradioPipeline(LivePortraitPipeline):
# D(W(f_s; x_s, x_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
if flag_do_crop_input_retargeting_image:
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
else:
out_to_ori_blend = out
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
@torch.no_grad()
def prepare_retargeting(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True):
def prepare_retargeting_image(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True):
""" for single image retargeting
"""
if input_image is not None:
@ -151,11 +169,17 @@ class GradioPipeline(LivePortraitPipeline):
######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if flag_do_crop:
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
source_lmk_user = self.cropper.calc_lmk_from_cropped_image(img_rgb)
crop_M_c2o = None
mask_ori = None
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_s_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
x_s_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
@ -165,15 +189,12 @@ class GradioPipeline(LivePortraitPipeline):
############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
def init_retargeting(self, retargeting_source_scale: float, input_image = None):
def init_retargeting_image(self, retargeting_source_scale: float, input_image = None):
""" initialize the retargeting slider
"""
if input_image != None:
@ -191,3 +212,194 @@ class GradioPipeline(LivePortraitPipeline):
source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
return round(float(source_eye_ratio.mean()), 2), round(source_lip_ratio[0][0], 2)
return 0., 0.
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True):
""" retargeting the lip-open ratio of each source frame
"""
# disposable feature
device = self.live_portrait_wrapper.device
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
if input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
I_p_pstbk_lst = None
if flag_do_crop_input_retargeting_video:
I_p_pstbk_lst = []
I_p_lst = []
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
x_s_user_i = x_s_user_lst[i].to(device)
f_s_user_i = f_s_user_lst[i].to(device)
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
x_d_i_new = x_s_user_i + lip_delta_retargeting
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if flag_do_crop_input_retargeting_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(self.args.output_dir)
flag_source_has_audio = has_audio_stream(input_video)
######### build the final concatenation result #########
# source frame | generation
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
if flag_source_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=source_fps)
######### build the final result #########
if flag_source_has_audio:
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
add_audio_to_video(wfp, input_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
gr.Info("Run successfully!", duration=2)
return wfp_concat, wfp
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
""" for video retargeting
"""
if input_video is not None:
# gr.Info("Upload successfully!", duration=2)
args_user = {'scale': retargeting_source_scale}
self.args = update_args(self.args, args_user)
self.cropper.update_config(self.args.__dict__)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source video ########
source_rgb_lst = load_video(input_video)
source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
source_fps = int(get_fps(input_video))
n_frames = len(source_rgb_lst)
log(f"Load source video from {input_video}. FPS is {source_fps}")
if flag_do_crop:
ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
if len(ret_s["frame_crop_lst"]) != n_frames:
n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
else:
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
source_M_c2o_lst, mask_ori_lst = None, None
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
# save the motion template
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
c_d_lip_retargeting = [input_lip_ratio]
f_s_user_lst, x_s_user_lst, lip_delta_retargeting_lst = [], [], []
for i in track(range(n_frames), description='Preparing retargeting video...', total=n_frames):
x_s_info = source_template_dct['motion'][i]
x_s_info = dct2device(x_s_info, device)
x_s_user = x_s_info['x_s']
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
combined_lip_ratio_tensor_retargeting = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_retargeting, source_lmk)
lip_delta_retargeting = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor_retargeting)
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32))
lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting)
return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames
else:
# when press the clear button, go here
raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
class GradioPipelineAnimal(LivePortraitPipelineAnimal):
"""gradio for animal
"""
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
inference_cfg.flag_crop_driving_video = True # ensure the face_analysis_wrapper is enabled
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper_animal = self.live_portrait_wrapper_animal
self.args = args
@torch.no_grad()
def execute_video(
self,
input_source_image_path=None,
input_driving_video_path=None,
input_driving_video_pickle_path=None,
flag_do_crop_input=False,
flag_remap_input=False,
driving_multiplier=1.0,
flag_stitching=False,
flag_crop_driving_video_input=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
tab_selection=None,
):
""" for video-driven potrait animation
"""
input_source_path = input_source_image_path
if tab_selection == 'Video':
input_driving_path = input_driving_video_path
elif tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path
else:
input_driving_path = input_driving_video_pickle_path
if input_source_path is not None and input_driving_path is not None:
if osp.exists(input_driving_path) and tab_selection == 'Video' and is_square_video(input_driving_path) is False:
flag_crop_driving_video_input = True
log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
args_user = {
'source': input_source_path,
'driving': input_driving_path,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'driving_multiplier': driving_multiplier,
'flag_stitching': flag_stitching,
'flag_crop_driving_video': flag_crop_driving_video_input,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
'scale_crop_driving_video': scale_crop_driving_video,
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper_animal.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat, video_gif_path = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat, video_gif_path
else:
raise gr.Error("Please upload the source animal image, and driving video 🤗🤗🤗", duration=5)

View File

@ -1,7 +1,7 @@
# coding: utf-8
"""
Pipeline of LivePortrait
Pipeline of LivePortrait (Human)
"""
import torch
@ -21,7 +21,7 @@ from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import prepare_paste_back, paste_back
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, is_square_video
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, is_square_video, calc_motion_multiplier
from .utils.filter import smooth
from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk
@ -46,13 +46,13 @@ class LivePortraitPipeline(object):
'motion': [],
'c_eyes_lst': [],
'c_lip_lst': [],
'x_i_info_lst': [],
}
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
# collect s, R, δ and t for inference
I_i = I_lst[i]
x_i_info = self.live_portrait_wrapper.get_kp_info(I_i)
x_s = self.live_portrait_wrapper.transform_keypoint(x_i_info)
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
item_dct = {
@ -60,6 +60,8 @@ class LivePortraitPipeline(object):
'R': R_i.cpu().numpy().astype(np.float32),
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_i_info['t'].cpu().numpy().astype(np.float32),
'kp': x_i_info['kp'].cpu().numpy().astype(np.float32),
'x_s': x_s.cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
@ -70,7 +72,6 @@ class LivePortraitPipeline(object):
c_lip = c_lip_lst[i].astype(np.float32)
template_dct['c_lip_lst'].append(c_lip)
template_dct['x_i_info_lst'].append(x_i_info)
return template_dct
@ -238,18 +239,17 @@ class LivePortraitPipeline(object):
log(f"The animated video consists of {n_frames} frames.")
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
if flag_is_source_video: # source video
x_s_info_tiny = source_template_dct['motion'][i]
x_s_info_tiny = dct2device(x_s_info_tiny, device)
x_s_info = source_template_dct['motion'][i]
x_s_info = dct2device(x_s_info, device)
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
x_s_info = source_template_dct['x_i_info_lst'][i]
x_c_s = x_s_info['kp']
R_s = x_s_info_tiny['R']
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
x_c_s = x_s_info['kp']
R_s = x_s_info['R']
x_s =x_s_info['x_s']
# let lip-open scalar to be 0 at first if the input is a video
if flag_normalize_lip and inf_cfg.flag_relative_motion and source_lmk is not None:
@ -308,6 +308,14 @@ class LivePortraitPipeline(object):
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video:
if i == 0:
x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
# motion_multiplier *= inf_cfg.driving_multiplier
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
x_d_i_new = x_d_diff + x_s
# Algorithm 1:
if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting:
# without stitching or retargeting
@ -350,6 +358,7 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
x_d_i_new = x_s + (x_d_i_new - x_s) * inf_cfg.driving_multiplier
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
@ -386,7 +395,7 @@ class LivePortraitPipeline(object):
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat} with {wfp_concat_with_audio}")
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
@ -402,7 +411,7 @@ class LivePortraitPipeline(object):
log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp} with {wfp_with_audio}")
log(f"Replace {wfp_with_audio} with {wfp}")
# final log
if wfp_template not in (None, ''):

View File

@ -0,0 +1,237 @@
# coding: utf-8
"""
Pipeline of LivePortrait (Animal)
"""
import warnings
warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.")
warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly.")
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
import torch
torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import os
import os.path as osp
from rich.progress import track
from .config.argument_config import ArgumentConfig
from .config.inference_config import InferenceConfig
from .config.crop_config import CropConfig
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream, video2gif
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, calc_motion_multiplier
from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk
from .live_portrait_wrapper import LivePortraitWrapperAnimal
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
class LivePortraitPipelineAnimal(object):
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
self.live_portrait_wrapper_animal: LivePortraitWrapperAnimal = LivePortraitWrapperAnimal(inference_cfg=inference_cfg)
self.cropper: Cropper = Cropper(crop_cfg=crop_cfg, image_type='animal_face', flag_use_half_precision=inference_cfg.flag_use_half_precision)
def make_motion_template(self, I_lst, **kwargs):
n_frames = I_lst.shape[0]
template_dct = {
'n_frames': n_frames,
'output_fps': kwargs.get('output_fps', 25),
'motion': [],
}
for i in track(range(n_frames), description='Making driving motion templates...', total=n_frames):
# collect s, R, δ and t for inference
I_i = I_lst[i]
x_i_info = self.live_portrait_wrapper_animal.get_kp_info(I_i)
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
item_dct = {
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
'R': R_i.cpu().numpy().astype(np.float32),
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
't': x_i_info['t'].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
return template_dct
def execute(self, args: ArgumentConfig):
# for convenience
inf_cfg = self.live_portrait_wrapper_animal.inference_cfg
device = self.live_portrait_wrapper_animal.device
crop_cfg = self.cropper.crop_cfg
######## load source input ########
if is_image(args.source):
img_rgb = load_image_rgb(args.source)
img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division)
log(f"Load source image from {args.source}")
else: # source input is an unknown format
raise Exception(f"Unknown source format: {args.source}")
######## process driving info ########
flag_load_from_template = is_template(args.driving)
driving_rgb_crop_256x256_lst = None
wfp_template = None
if flag_load_from_template:
# NOTE: load from template, it is fast, but the cropping video is None
log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green')
driving_template_dct = load(args.driving)
n_frames = driving_template_dct['n_frames']
# set output_fps
output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps)
log(f'The FPS of template: {output_fps}')
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving) and is_video(args.driving):
# load from video file, AND make motion template
output_fps = int(get_fps(args.driving))
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_rgb_lst = load_video(args.driving)
n_frames = len(driving_rgb_lst)
######## make motion template ########
log("Start making driving motion template...")
if inf_cfg.flag_crop_driving_video:
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst = ret_d['frame_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
else:
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256
#######################################
# save the motion template
I_d_lst = self.live_portrait_wrapper_animal.prepare_videos(driving_rgb_crop_256x256_lst)
driving_template_dct = self.make_motion_template(I_d_lst, output_fps=output_fps)
wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}")
else:
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
######## prepare for pasteback ########
I_p_pstbk_lst = None
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
I_p_pstbk_lst = []
log("Prepared pasteback mask done.")
######## process source info ########
if inf_cfg.flag_do_crop:
crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg)
if crop_info is None:
raise Exception("No animal face detected in the source image!")
img_crop_256x256 = crop_info['img_crop_256x256']
else:
img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256
I_s = self.live_portrait_wrapper_animal.prepare_source(img_crop_256x256)
x_s_info = self.live_portrait_wrapper_animal.get_kp_info(I_s)
x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper_animal.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper_animal.transform_keypoint(x_s_info)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
######## animate ########
I_p_lst = []
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
delta_new = x_d_i_info['exp']
t_new = x_d_i_info['t']
t_new[..., 2].fill_(0) # zero tz
scale_new = x_s_info['scale']
x_d_i = scale_new * (x_c_s @ R_d_i + delta_new) + t_new
if i == 0:
x_d_0 = x_d_i
motion_multiplier = calc_motion_multiplier(x_s, x_d_0)
x_d_diff = (x_d_i - x_d_0) * motion_multiplier
x_d_i = x_d_diff + x_s
if not inf_cfg.flag_stitching:
pass
else:
x_d_i = self.live_portrait_wrapper_animal.stitching(x_s, x_d_i)
x_d_i = x_s + (x_d_i - x_s) * inf_cfg.driving_multiplier
out = self.live_portrait_wrapper_animal.warp_decode(f_s, x_s, x_d_i)
I_p_i = self.live_portrait_wrapper_animal.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching:
I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float)
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(args.output_dir)
wfp_concat = None
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
######### build the final concatenation result #########
# driving frame | source image | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_driving_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.driving
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build the final result #########
if flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
audio_from_which_video = args.driving
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}')
# build the gif
wfp_gif = video2gif(wfp)
log(f'Animated gif: {wfp_gif}')
return wfp, wfp_concat, wfp_gif

View File

@ -1,7 +1,7 @@
# coding: utf-8
"""
Wrapper for LivePortrait core functions
Wrappers for LivePortrait core functions
"""
import contextlib
@ -20,6 +20,9 @@ from .utils.rprint import rlog as log
class LivePortraitWrapper(object):
"""
Wrapper for Human
"""
def __init__(self, inference_cfg: InferenceConfig):
@ -37,20 +40,20 @@ class LivePortraitWrapper(object):
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor done.')
log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F)} done.')
# init M
self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor done.')
log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M)} done.')
# init W
self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
log(f'Load warping_module done.')
log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W)} done.')
# init G
self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
log(f'Load spade_generator done.')
log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G)} done.')
# init S and R
if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module done.')
log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S)} done.')
else:
self.stitching_retargeting_module = None
# Optimize for inference
@ -326,3 +329,50 @@ class LivePortraitWrapper(object):
# [c_s,lip, c_d,lip,i]
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
return combined_lip_ratio_tensor
class LivePortraitWrapperAnimal(LivePortraitWrapper):
"""
Wrapper for Animal
"""
def __init__(self, inference_cfg: InferenceConfig):
# super().__init__(inference_cfg) # 调用父类的初始化方法
self.inference_cfg = inference_cfg
self.device_id = inference_cfg.device_id
self.compile = inference_cfg.flag_do_torch_compile
if inference_cfg.flag_force_cpu:
self.device = 'cpu'
else:
if torch.backends.mps.is_available():
self.device = 'mps'
else:
self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F_animal, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F_animal)} done.')
# init M
self.motion_extractor = load_model(inference_cfg.checkpoint_M_animal, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M_animal)} done.')
# init W
self.warping_module = load_model(inference_cfg.checkpoint_W_animal, model_config, self.device, 'warping_module')
log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W_animal)} done.')
# init G
self.spade_generator = load_model(inference_cfg.checkpoint_G_animal, model_config, self.device, 'spade_generator')
log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G_animal)} done.')
# init S and R
if inference_cfg.checkpoint_S_animal is not None and osp.exists(inference_cfg.checkpoint_S_animal):
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S_animal, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S_animal)} done.')
else:
self.stitching_retargeting_module = None
# Optimize for inference
if self.compile:
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
self.timer = Timer()

View File

@ -11,7 +11,8 @@ import torch
import torch.nn.utils.spectral_norm as spectral_norm
import math
import warnings
import collections.abc
from itertools import repeat
def kp2gaussian(kp, spatial_size, kp_variance):
"""
@ -439,3 +440,13 @@ class DropPath(nn.Module):
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)

View File

@ -0,0 +1,138 @@
# coding: utf-8
"""
face detectoin and alignment using XPose
"""
import os
import pickle
import torch
import numpy as np
from PIL import Image
from torchvision.ops import nms
from .timer import Timer
from .rprint import rlog as log
from .helper import clean_state_dict
from .dependencies.XPose import transforms as T
from .dependencies.XPose.models import build_model
from .dependencies.XPose.predefined_keypoints import *
from .dependencies.XPose.util import box_ops
from .dependencies.XPose.util.config import Config
class XPoseRunner(object):
def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
self.device_id = kwargs.get("device_id", 0)
self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
self.timer = Timer()
# Load cached embeddings if available
try:
with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
print("Loaded cached embeddings from file.")
except Exception:
raise ValueError("Could not load clip embeddings from file, please check your file path.")
def load_animal_model(self, model_config_path, model_checkpoint_path, device):
args = Config.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
def load_image(self, input_image):
image_pil = input_image.convert("RGB")
transform = T.Compose([
T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image, _ = transform(image_pil, None)
return image_pil, image
def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
instance_list = instance_text_prompt.split(',')
if len(keypoint_text_prompt) == 9:
# torch.Size([1, 512]) torch.Size([9, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
elif len(keypoint_text_prompt) ==68:
# torch.Size([1, 512]) torch.Size([68, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
else:
raise ValueError("Invalid number of keypoint embeddings.")
target = {
"instance_text_prompt": instance_list,
"keypoint_text_prompt": keypoint_text_prompt,
"object_embeddings_text": ins_text_embeddings.float(),
"kpts_embeddings_text": torch.cat((kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), dim=0),
"kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
}
self.model = self.model.to(self.device)
image = image.to(self.device)
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
outputs = self.model(image[None], [target])
logits = outputs["pred_logits"].sigmoid()[0]
boxes = outputs["pred_boxes"][0]
keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]
logits_filt = logits.cpu().clone()
boxes_filt = boxes.cpu().clone()
keypoints_filt = keypoints.cpu().clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask]
boxes_filt = boxes_filt[filt_mask]
keypoints_filt = keypoints_filt[filt_mask]
keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], iou_threshold=IoU_threshold)
filtered_boxes = boxes_filt[keep_indices]
filtered_keypoints = keypoints_filt[keep_indices]
return filtered_boxes, filtered_keypoints
def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
if keypoint_text_example in globals():
keypoint_dict = globals()[keypoint_text_example]
elif instance_text_prompt in globals():
keypoint_dict = globals()[instance_text_prompt]
else:
keypoint_dict = globals()["animal"]
keypoint_text_prompt = keypoint_dict.get("keypoints")
keypoint_skeleton = keypoint_dict.get("skeleton")
image_pil, image = self.load_image(input_image)
boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold)
size = image_pil.size
H, W = size[1], size[0]
keypoints_filt = keypoints_filt[0].squeeze(0)
kp = np.array(keypoints_filt.cpu())
num_kpts = len(keypoint_text_prompt)
Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
Z = Z.reshape(num_kpts * 2)
x = Z[0::2]
y = Z[1::2]
return np.stack((x, y), axis=1)
def warmup(self):
self.timer.tic()
img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)
elapse = self.timer.toc()
log(f'XPoseRunner warmup time: {elapse:.3f}s')

View File

@ -136,6 +136,29 @@ def parse_pt2_from_pt5(pt5, use_lip=True):
], axis=0)
return pt2
def parse_pt2_from_pt9(pt9, use_lip=True):
'''
parsing the 2 points according to the 9 points, which cancels the roll
['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip']
'''
if use_lip:
pt9 = np.stack([
(pt9[2] + pt9[3]) / 2, # left eye
(pt9[0] + pt9[1]) / 2, # right eye
pt9[4],
(pt9[5] + pt9[6] ) / 2 # lip
], axis=0)
pt2 = np.stack([
(pt9[0] + pt9[1]) / 2, # eye
pt9[3] # lip
], axis=0)
else:
pt2 = np.stack([
(pt9[2] + pt9[3]) / 2,
(pt9[0] + pt9[1]) / 2,
], axis=0)
return pt2
def parse_pt2_from_pt_x(pts, use_lip=True):
if pts.shape[0] == 101:
@ -151,6 +174,8 @@ def parse_pt2_from_pt_x(pts, use_lip=True):
elif pts.shape[0] > 101:
# take the first 101 points
pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip)
elif pts.shape[0] == 9:
pt2 = parse_pt2_from_pt9(pts, use_lip=use_lip)
else:
raise Exception(f'Unknow shape: {pts.shape}')

View File

@ -1,12 +1,13 @@
# coding: utf-8
import os.path as osp
from dataclasses import dataclass, field
from typing import List, Tuple, Union
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import torch
import numpy as np
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
from PIL import Image
from typing import List, Tuple, Union
from dataclasses import dataclass, field
from ..config.crop_config import CropConfig
from .crop import (
@ -18,8 +19,7 @@ from .crop import (
from .io import contiguous
from .rprint import rlog as log
from .face_analysis_diy import FaceAnalysisDIY
from .landmark_runner import LandmarkRunner
from .human_landmark_runner import LandmarkRunner as HumanLandmark
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
@ -41,6 +41,7 @@ class Trajectory:
class Cropper(object):
def __init__(self, **kwargs) -> None:
self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
self.image_type = kwargs.get("image_type", 'human_face')
device_id = kwargs.get("device_id", 0)
flag_force_cpu = kwargs.get("flag_force_cpu", False)
if flag_force_cpu:
@ -55,20 +56,31 @@ class Cropper(object):
else:
device = "cuda"
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
self.landmark_runner = LandmarkRunner(
ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
self.face_analysis_wrapper = FaceAnalysisDIY(
name="buffalo_l",
root=self.crop_cfg.insightface_root,
providers=face_analysis_wrapper_provider,
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh)
self.face_analysis_wrapper.warmup()
self.human_landmark_runner = HumanLandmark(
ckpt_path=self.crop_cfg.landmark_ckpt_path,
onnx_provider=device,
device_id=device_id,
)
self.landmark_runner.warmup()
self.human_landmark_runner.warmup()
self.face_analysis_wrapper = FaceAnalysisDIY(
name="buffalo_l",
root=make_abs_path(self.crop_cfg.insightface_root),
providers=face_analysis_wrapper_provider,
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh)
self.face_analysis_wrapper.warmup()
if self.image_type == "animal_face":
from .animal_landmark_runner import XPoseRunner as AnimalLandmarkRunner
self.animal_landmark_runner = AnimalLandmarkRunner(
model_config_path=self.crop_cfg.xpose_config_file_path,
model_checkpoint_path=self.crop_cfg.xpose_ckpt_path,
embeddings_cache_path=self.crop_cfg.xpose_embedding_cache_path,
flag_use_half_precision=kwargs.get("flag_use_half_precision", True),
)
self.animal_landmark_runner.warmup()
def update_config(self, user_args):
for k, v in user_args.items():
@ -78,24 +90,39 @@ class Cropper(object):
def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig):
# crop a source image and get neccessary information
img_rgb = img_rgb_.copy() # copy it
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
src_face = self.face_analysis_wrapper.get(
img_bgr,
flag_do_landmark_2d_106=True,
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
log("No face detected in the source image.")
return None
elif len(src_face) > 1:
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
if self.image_type == "human_face":
src_face = self.face_analysis_wrapper.get(
img_bgr,
flag_do_landmark_2d_106=True,
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
# NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0]
lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
if len(src_face) == 0:
log("No face detected in the source image.")
return None
elif len(src_face) > 1:
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
# NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0]
lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
else:
tmp_dct = {
'animal_face_9': 'animal_face',
'animal_face_68': 'face'
}
img_rgb_pil = Image.fromarray(img_rgb)
lmk = self.animal_landmark_runner.run(
img_rgb_pil,
'face',
tmp_dct[crop_cfg.animal_face_type],
0,
0
)
# crop the face
ret_dct = crop_image(
@ -108,12 +135,15 @@ class Cropper(object):
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.landmark_runner.run(img_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
if self.image_type == "human_face":
lmk = self.human_landmark_runner.run(img_rgb, lmk)
ret_dct["lmk_crop"] = lmk
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
else:
# 68x2 or 9x2
ret_dct["lmk_crop"] = lmk
return ret_dct
@ -131,7 +161,7 @@ class Cropper(object):
log(f"More than one face detected in the image, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(img_rgb_, lmk)
lmk = self.human_landmark_runner.run(img_rgb_, lmk)
return lmk
@ -155,11 +185,11 @@ class Cropper(object):
log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
# TODO: add IOU check for tracking
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
@ -174,7 +204,7 @@ class Cropper(object):
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.landmark_runner.run(frame_rgb, lmk)
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
@ -209,10 +239,10 @@ class Cropper(object):
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
lmk = self.human_landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
@ -270,10 +300,10 @@ class Cropper(object):
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb_crop, lmk)
lmk = self.human_landmark_runner.run(frame_rgb_crop, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
lmk = self.human_landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)

View File

@ -0,0 +1,125 @@
_base_ = ['coco_transformer.py']
use_label_enc = True
num_classes=2
lr = 0.0001
param_dict_type = 'default'
lr_backbone = 1e-05
lr_backbone_names = ['backbone.0']
lr_linear_proj_names = ['reference_points', 'sampling_offsets']
lr_linear_proj_mult = 0.1
ddetr_lr_param = False
batch_size = 2
weight_decay = 0.0001
epochs = 12
lr_drop = 11
save_checkpoint_interval = 100
clip_max_norm = 0.1
onecyclelr = False
multi_step_lr = False
lr_drop_list = [33, 45]
modelname = 'UniPose'
frozen_weights = None
backbone = 'swin_T_224_1k'
dilation = False
position_embedding = 'sine'
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
unic_layers = 0
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
pdetr3_bbox_embed_diff_each_layer = False
pdetr3_refHW = -1
random_refpoints_xy = False
fix_refpoints_hw = -1
dabdetr_yolo_like_anchor_update = False
dabdetr_deformable_encoder = False
dabdetr_deformable_decoder = False
use_deformable_box_attn = False
box_attn_type = 'roi_align'
dec_layer_number = None
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
decoder_layer_noise = False
dln_xy_noise = 0.2
dln_hw_noise = 0.2
add_channel_attention = False
add_pos_value = False
two_stage_type = 'standard'
two_stage_pat_embed = 0
two_stage_add_query_num = 0
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
two_stage_learn_wh = False
two_stage_default_hw = 0.05
two_stage_keep_all_tokens = False
num_select = 50
transformer_activation = 'relu'
batch_norm_type = 'FrozenBatchNorm2d'
masks = False
decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content']
matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
decoder_module_seq = ['sa', 'ca', 'ffn']
nms_iou_threshold = -1
dec_pred_bbox_embed_share = True
dec_pred_class_embed_share = True
use_dn = True
dn_number = 100
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef=1.0
dn_bbox_coef=1.0
embed_init_tgt = True
dn_labelbook_size = 2000
match_unstable_error = True
# for ema
use_ema = True
ema_decay = 0.9997
ema_epoch = 0
use_detached_boxes_dec_out = False
max_text_len = 256
shuffle_type = None
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = False # True
use_transformer_ckpt = True
text_encoder_type = 'bert-base-uncased'
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
num_body_points=68
binary_query_selection = False
use_cdn = True
ffn_extra_layernorm = False
fix_size=False

View File

@ -0,0 +1,8 @@
data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
data_aug_max_size = 1333
data_aug_scales2_resize = [400, 500, 600]
data_aug_scales2_crop = [384, 600]
data_aug_scale_overlap = None

View File

@ -0,0 +1,10 @@
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from .unipose import build_unipose

View File

@ -0,0 +1,373 @@
# ------------------------------------------------------------------------
# UniPose
# url: https://github.com/IDEA-Research/UniPose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# ED-Pose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from codes in torch.nn
# ------------------------------------------------------------------------
"""
MultiheadAttention that support query, key, and value to have different dimensions.
Query, key, and value projections are removed.
Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
"""
import warnings
import torch
from torch.nn.modules.linear import Linear
from torch.nn.init import constant_
from torch.nn.modules.module import Module
from torch._jit_internal import Optional, Tuple
try:
from torch.overrides import has_torch_function, handle_torch_function
except:
from torch._overrides import has_torch_function, handle_torch_function
from torch.nn.functional import linear, pad, softmax, dropout
Tensor = torch.Tensor
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
Note: if kdim and vdim are None, they will be set to embed_dim such that
query, key, and value have the same number of features.
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
vdim = vdim if vdim is not None else embed_dim
self.out_proj = Linear(vdim , vdim)
self.in_proj_bias = None
self.in_proj_weight = None
self.bias_k = self.bias_v = None
self.q_proj_weight = None
self.k_proj_weight = None
self.v_proj_weight = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
constant_(self.out_proj.bias, 0.)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if '_qkv_same_embed_dim' not in state:
state['_qkv_same_embed_dim'] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
if not self._qkv_same_embed_dim:
return multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, out_dim=self.vdim)
else:
return multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, out_dim=self.vdim)
def multi_head_attention_forward(query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
out_dim: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
if not torch.jit.is_scripting():
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
out_proj_weight, out_proj_bias)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward, tens_ops, query, key, value,
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
need_weights=need_weights, attn_mask=attn_mask,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
v_head_dim = out_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
q = query * scaling
k = key
v = value
if attn_mask is not None:
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
if attn_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 2D attn_mask is not correct.')
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 3D attn_mask is not correct.')
else:
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
key_padding_mask = key_padding_mask.to(torch.bool)
if bias_k is not None and bias_v is not None:
if static_k is None and static_v is None:
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
else:
assert bias_k is None
assert bias_v is None
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * num_heads
assert static_v.size(2) == v_head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
# attn_output_weights = softmax(
# attn_output_weights, dim=-1)
attn_output_weights = softmax(
attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1)
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None

View File

@ -0,0 +1,211 @@
# ------------------------------------------------------------------------
# UniPose
# url: https://github.com/IDEA-Research/UniPose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Backbone modules.
"""
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
from .swin_transformer import build_swin_transformer
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self,
backbone: nn.Module,
train_backbone: bool,
num_channels: int,
return_interm_indices: list,
):
super().__init__()
for name, parameter in backbone.named_parameters():
if (
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
parameter.requires_grad_(False)
return_layers = {}
for idx, layer_index in enumerate(return_interm_indices):
return_layers.update(
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
)
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
# import ipdb; ipdb.set_trace()
return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
dilation: bool,
return_interm_indices: list,
batch_norm=FrozenBatchNorm2d,
):
if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=batch_norm,
)
else:
raise NotImplementedError("Why you can get here with name {}".format(name))
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
num_channels_all = [256, 512, 1024, 2048]
num_channels = num_channels_all[4 - len(return_interm_indices) :]
super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
"""
Useful args:
- backbone: backbone name
- lr_backbone:
- dilation
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
- backbone_freeze_keywords:
- use_checkpoint: for swin only for now
"""
position_embedding = build_position_encoding(args)
train_backbone = True
if not train_backbone:
raise ValueError("Please set lr_backbone > 0")
return_interm_indices = args.return_interm_indices
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
args.backbone_freeze_keywords
use_checkpoint = getattr(args, "use_checkpoint", False)
if args.backbone in ["resnet50", "resnet101"]:
backbone = Backbone(
args.backbone,
train_backbone,
args.dilation,
return_interm_indices,
batch_norm=FrozenBatchNorm2d,
)
bb_num_channels = backbone.num_channels
elif args.backbone in [
"swin_T_224_1k",
"swin_B_224_22k",
"swin_B_384_22k",
"swin_L_224_22k",
"swin_L_384_22k",
]:
pretrain_img_size = int(args.backbone.split("_")[-2])
backbone = build_swin_transformer(
args.backbone,
pretrain_img_size=pretrain_img_size,
out_indices=tuple(return_interm_indices),
dilation=False,
use_checkpoint=use_checkpoint,
)
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
else:
raise NotImplementedError("Unknown backbone {}".format(args.backbone))
assert len(bb_num_channels) == len(
return_interm_indices
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
model = Joiner(backbone, position_embedding)
model.num_channels = bb_num_channels
assert isinstance(
bb_num_channels, List
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
return model

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,274 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# from timm.models.layers import DropPath
from src.modules.util import DropPath
class FeatureResizer(nn.Module):
"""
This class takes as input a set of embeddings of dimension C1 and outputs a set of
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
"""
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
super().__init__()
self.do_ln = do_ln
# Object feature encoding
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
self.dropout = nn.Dropout(dropout)
def forward(self, encoder_features):
x = self.fc(encoder_features)
if self.do_ln:
x = self.layer_norm(x)
output = self.dropout(x)
return output
def l1norm(X, dim, eps=1e-8):
"""L1-normalize columns of X
"""
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
X = torch.div(X, norm)
return X
def l2norm(X, dim, eps=1e-8):
"""L2-normalize columns of X
"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
batch_size_q, queryL = query.size(0), query.size(1)
batch_size, sourceL = context.size(0), context.size(1)
# Get attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)
if raw_feature_norm == "softmax":
# --> (batch*sourceL, queryL)
attn = attn.view(batch_size * sourceL, queryL)
attn = nn.Softmax()(attn)
# --> (batch, sourceL, queryL)
attn = attn.view(batch_size, sourceL, queryL)
elif raw_feature_norm == "l2norm":
attn = l2norm(attn, 2)
elif raw_feature_norm == "clipped_l2norm":
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
else:
raise ValueError("unknown first norm type:", raw_feature_norm)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = attn.view(batch_size * queryL, sourceL)
attn = nn.Softmax()(attn * smooth)
# --> (batch, queryL, sourceL)
attn = attn.view(batch_size, queryL, sourceL)
# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
return weightedContext, attnT
class BiMultiHeadAttention(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
super(BiMultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.v_dim = v_dim
self.l_dim = l_dim
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** (-0.5)
self.dropout = dropout
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
self.stable_softmax_2d = True
self.clamp_min_for_underflow = True
self.clamp_max_for_overflow = True
self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.l_proj.weight)
self.l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_v_proj.weight)
self.values_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_l_proj.weight)
self.values_l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_v_proj.weight)
self.out_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_l_proj.weight)
self.out_l_proj.bias.data.fill_(0)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
"""_summary_
Args:
v (_type_): bs, n_img, dim
l (_type_): bs, n_text, dim
attention_mask_v (_type_, optional): _description_. bs, n_img
attention_mask_l (_type_, optional): _description_. bs, n_text
Returns:
_type_: _description_
"""
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
bsz, tgt_len, _ = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
0])
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range
# mask vison for language
if attention_mask_v is not None:
attention_mask_v = attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
attn_weights_l.masked_fill_(attention_mask_v, float('-inf'))
attn_weights_l = attn_weights_l.softmax(dim=-1)
# mask language for vision
if attention_mask_l is not None:
attention_mask_l = attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
attn_weights.masked_fill_(attention_mask_l, float('-inf'))
attn_weights_v = attn_weights.softmax(dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
)
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
)
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
# Bi-Direction MHA (text->image, image->text)
class BiAttentionBlock(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
drop_path=.0, init_values=1e-4, cfg=None):
"""
Inputs:
embed_dim - Dimensionality of input and attention feature vectors
hidden_dim - Dimensionality of hidden layer in feed-forward network
(usually 2-4x larger than embed_dim)
num_heads - Number of heads to use in the Multi-Head Attention block
dropout - Amount of dropout to apply in the feed-forward network
"""
super(BiAttentionBlock, self).__init__()
# pre layer norm
self.layer_norm_v = nn.LayerNorm(v_dim)
self.layer_norm_l = nn.LayerNorm(l_dim)
self.attn = BiMultiHeadAttention(v_dim=v_dim,
l_dim=l_dim,
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout)
# add layer scale for training stability
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=False)
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=False)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l)
# v, l = v + delta_v, l + delta_l
v = v + self.drop_path(self.gamma_v * delta_v)
l = l + self.drop_path(self.gamma_l * delta_l)
return v, l

View File

@ -0,0 +1,56 @@
import torch
def prepare_for_mask(kpt_mask):
tgt_size2 = 50 * 69
attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0
group_bbox_kpt = 69
num_group=50
for matchj in range(num_group * group_bbox_kpt):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
if sj > 0:
attn_mask2[:,:,matchj, :sj] = True
if ej < num_group * group_bbox_kpt:
attn_mask2[:,:,matchj, ej:] = True
bs, length = kpt_mask.shape
equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :]
equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1)
for idx in range(num_group):
start_idx = idx * length
end_idx = (idx + 1) * length
attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False
attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True
input_query_label = None
input_query_bbox = None
attn_mask = None
dn_meta = None
return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta
def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
if dn_meta and dn_meta['pad_size'] > 0:
output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class]
output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord]
outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class]
outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord]
out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
if aux_loss:
out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
dn_meta['output_known_lbs_bboxes'] = out
return outputs_class, outputs_coord

View File

@ -0,0 +1,10 @@
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn_func import MSDeformAttnFunction

View File

@ -0,0 +1,61 @@
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import MultiScaleDeformableAttention as MSDA
class MSDeformAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
MSDA.ms_deform_attn_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
return output.transpose(1, 2).contiguous()

View File

@ -0,0 +1,9 @@
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn import MSDeformAttn

View File

@ -0,0 +1,142 @@
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings
import math, os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from src.utils.dependencies.XPose.models.UniPose.ops.functions.ms_deform_attn_func import MSDeformAttnFunction
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self.use_4D_normalizer = use_4D_normalizer
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
# if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
# import ipdb; ipdb.set_trace()
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
if self.use_4D_normalizer:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5
else:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
# if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
# import ipdb; ipdb.set_trace()
# for amp
if value.dtype == torch.float16:
# for mixed precision
output = MSDeformAttnFunction.apply(
value.to(torch.float32), input_spatial_shapes, input_level_start_index, sampling_locations.to(torch.float32), attention_weights, self.im2col_step)
output = output.to(torch.float16)
output = self.output_proj(output)
return output
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output

View File

@ -0,0 +1,130 @@
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings
import math, os
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
try:
from src.utils.dependencies.XPose.models.UniPose.ops.functions import MSDeformAttnFunction
except:
warnings.warn('Failed to import MSDeformAttnFunction.')
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self.use_4D_normalizer = use_4D_normalizer
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param key (N, 1, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
# if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
# import ipdb; ipdb.set_trace()
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
if self.use_4D_normalizer:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5
else:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output

View File

@ -0,0 +1,73 @@
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
import os
import glob
import torch
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
# import ipdb; ipdb.set_trace()
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
raise NotImplementedError('Cuda is not availabel')
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"MultiScaleDeformableAttention",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="MultiScaleDeformableAttention",
version="1.0",
author="Weijie Su",
url="https://github.com/fundamentalvision/Deformable-DETR",
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
packages=find_packages(exclude=("configs", "tests",)),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)

View File

@ -0,0 +1,41 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}

View File

@ -0,0 +1,33 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);

View File

@ -0,0 +1,153 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "cuda/ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}

View File

@ -0,0 +1,30 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);

View File

@ -0,0 +1,62 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}

View File

@ -0,0 +1,16 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}

View File

@ -0,0 +1,89 @@
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H*W).item() for H, W in shapes])
torch.manual_seed(3)
@torch.no_grad()
def check_forward_equal_with_pytorch_double():
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
@torch.no_grad()
def check_forward_equal_with_pytorch_float():
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
func = MSDeformAttnFunction.apply
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
print(f'* {gradok} check_gradient_numerical(D={channels})')
if __name__ == '__main__':
check_forward_equal_with_pytorch_double()
check_forward_equal_with_pytorch_float()
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
check_gradient_numerical(channels, True, True, True)

View File

@ -0,0 +1,157 @@
# ------------------------------------------------------------------------
# ED-Pose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
# if os.environ.get("SHILONG_AMP", None) == '1':
# eps = 1e-4
# else:
# eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingSineHW(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH
self.temperatureW = temperatureW
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
# import ipdb; ipdb.set_trace()
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSineHW(
N_steps,
temperatureH=args.pe_temperatureH,
temperatureW=args.pe_temperatureW,
normalize=True
)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@ -0,0 +1,701 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from util.misc import NestedTensor
# from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from src.modules.util import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
""" Multilayer perceptron."""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
""" Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
""" Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
"""
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
dilation=False,
use_checkpoint=False):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.dilation = dilation
# if use_checkpoint:
# print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
# prepare downsample list
downsamplelist = [PatchMerging for i in range(self.num_layers)]
downsamplelist[-1] = None
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
if self.dilation:
downsamplelist[-2] = None
num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
for i_layer in range(self.num_layers):
layer = BasicLayer(
# dim=int(embed_dim * 2 ** i_layer),
dim=num_features[i_layer],
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
downsample=downsamplelist[i_layer],
use_checkpoint=use_checkpoint)
self.layers.append(layer)
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward_raw(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
# import ipdb; ipdb.set_trace()
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
# in:
# torch.Size([2, 3, 1024, 1024])
# outs:
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
return tuple(outs)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
# in:
# torch.Size([2, 3, 1024, 1024])
# out:
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
# collect for nesttensors
outs_dict = {}
for idx, out_i in enumerate(outs):
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
outs_dict[idx] = NestedTensor(out_i, mask)
return outs_dict
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
def build_swin_transformer(modelname, pretrain_img_size, **kw):
assert modelname in ['swin_T_224_1k', 'swin_B_224_22k', 'swin_B_384_22k', 'swin_L_224_22k', 'swin_L_384_22k']
model_para_dict = {
'swin_T_224_1k': dict(
embed_dim=96,
depths=[ 2, 2, 6, 2 ],
num_heads=[ 3, 6, 12, 24],
window_size=7
),
'swin_B_224_22k': dict(
embed_dim=128,
depths=[ 2, 2, 18, 2 ],
num_heads=[ 4, 8, 16, 32 ],
window_size=7
),
'swin_B_384_22k': dict(
embed_dim=128,
depths=[ 2, 2, 18, 2 ],
num_heads=[ 4, 8, 16, 32 ],
window_size=12
),
'swin_L_224_22k': dict(
embed_dim=192,
depths=[ 2, 2, 18, 2 ],
num_heads=[ 6, 12, 24, 48 ],
window_size=7
),
'swin_L_384_22k': dict(
embed_dim=192,
depths=[ 2, 2, 18, 2 ],
num_heads=[ 6, 12, 24, 48 ],
window_size=12
),
}
kw_cgf = model_para_dict[modelname]
kw_cgf.update(kw)
model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
return model
if __name__ == "__main__":
model = build_swin_transformer('swin_L_384_22k', 384, dilation=True)
x = torch.rand(2, 3, 1024, 1024)
y = model.forward_raw(x)
import ipdb; ipdb.set_trace()
x = torch.rand(2, 3, 384, 384)
y = model.forward_raw(x)

View File

@ -0,0 +1,595 @@
# ------------------------------------------------------------------------
# ED-Pose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import copy
import math
import torch
from torch import nn, Tensor
from torch.nn.init import xavier_uniform_, constant_, normal_
from typing import Optional
from util.misc import inverse_sigmoid
from .ops.modules import MSDeformAttn
from .utils import MLP, _get_activation_fn, gen_sineembed_for_position
class DeformableTransformer(nn.Module):
def __init__(self, d_model=256, nhead=8,
num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
activation="relu", return_intermediate_dec=False,
num_feature_levels=4, dec_n_points=4, enc_n_points=4,
two_stage=False, two_stage_num_proposals=300,
use_dab=False, high_dim_query_update=False, no_sine_embed=False):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.two_stage = two_stage
self.two_stage_num_proposals = two_stage_num_proposals
self.use_dab = use_dab
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
dropout, activation,
num_feature_levels, nhead, enc_n_points)
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
dropout, activation,
num_feature_levels, nhead, dec_n_points)
self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec,
use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed)
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
if two_stage:
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
self.pos_trans_norm = nn.LayerNorm(d_model * 2)
else:
if not self.use_dab:
self.reference_points = nn.Linear(d_model, 2)
self.high_dim_query_update = high_dim_query_update
if high_dim_query_update:
assert not self.use_dab, "use_dab must be True"
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if not self.two_stage and not self.use_dab:
xavier_uniform_(self.reference_points.weight.data, gain=1.0)
constant_(self.reference_points.bias.data, 0.)
normal_(self.level_embed)
def get_proposal_pos_embed(self, proposals):
num_pos_feats = 128
temperature = 10000
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
return pos
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += (H_ * W_)
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, srcs, masks, pos_embeds, query_embed=None):
"""
Input:
- srcs: List([bs, c, h, w])
- masks: List([bs, h, w])
"""
assert self.two_stage or query_embed is not None
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
# import ipdb; ipdb.set_trace()
# prepare input for decoder
bs, _, c = memory.shape
if self.two_stage:
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
# hack implementation for two-stage Deformable DETR
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
elif self.use_dab:
reference_points = query_embed[..., self.d_model:].sigmoid()
tgt = query_embed[..., :self.d_model]
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
init_reference_out = reference_points
else:
query_embed, tgt = torch.split(query_embed, c, dim=1)
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_embed).sigmoid()
# bs, num_quires, 2
init_reference_out = reference_points
# decoder
# import ipdb; ipdb.set_trace()
hs, inter_references = self.decoder(tgt, reference_points, memory,
spatial_shapes, level_start_index, valid_ratios,
query_pos=query_embed if not self.use_dab else None,
src_padding_mask=mask_flatten)
inter_references_out = inter_references
if self.two_stage:
return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
return hs, init_reference_out, inter_references_out, None, None
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4,
add_channel_attention=False,
use_deformable_box_attn=False,
box_attn_type='roi_align',
):
super().__init__()
# self attention
if use_deformable_box_attn:
self.self_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type)
else:
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# channel attention
self.add_channel_attention = add_channel_attention
if add_channel_attention:
self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model)
self.norm_channel = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
# self attention
# import ipdb; ipdb.set_trace()
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
# channel attn
if self.add_channel_attention:
src = self.norm_channel(src + self.activ_channel(src))
return src
class DeformableTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
if num_layers > 0:
self.layers = _get_clones(encoder_layer, num_layers)
else:
self.layers = []
del encoder_layer
self.num_layers = num_layers
self.norm = norm
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- padding_mask: [bs, sum(hi*wi)]
Intermedia:
- reference_points: [bs, sum(hi*wi), num_lebel, 2]
"""
output = src
# bs, sum(hi*wi), 256
# import ipdb; ipdb.set_trace()
if self.num_layers > 0:
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
for _, layer in enumerate(self.layers):
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
if self.norm is not None:
output = self.norm(output)
return output
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4,
use_deformable_box_attn=False,
box_attn_type='roi_align',
key_aware_type=None,
decoder_sa_type='ca',
module_seq=['sa', 'ca', 'ffn'],
):
super().__init__()
self.module_seq = module_seq
assert sorted(module_seq) == ['ca', 'ffn', 'sa']
# cross attention
# self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
if use_deformable_box_attn:
self.cross_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type)
else:
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
self.key_aware_type = key_aware_type
self.key_aware_proj = None
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
if decoder_sa_type == 'ca_content':
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
def rm_self_attn_modules(self):
self.self_attn = None
self.dropout2 = None
self.norm2 = None
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_sa(self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
):
# self attention
if self.self_attn is not None:
# import ipdb; ipdb.set_trace()
if self.decoder_sa_type == 'sa':
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
elif self.decoder_sa_type == 'ca_label':
# import ipdb; ipdb.set_trace()
# q = self.with_pos_embed(tgt, tgt_query_pos)
bs = tgt.shape[1]
k = v = self.label_embedding.weight[:, None, :].repeat(1, bs, 1)
tgt2 = self.self_attn(tgt, k, v, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
elif self.decoder_sa_type == 'ca_content':
tgt2 = self.self_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
tgt_reference_points.transpose(0, 1).contiguous(),
memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
else:
raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
return tgt
def forward_ca(self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
):
# cross attention
# import ipdb; ipdb.set_trace()
if self.key_aware_type is not None:
if self.key_aware_type == 'mean':
tgt = tgt + memory.mean(0, keepdim=True)
elif self.key_aware_type == 'proj_mean':
tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)
else:
raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type))
tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
tgt_reference_points.transpose(0, 1).contiguous(),
memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
return tgt
def forward(self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
):
for funcname in self.module_seq:
# if os.environ.get('IPDB_DEBUG_SHILONG') == 'INFO':
# import ipdb; ipdb.set_trace()
if funcname == 'ffn':
tgt = self.forward_ffn(tgt)
elif funcname == 'ca':
tgt = self.forward_ca(tgt, tgt_query_pos, tgt_query_sine_embed, \
tgt_key_padding_mask, tgt_reference_points, \
memory, memory_key_padding_mask, memory_level_start_index, \
memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask)
elif funcname == 'sa':
tgt = self.forward_sa(tgt, tgt_query_pos, tgt_query_sine_embed, \
tgt_key_padding_mask, tgt_reference_points, \
memory, memory_key_padding_mask, memory_level_start_index, \
memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask)
else:
raise ValueError('unknown funcname {}'.format(funcname))
return tgt
class DeformableTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, return_intermediate=False, use_dab=False, d_model=256, query_dim=4):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
assert return_intermediate
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
self.bbox_embed = None
self.class_embed = None
self.use_dab = use_dab
self.d_model = d_model
self.query_dim = query_dim
if use_dab:
self.query_scale = MLP(d_model, d_model, d_model, 2)
self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
def forward(self, tgt, reference_points, src, src_spatial_shapes,
src_level_start_index, src_valid_ratios,
query_pos=None, src_padding_mask=None):
output = tgt
if self.use_dab:
assert query_pos is None
intermediate = []
intermediate_reference_points = [reference_points]
for layer_id, layer in enumerate(self.layers):
# import ipdb; ipdb.set_trace()
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] \
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] # bs, nq, 4, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
if self.use_dab:
# import ipdb; ipdb.set_trace()
query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2
raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256
pos_scale = self.query_scale(output) if layer_id != 0 else 1
query_pos = pos_scale * raw_query_pos
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
# hack implementation for iterative bounding box refinement
if self.bbox_embed is not None:
box_holder = self.bbox_embed(output)
box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
new_reference_points = box_holder[..., :self.query_dim].sigmoid()
reference_points = new_reference_points.detach()
if layer_id != self.num_layers - 1:
intermediate_reference_points.append(new_reference_points)
intermediate.append(output)
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_deforamble_transformer(args):
return DeformableTransformer(
d_model=args.hidden_dim,
nhead=args.nheads,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
activation="relu",
return_intermediate_dec=True,
num_feature_levels=args.ddetr_num_feature_levels,
dec_n_points=args.ddetr_dec_n_points,
enc_n_points=args.ddetr_enc_n_points,
two_stage=args.ddetr_two_stage,
two_stage_num_proposals=args.num_queries,
use_dab=args.ddetr_use_dab,
high_dim_query_update=args.ddetr_high_dim_query_update,
no_sine_embed=args.ddetr_no_sine_embed)

View File

@ -0,0 +1,102 @@
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import torch
from torch import Tensor, nn
from typing import List, Optional
from .utils import _get_activation_fn, _get_clones
class TextTransformer(nn.Module):
def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.nheads = nheads
self.dim_feedforward = dim_feedforward
self.norm = None
single_encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout)
self.layers = _get_clones(single_encoder_layer, num_layers)
def forward(self, memory_text:torch.Tensor, text_attention_mask:torch.Tensor):
"""
Args:
text_attention_mask: bs, num_token
memory_text: bs, num_token, d_model
Raises:
RuntimeError: _description_
Returns:
output: bs, num_token, d_model
"""
output = memory_text.transpose(0, 1)
for layer in self.layers:
output = layer(output, src_key_padding_mask=text_attention_mask)
if self.norm is not None:
output = self.norm(output)
return output.transpose(0, 1)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.nhead = nhead
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
# repeat attn mask
if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
# bs, num_q, num_k
src_mask = src_mask.repeat(self.nhead, 1, 1)
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
# src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src

View File

@ -0,0 +1,621 @@
# ------------------------------------------------------------------------
# ED-Pose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import os
import copy
import torch
import torch.nn.functional as F
from torch import nn
from typing import List
from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz
from util.misc import NestedTensor, nested_tensor_from_tensor_list,inverse_sigmoid
from .utils import MLP
from .backbone import build_backbone
from ..registry import MODULE_BUILD_FUNCS
from .mask_generate import prepare_for_mask, post_process
from .deformable_transformer import build_deformable_transformer
class UniPose(nn.Module):
""" This is the Cross-Attention Detector module that performs object detection """
def __init__(self, backbone, transformer, num_classes, num_queries,
aux_loss=False, iter_update=False,
query_dim=2,
random_refpoints_xy=False,
fix_refpoints_hw=-1,
num_feature_levels=1,
nheads=8,
# two stage
two_stage_type='no', # ['no', 'standard']
two_stage_add_query_num=0,
dec_pred_class_embed_share=True,
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
decoder_sa_type='sa',
num_patterns=0,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
use_label_enc=True,
text_encoder_type='bert-base-uncased',
binary_query_selection=False,
use_cdn=True,
sub_sentence_present=True,
num_body_points=68,
num_box_decoder_layers=2,
):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
fix_refpoints_hw: -1(default): learn w and h for each box seperately
>0 : given fixed number
-2 : learn a shared w and h
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.num_classes = num_classes
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.use_label_enc = use_label_enc
if use_label_enc:
self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim)
else:
raise NotImplementedError
self.label_enc = None
self.max_text_len = 256
self.binary_query_selection = binary_query_selection
self.sub_sentence_present = sub_sentence_present
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
self.random_refpoints_xy = random_refpoints_xy
self.fix_refpoints_hw = fix_refpoints_hw
# for dn training
self.num_patterns = num_patterns
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_labelbook_size = dn_labelbook_size
self.use_cdn = use_cdn
self.projection = MLP(512, hidden_dim, hidden_dim, 3)
self.projection_kpt = MLP(512, hidden_dim, hidden_dim, 3)
device = "cuda" if torch.cuda.is_available() else "cpu"
# model, _ = clip.load("ViT-B/32", device=device)
# self.clip_model = model
# visual_parameters = list(self.clip_model.visual.parameters())
# #
# for param in visual_parameters:
# param.requires_grad = False
self.pos_proj = nn.Linear(hidden_dim, 768)
self.padding = nn.Embedding(1, 768)
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
))
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
))
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == 'no', "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)])
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, "Why not iter_update?"
# prepare pred layers
self.dec_pred_class_embed_share = dec_pred_class_embed_share
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = ContrastiveAssign()
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
_pose_embed = MLP(hidden_dim, hidden_dim, 2, 3)
_pose_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_pose_embed.layers[-1].weight.data, 0)
nn.init.constant_(_pose_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
else:
box_embed_layerlist = [copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)]
if dec_pred_class_embed_share:
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
else:
class_embed_layerlist = [copy.deepcopy(_class_embed) for i in range(transformer.num_decoder_layers)]
if dec_pred_bbox_embed_share:
pose_embed_layerlist = [_pose_embed for i in
range(transformer.num_decoder_layers - num_box_decoder_layers + 1)]
else:
pose_embed_layerlist = [copy.deepcopy(_pose_embed) for i in
range(transformer.num_decoder_layers - num_box_decoder_layers + 1)]
pose_hw_embed_layerlist = [_pose_hw_embed for i in
range(transformer.num_decoder_layers - num_box_decoder_layers)]
self.num_box_decoder_layers = num_box_decoder_layers
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.num_body_points = num_body_points
self.pose_embed = nn.ModuleList(pose_embed_layerlist)
self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
self.transformer.decoder.pose_embed = self.pose_embed
self.transformer.decoder.pose_hw_embed = self.pose_hw_embed
self.transformer.decoder.num_body_points = num_body_points
# two stage
self.two_stage_type = two_stage_type
self.two_stage_add_query_num = two_stage_add_query_num
assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type)
if two_stage_type != 'no':
if two_stage_bbox_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
self.refpoint_embed = None
if self.two_stage_add_query_num > 0:
self.init_ref_points(two_stage_add_query_num)
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
# self.replace_sa_with_double_ca = replace_sa_with_double_ca
if decoder_sa_type == 'ca_label':
self.label_embedding = nn.Embedding(num_classes, hidden_dim)
for layer in self.transformer.decoder.layers:
layer.label_embedding = self.label_embedding
else:
for layer in self.transformer.decoder.layers:
layer.label_embedding = None
self.label_embedding = None
self._reset_parameters()
def open_set_transfer_init(self):
for name, param in self.named_parameters():
if 'fusion_layers' in name:
continue
if 'ca_text' in name:
continue
if 'catext_norm' in name:
continue
if 'catext_dropout' in name:
continue
if "text_layers" in name:
continue
if 'bert' in name:
continue
if 'bbox_embed' in name:
continue
if 'label_enc.weight' in name:
continue
if 'feat_map' in name:
continue
if 'enc_output' in name:
continue
param.requires_grad_(False)
# import ipdb; ipdb.set_trace()
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
if self.random_refpoints_xy:
# import ipdb; ipdb.set_trace()
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
if self.fix_refpoints_hw > 0:
print("fix_refpoints_hw: {}".format(self.fix_refpoints_hw))
assert self.random_refpoints_xy
self.refpoint_embed.weight.data[:, 2:] = self.fix_refpoints_hw
self.refpoint_embed.weight.data[:, 2:] = inverse_sigmoid(self.refpoint_embed.weight.data[:, 2:])
self.refpoint_embed.weight.data[:, 2:].requires_grad = False
elif int(self.fix_refpoints_hw) == -1:
pass
elif int(self.fix_refpoints_hw) == -2:
print('learn a shared h and w')
assert self.random_refpoints_xy
self.refpoint_embed = nn.Embedding(use_num_queries, 2)
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
self.hw_embed = nn.Embedding(1, 1)
else:
raise NotImplementedError('Unknown fix_refpoints_hw {}'.format(self.fix_refpoints_hw))
def forward(self, samples: NestedTensor, targets: List = None, **kw):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
captions = [t['instance_text_prompt'] for t in targets]
bs=len(captions)
tensor_list = [tgt["object_embeddings_text"] for tgt in targets]
max_size = 350
padded_tensors = [torch.cat([tensor, torch.zeros(max_size - tensor.size(0), tensor.size(1),device=tensor.device)]) if tensor.size(0) < max_size else tensor for tensor in tensor_list]
object_embeddings_text = torch.stack(padded_tensors)
kpts_embeddings_text = torch.stack([tgt["kpts_embeddings_text"] for tgt in targets])[:, :self.num_body_points]
encoded_text=self.projection(object_embeddings_text) # bs, 81, 101, 256
kpt_embeddings_specific=self.projection_kpt(kpts_embeddings_text) # bs, 81, 101, 256
kpt_vis = torch.stack([tgt["kpt_vis_text"] for tgt in targets])[:, :self.num_body_points]
kpt_mask = torch.cat((torch.ones_like(kpt_vis, device=kpt_vis.device)[..., 0].unsqueeze(-1), kpt_vis), dim=-1)
num_classes = encoded_text.shape[1] # bs, 81, 101, 256
text_self_attention_masks = torch.eye(num_classes).unsqueeze(0).expand(bs, -1, -1).bool().to(samples.device)
text_token_mask = torch.zeros(samples.shape[0],num_classes).to(samples.device)>0
for i in range(bs):
text_token_mask[i,:len(captions[i])]=True
position_ids = torch.zeros(samples.shape[0], num_classes).to(samples.device)
for i in range(bs):
position_ids[i,:len(captions[i])]= 1
text_dict = {
'encoded_text': encoded_text, # bs, 195, d_model
'text_token_mask': text_token_mask, # bs, 195
'position_ids': position_ids, # bs, 195
'text_self_attention_masks': text_self_attention_masks # bs, 195,195
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, poss = self.backbone(samples)
if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
import ipdb;
ipdb.set_trace()
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
if self.label_enc is not None:
label_enc = self.label_enc
else:
raise NotImplementedError
label_enc = encoded_text
if self.dn_number > 0 or targets is not None:
input_query_label, input_query_bbox, attn_mask, attn_mask2, dn_meta = \
prepare_for_mask(kpt_mask=kpt_mask)
else:
assert targets is None
input_query_bbox = input_query_label = attn_mask = attn_mask2 = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(srcs, masks, input_query_bbox, poss,
input_query_label, attn_mask, attn_mask2,
text_dict, dn_meta,targets,kpt_embeddings_specific)
# In case num object=0
if self.label_enc is not None:
hs[0] += self.label_enc.weight[0, 0] * 0.0
hs[0] += self.pos_proj.weight[0, 0] * 0.0
hs[0] += self.pos_proj.bias[0] * 0.0
hs[0] += self.padding.weight[0, 0] * 0.0
num_group = 50
effective_dn_number = dn_meta['pad_size'] if self.training else 0
outputs_coord_list = []
outputs_class = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, self.class_embed, hs)):
if dec_lid < self.num_box_decoder_layers:
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
layer_cls = layer_cls_embed(layer_hs, text_dict)
outputs_coord_list.append(layer_outputs_unsig)
outputs_class.append(layer_cls)
else:
layer_hs_bbox_dn = layer_hs[:, :effective_dn_number, :]
layer_hs_bbox_norm = layer_hs[:, effective_dn_number:, :][:, 0::(self.num_body_points + 1), :]
bs = layer_ref_sig.shape[0]
reference_before_sigmoid_bbox_dn = layer_ref_sig[:, :effective_dn_number, :]
reference_before_sigmoid_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][:,
0::(self.num_body_points + 1), :]
layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn)
layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm)
layer_outputs_unsig_dn = layer_delta_unsig_dn + inverse_sigmoid(reference_before_sigmoid_bbox_dn)
layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid()
layer_outputs_unsig_norm = layer_delta_unsig_norm + inverse_sigmoid(reference_before_sigmoid_bbox_norm)
layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid()
layer_outputs_unsig = torch.cat((layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1)
layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn, text_dict)
layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm, text_dict)
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1)
outputs_class.append(layer_cls)
outputs_coord_list.append(layer_outputs_unsig)
# update keypoints
outputs_keypoints_list = []
outputs_keypoints_hw = []
kpt_index = [x for x in range(num_group * (self.num_body_points + 1)) if x % (self.num_body_points + 1) != 0]
for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)):
if dec_lid < self.num_box_decoder_layers:
assert isinstance(layer_hs, torch.Tensor)
bs = layer_hs.shape[0]
layer_res = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points * 3))
outputs_keypoints_list.append(layer_res)
else:
bs = layer_ref_sig.shape[0]
layer_hs_kpt = layer_hs[:, effective_dn_number:, :].index_select(1, torch.tensor(kpt_index,
device=layer_hs.device))
delta_xy_unsig = self.pose_embed[dec_lid - self.num_box_decoder_layers](layer_hs_kpt)
layer_ref_sig_kpt = layer_ref_sig[:, effective_dn_number:, :].index_select(1, torch.tensor(kpt_index,
device=layer_hs.device))
layer_outputs_unsig_keypoints = delta_xy_unsig + inverse_sigmoid(layer_ref_sig_kpt[..., :2])
vis_xy_unsig = torch.ones_like(layer_outputs_unsig_keypoints,
device=layer_outputs_unsig_keypoints.device)
xyv = torch.cat((layer_outputs_unsig_keypoints, vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1)
xyv = xyv.sigmoid()
layer_res = xyv.reshape((bs, num_group, self.num_body_points, 3)).flatten(2, 3)
layer_hw = layer_ref_sig_kpt[..., 2:].reshape(bs, num_group, self.num_body_points, 2).flatten(2, 3)
layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res)
outputs_keypoints_list.append(layer_res)
outputs_keypoints_hw.append(layer_hw)
if self.dn_number > 0 and dn_meta is not None:
outputs_class, outputs_coord_list = \
post_process(outputs_class, outputs_coord_list,
dn_meta, self.aux_loss, self._set_aux_loss)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord_list[-1],
'pred_keypoints': outputs_keypoints_list[-1]}
return out
@MODULE_BUILD_FUNCS.registe_with_name(module_name='UniPose')
def build_unipose(args):
num_classes = args.num_classes
device = torch.device(args.device)
backbone = build_backbone(args)
transformer = build_deformable_transformer(args)
try:
match_unstable_error = args.match_unstable_error
dn_labelbook_size = args.dn_labelbook_size
except:
match_unstable_error = True
dn_labelbook_size = num_classes
try:
dec_pred_class_embed_share = args.dec_pred_class_embed_share
except:
dec_pred_class_embed_share = True
try:
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
except:
dec_pred_bbox_embed_share = True
binary_query_selection = False
try:
binary_query_selection = args.binary_query_selection
except:
binary_query_selection = False
use_cdn = True
try:
use_cdn = args.use_cdn
except:
use_cdn = True
sub_sentence_present = True
try:
sub_sentence_present = args.sub_sentence_present
except:
sub_sentence_present = True
# print('********* sub_sentence_present', sub_sentence_present)
model = UniPose(
backbone,
transformer,
num_classes=num_classes,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
random_refpoints_xy=args.random_refpoints_xy,
fix_refpoints_hw=args.fix_refpoints_hw,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_class_embed_share=dec_pred_class_embed_share,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
# two stage
two_stage_type=args.two_stage_type,
# box_share
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
decoder_sa_type=args.decoder_sa_type,
num_patterns=args.num_patterns,
dn_number=args.dn_number if args.use_dn else 0,
dn_box_noise_scale=args.dn_box_noise_scale,
dn_label_noise_ratio=args.dn_label_noise_ratio,
dn_labelbook_size=dn_labelbook_size,
use_label_enc=args.use_label_enc,
text_encoder_type=args.text_encoder_type,
binary_query_selection=binary_query_selection,
use_cdn=use_cdn,
sub_sentence_present=sub_sentence_present
)
return model
class ContrastiveAssign(nn.Module):
def __init__(self, project=False, cal_bias=None, max_text_len=256):
"""
:param x: query
:param y: text embed
:param proj:
:return:
"""
super().__init__()
self.project = project
self.cal_bias = cal_bias
self.max_text_len = max_text_len
def forward(self, x, text_dict):
"""_summary_
Args:
x (_type_): _description_
text_dict (_type_): _description_
{
'encoded_text': encoded_text, # bs, 195, d_model
'text_token_mask': text_token_mask, # bs, 195
# True for used tokens. False for padding tokens
}
Returns:
_type_: _description_
"""
assert isinstance(text_dict, dict)
y = text_dict['encoded_text']
max_text_len = y.shape[1]
text_token_mask = text_dict['text_token_mask']
if self.cal_bias is not None:
raise NotImplementedError
return x @ y.transpose(-1, -2) + self.cal_bias.weight.repeat(x.shape[0], x.shape[1], 1)
res = x @ y.transpose(-1, -2)
res.masked_fill_(~text_token_mask[:, None, :], float('-inf'))
# padding to max_text_len
new_res = torch.full((*res.shape[:-1], max_text_len), float('-inf'), device=res.device)
new_res[..., :res.shape[-1]] = res
return new_res

View File

@ -0,0 +1,348 @@
# ------------------------------------------------------------------------
# ED-Pose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import copy
import torch
import random
from torch import nn, Tensor
import os
import numpy as np
import math
import torch.nn.functional as F
from torch import nn
def _get_clones(module, N, layer_share=False):
# import ipdb; ipdb.set_trace()
if layer_share:
return nn.ModuleList([module for i in range(N)])
else:
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def get_sine_pos_embed(
pos_tensor: torch.Tensor,
num_pos_feats: int = 128,
temperature: int = 10000,
exchange_xy: bool = True,
):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): shape: [..., n].
num_pos_feats (int): projected shape for each float in the tensor.
temperature (int): temperature in the sine/cosine function.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
Returns:
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
"""
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
return sin_x
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1)
return pos_res
def gen_encoder_output_proposals(memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
- learnedwh: 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
# import ipdb; ipdb.set_trace()
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
if learnedwh is not None:
# import ipdb; ipdb.set_trace()
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0 ** lvl)
else:
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh = torch.ones_like(grid) / scale
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += (H_ * W_)
# import ipdb; ipdb.set_trace()
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
return output_memory, output_proposals
class RandomBoxPerturber():
def __init__(self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2) -> None:
self.noise_scale = torch.Tensor([x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale])
def __call__(self, refanchors: Tensor) -> Tensor:
nq, bs, query_dim = refanchors.shape
device = refanchors.device
noise_raw = torch.rand_like(refanchors)
noise_scale = self.noise_scale.to(device)[:query_dim]
new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
return new_refanchors.clamp_(0, 1)
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if no_reduction:
return loss
return loss.mean(1).sum() / num_boxes
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def _get_activation_fn(activation, d_model=256, batch_dim=0):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
if activation == "prelu":
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos
def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas):
sigmas = kpt_preds.new_tensor(sigmas)
variances = (sigmas * 2) ** 2
assert kpt_preds.size(0) == kpt_gts.size(0)
kpt_preds = kpt_preds.reshape(-1, kpt_preds.size(-1) // 2, 2)
kpt_gts = kpt_gts.reshape(-1, kpt_gts.size(-1) // 2, 2)
squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \
(kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2
# import pdb
# pdb.set_trace()
# assert (kpt_valids.sum(-1) > 0).all()
squared_distance0 = squared_distance / (kpt_areas[:, None] * variances[None, :] * 2)
squared_distance1 = torch.exp(-squared_distance0)
squared_distance1 = squared_distance1 * kpt_valids
oks = squared_distance1.sum(dim=1) / (kpt_valids.sum(dim=1) + 1e-6)
return oks
def oks_loss(pred,
target,
valid=None,
area=None,
linear=False,
sigmas=None,
eps=1e-6):
"""Oks loss.
Computing the oks loss between a set of predicted poses and target poses.
The loss is calculated as negative log of oks.
Args:
pred (torch.Tensor): Predicted poses of format (x1, y1, x2, y2, ...),
shape (n, 2K).
target (torch.Tensor): Corresponding gt poses, shape (n, 2K).
linear (bool, optional): If True, use linear scale of loss instead of
log scale. Default: False.
eps (float): Eps to avoid log(0).
Return:
torch.Tensor: Loss tensor.
"""
oks = oks_overlaps(pred, target, valid, area, sigmas).clamp(min=eps)
if linear:
loss = 1 - oks
else:
loss = -oks.log()
return loss
class OKSLoss(nn.Module):
"""IoULoss.
Computing the oks loss between a set of predicted poses and target poses.
Args:
linear (bool): If True, use linear scale of loss instead of log scale.
Default: False.
eps (float): Eps to avoid log(0).
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of loss.
"""
def __init__(self,
linear=False,
num_keypoints=17,
eps=1e-6,
reduction='mean',
loss_weight=1.0):
super(OKSLoss, self).__init__()
self.linear = linear
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
if num_keypoints == 68:
self.sigmas = np.array([
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07,
1.07, .87, .87, .89, .89, .25, .25, .25, .25, .25, .25, .25, .25,
.25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25,
.25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25,
.25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25,
], dtype=np.float32) / 10.0
else:
raise ValueError(f'Unsupported keypoints number {num_keypoints}')
def forward(self,
pred,
target,
valid,
area,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
valid (torch.Tensor): The visible flag of the target pose.
area (torch.Tensor): The area of the target pose.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None. Options are "none", "mean" and "sum".
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if (weight is not None) and (not torch.any(weight > 0)) and (
reduction != 'none'):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
if weight is not None and weight.dim() > 1:
# TODO: remove this in the future
# reduce the weight of shape (n, 4) to (n,) to match the
# iou_loss of shape (n,)
assert weight.shape == pred.shape
weight = weight.mean(-1)
loss = self.loss_weight * oks_loss(
pred,
target,
valid=valid,
area=area,
linear=self.linear,
sigmas=self.sigmas,
eps=self.eps)
return loss

View File

@ -0,0 +1,16 @@
# ------------------------------------------------------------------------
# ED-Pose
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .UniPose.unipose import build_unipose
def build_model(args):
# we use register to maintain models from catdet6 on.
from .registry import MODULE_BUILD_FUNCS
assert args.modelname in MODULE_BUILD_FUNCS._module_dict
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
model = build_func(args)
return model

View File

@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
# @Author: Yihao Chen
# @Date: 2021-08-16 16:03:17
# @Last Modified by: Shilong Liu
# @Last Modified time: 2022-01-23 15:26
# modified from mmcv
import inspect
from functools import partial
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
def __len__(self):
return len(self._module_dict)
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def registe_with_name(self, module_name=None, force=False):
return partial(self.register, module_name=module_name, force=force)
def register(self, module_build_function, module_name=None, force=False):
"""Register a module build function.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isfunction(module_build_function):
raise TypeError('module_build_function must be a function, but got {}'.format(
type(module_build_function)))
if module_name is None:
module_name = module_build_function.__name__
if not force and module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_build_function
return module_build_function
MODULE_BUILD_FUNCS = Registry('model build functions')

View File

@ -0,0 +1,56 @@
person = {"keypoints":['nose', 'left eye', 'right eye', 'left ear', 'right ear', 'left shoulder', 'right shoulder', 'left elbow', 'right elbow', 'left wrist', 'right wrist', 'left hip', 'right hip', 'left knee', 'right knee', 'left ankle', 'right ankle'],"skeleton": [[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],[6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]}
face = {"keypoints": ['right cheekbone 1', 'right cheekbone 2', 'right cheek 1', 'right cheek 2', 'right cheek 3', 'right cheek 4', 'right cheek 5', 'right chin', 'chin center', 'left chin', 'left cheek 5', 'left cheek 4', 'left cheek 3', 'left cheek 2', 'left cheek 1', 'left cheekbone 2', 'left cheekbone 1', 'right eyebrow 1', 'right eyebrow 2', 'right eyebrow 3', 'right eyebrow 4', 'right eyebrow 5', 'left eyebrow 1', 'left eyebrow 2', 'left eyebrow 3', 'left eyebrow 4', 'left eyebrow 5', 'nasal bridge 1', 'nasal bridge 2', 'nasal bridge 3', 'nasal bridge 4', 'right nasal wing 1', 'right nasal wing 2', 'nasal wing center', 'left nasal wing 1', 'left nasal wing 2', 'right eye eye corner 1', 'right eye upper eyelid 1', 'right eye upper eyelid 2', 'right eye eye corner 2', 'right eye lower eyelid 2', 'right eye lower eyelid 1', 'left eye eye corner 1', 'left eye upper eyelid 1', 'left eye upper eyelid 2', 'left eye eye corner 2', 'left eye lower eyelid 2', 'left eye lower eyelid 1', 'right mouth corner', 'upper lip outer edge 1', 'upper lip outer edge 2', 'upper lip outer edge 3', 'upper lip outer edge 4', 'upper lip outer edge 5', 'left mouth corner', 'lower lip outer edge 5', 'lower lip outer edge 4', 'lower lip outer edge 3', 'lower lip outer edge 2', 'lower lip outer edge 1', 'upper lip inter edge 1', 'upper lip inter edge 2', 'upper lip inter edge 3', 'upper lip inter edge 4', 'upper lip inter edge 5', 'lower lip inter edge 3', 'lower lip inter edge 2', 'lower lip inter edge 1'], "skeleton": []}
hand = {"keypoints":['wrist', 'thumb root', "thumb's third knuckle", "thumb's second knuckle", 'thumbs first knuckle', "forefinger's root", "forefinger's third knuckle", "forefinger's second knuckle", "forefinger's first knuckle", "middle finger's root", "middle finger's third knuckle", "middle finger's second knuckle", "middle finger's first knuckle", "ring finger's root", "ring finger's third knuckle", "ring finger's second knuckle", "ring finger's first knuckle", "pinky finger's root", "pinky finger's third knuckle", "pinky finger's second knuckle", "pinky finger's first knuckle"],"skeleton": []}
animal_in_AnimalKindom = {"keypoints":['head mid top', 'eye left', 'eye right', 'mouth front top', 'mouth back left', 'mouth back right', 'mouth front bottom', 'shoulder left', 'shoulder right', 'elbow left', 'elbow right', 'wrist left', 'wrist right', 'torso mid back', 'hip left', 'hip right', 'knee left', 'knee right', 'ankle left ', 'ankle right', 'tail top back', 'tail mid back', 'tail end back'],"skeleton": [[1, 0], [2, 0], [3, 4], [3, 5], [4, 6], [5, 6], [0, 7], [0, 8], [7, 9], [8, 10], [9, 11], [10, 12], [0, 13], [13, 20], [20, 14], [20, 15], [14, 16], [15, 17], [16, 18], [17, 19], [20, 21], [21, 22]]}
animal_in_AP10K = {"keypoints": ['left eye', 'right eye', 'nose', 'neck', 'root of tail', 'left shoulder', 'left elbow', 'left front paw', 'right shoulder', 'right elbow', 'right front paw', 'left hip', 'left knee', 'left back paw', 'right hip', 'right knee', 'right back paw'], "skeleton": [[1, 2], [1, 3], [2, 3], [3, 4], [4, 5], [4, 6], [6, 7], [7, 8], [4, 9], [9, 10], [10, 11], [5, 12], [12, 13], [13, 14], [5, 15], [15, 16], [16, 17]]}
animal= {"keypoints": ['left eye', 'right eye', 'nose', 'neck', 'root of tail', 'left shoulder', 'left elbow', 'left front paw', 'right shoulder', 'right elbow', 'right front paw', 'left hip', 'left knee', 'left back paw', 'right hip', 'right knee', 'right back paw'], "skeleton": [[1, 2], [1, 3], [2, 3], [3, 4], [4, 5], [4, 6], [6, 7], [7, 8], [4, 9], [9, 10], [10, 11], [5, 12], [12, 13], [13, 14], [5, 15], [15, 16], [16, 17]]}
animal_face = {"keypoints": ['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip'], "skeleton": []}
fly = {"keypoints": ['head', 'eye left', 'eye right', 'neck', 'thorax', 'abdomen', 'foreleg right base', 'foreleg right first segment', 'foreleg right second segment', 'foreleg right tip', 'midleg right base', 'midleg right first segment', 'midleg right second segment', 'midleg right tip', 'hindleg right base', 'hindleg right first segment', 'hindleg right second segment', 'hindleg right tip', 'foreleg left base', 'foreleg left first segment', 'foreleg left second segment', 'foreleg left tip', 'midleg left base', 'midleg left first segment', 'midleg left second segment', 'midleg left tip', 'hindleg left base', 'hindleg left first segment', 'hindleg left second segment', 'hindleg left tip', 'wing left', 'wing right'], "skeleton": [[2, 1], [3, 1], [4, 1], [5, 4], [6, 5], [8, 7], [9, 8], [10, 9], [12, 11], [13, 12], [14, 13], [16, 15], [17, 16], [18, 17], [20, 19], [21, 20], [22, 21], [24, 23], [25, 24], [26, 25], [28, 27], [29, 28], [30, 29], [31, 4], [32, 4]]}
locust = {"keypoints": ['head', 'neck', 'thorax', 'abdomen1', 'abdomen2', 'anttip left', 'antbase left', 'eye left', 'foreleg left base', 'foreleg left first segment', 'foreleg left second segment', 'foreleg left tip', 'midleg left base', 'midleg left first segment', 'midleg left second segment', 'midleg left tip', 'hindleg left base', 'hindleg left first segment', 'hindleg left second segment', 'hindleg left tip', 'anttip right', 'antbase right', 'eye right', 'foreleg right base', 'foreleg right first segment', 'foreleg right second segment', 'foreleg right tip', 'midleg right base', 'midleg right first segment', 'midleg right second segment', 'midleg right tip', 'hindleg right base', 'hindleg right first segment', 'hindleg right second segment', 'hindleg right tip'],"skeleton": [[2, 1], [3, 2], [4, 3], [5, 4], [7, 6], [8, 7], [10, 9], [11, 10], [12, 11], [14, 13], [15, 14],[16, 15], [18, 17], [19, 18], [20, 19], [22, 21], [23, 22], [25, 24], [26, 25], [27, 26],[29, 28], [30, 29], [31, 30], [33, 32], [34, 33], [35, 34]]}
car ={"keypoints": ['right front wheel center', 'left front wheel center', 'right rear wheel center', 'left rear wheel center', 'front right', 'front left', 'back right', 'back left', 'none', 'roof front right', 'roof front left', 'roof back right', 'roof back left', 'none'],"skeleton": [[0, 2], [1, 3], [0, 1], [2, 3], [9, 11], [10, 12], [9, 10], [11, 12], [4, 0], [4, 9], [4, 5], [5, 1], [5, 10], [6, 2], [6, 11], [7, 3], [7, 12], [6, 7]]}
short_sleeved_shirt = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right cuff outside', 'right cuff inside', 'right sleeve inside 2', 'right sleeve inside 1', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left cuff inside', 'left cuff outside', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []}
long_sleeved_outwear={'keypoints': ['upper center neckline', 'lower right center neckline', 'lower right neckline', 'upper right neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right sleeve outside 3', 'right sleeve outside 4', 'right cuff outside', 'right cuff inside', 'right sleeve inside 1', 'right sleeve inside 2', 'right sleeve inside 3', 'right sleeve inside 4', 'right side outside 1', 'right side outside 2', 'right side outside 3', 'right side inside 3', 'left side outside 3', 'left side outside 2', 'left side outside 1', 'left sleeve inside 4', 'left sleeve inside 3', 'left sleeve inside 2', 'left sleeve inside 1', 'left cuff inside', 'left cuff outside', 'left sleeve outside 4', 'left sleeve outside 3', 'left sleeve outside 2', 'left sleeve outside 1', 'lower left center neckline', 'left side inside 1', 'left side inside 2', 'left side inside 3', 'right side inside 1', 'right side inside 2'], 'skeleton': []}
short_sleeved_outwear={'keypoints': ['upper center neckline', 'lower right center neckline', 'lower right neckline', 'upper right neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right cuff outside', 'right cuff inside', 'right sleeve inside 2', 'right sleeve inside 1', 'right side outside 1', 'right side outside 2', 'right side outside 3', 'right side inside 3', 'left side outside 3', 'left side outside 2', 'left side outside 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left cuff inside', 'left cuff outside', 'left sleeve outside 2', 'left sleeve outside 1', 'lower left center neckline', 'left side inside 1', 'left side inside 2', 'left side inside 3', 'right side inside 1', 'right side inside 2'], 'skeleton': []}
sling={'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve'], 'skeleton': []}
vest = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve'], 'skeleton': []}
long_sleeved_dress={'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right sleeve outside 3', 'right sleeve outside 4', 'right cuff outside', 'right cuff inside', 'right sleeve inside 4', 'right sleeve inside 3', 'right sleeve inside 2', 'right sleeve inside 1', 'right side 1', 'right side 2', 'right side 3', 'right side 4', 'right side 5', 'center hem', 'left side 5', 'left side 4', 'left side 3', 'left side 2', 'left side 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left sleeve inside 3', 'left sleeve inside 4', 'left cuff inside', 'left cuff outside', 'left sleeve outside 4', 'left sleeve outside 3', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []}
long_sleeved_shirt = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right sleeve outside 3', 'right sleeve outside 4', 'right cuff outside', 'right cuff inside', 'right sleeve inside 4', 'right sleeve inside 3', 'right sleeve inside 2', 'right sleeve inside 1', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left sleeve inside 3', 'left sleeve inside 4', 'left cuff inside', 'left cuff outside', 'left sleeve outside 4', 'left sleeve outside 3', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []}
trousers = {'keypoints': ['right side outside 1', 'upper center', 'left side outside 1', 'right side outside 2', 'right side outside 3', 'right cuff outside', 'right cuff inside', 'right side inside 1', 'crotch', 'left side inside 1', 'left cuff inside', 'left cuff outside', 'left side outside 3', 'left side outside 2'], 'skeleton': []}
sling_dress = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right side 1', 'right side 2', 'right side 3', 'right side 4', 'right side 5', 'right side 6', 'center hem', 'left side 6', 'left side 5', 'left side 4', 'left side 3', 'left side 2', 'left side 1'], 'skeleton': []}
vest_dress = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right side 1', 'right side 2', 'right side 3', 'right side 4', 'right side 5', 'right side 6', 'center hem', 'left side 6', 'left side 5', 'left side 4', 'left side 3', 'left side 2', 'left side 1'], 'skeleton': []}
skirt = {'keypoints': ['right side 1', 'upper center', 'left side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2'], 'skeleton': []}
short_sleeved_dress = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right cuff outside', 'right cuff inside', 'right sleeve inside 1', 'right sleeve inside 2', 'left side 1', 'left side 2', 'left side 3', 'left side 4', 'left side 5', 'center hem', 'right side 5', 'right side 4', 'right side 3', 'right side 2', 'right side 1', 'left sleeve inside 2', 'left sleeve inside 1', 'left cuff inside', 'left cuff outside', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []}
shorts = {'keypoints': ['right side outside 1', 'upper center', 'left side outside 1', 'right side outside 2', 'right cuff outside', 'right cuff inside', 'crotch', 'left cuff inside', 'left cuff outside', 'left side outside 2'], 'skeleton': []}
table = {'keypoints': ['desktop corner 1', 'desktop corner 2', 'desktop corner 3', 'desktop corner 4', 'table leg 1', 'table leg 2', 'table leg 3', 'table leg 4'], 'skeleton': []}
chair = {'keypoints': ['legs righttopcorner', 'legs lefttopcorner', 'legs leftbottomcorner', 'legs rightbottomcorner', 'base righttop', 'base lefttop', 'base leftbottom', 'base rightbottom', 'headboard righttop', 'headboard lefttop'], 'skeleton': []}
bed = {'keypoints': ['legs rightbottomcorner', 'legs righttopcorner', 'base rightbottom', 'base righttop', 'backrest righttop', 'legs leftbottomcorner', 'legs lefttopcorner', 'base leftbottom', 'base lefttop', 'backrest lefttop'], 'skeleton': []}
sofa = {'keypoints': ['legs rightbottomcorner', 'legs righttopcorner', 'base rightbottom', 'base righttop', 'armrests rightbottomcorner', 'armrests righttopcorner', 'backrest righttop', 'legs leftbottomcorner', 'legs lefttopcorner', 'base leftbottom', 'base lefttop', 'armrests leftbottomcorner', 'armrests lefttopcorner', 'backrest lefttop'], 'skeleton': []}
swivelchair = {'keypoints': ['rotatingbase 1', 'rotatingbase 2', 'rotatingbase 3', 'rotatingbase 4', 'rotatingbase 5', 'rotatingbase center', 'base center', 'base righttop', 'base lefttop', 'base leftbottom', 'base rightbottom', 'backrest righttop', 'backrest lefttop'], 'skeleton': []}

View File

@ -0,0 +1,394 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Transforms and data augmentation for both image + bbox.
"""
import os
import sys
import random
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.box_ops import box_xyxy_to_cxcywh
from util.misc import interpolate
def crop(image, target, region):
cropped_image = F.crop(image, *region)
if target is not None:
target = target.copy()
i, j, h, w = region
id2catname = target["id2catname"]
caption_list = target["caption_list"]
target["size"] = torch.tensor([h, w])
fields = ["labels", "area", "iscrowd", "positive_map","keypoints"]
if "boxes" in target:
boxes = target["boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target["boxes"] = cropped_boxes.reshape(-1, 4)
target["area"] = area
fields.append("boxes")
if "masks" in target:
# FIXME should we update the area here if there are no boxes?
target['masks'] = target['masks'][:, i:i + h, j:j + w]
fields.append("masks")
# remove elements for which the boxes or masks that have zero area
if "boxes" in target or "masks" in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if "boxes" in target:
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target['masks'].flatten(1).any(1)
for field in fields:
if field in target:
target[field] = target[field][keep]
if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# for debug and visualization only.
if 'strings_positive' in target:
target['strings_positive'] = [_i for _i, _j in zip(target['strings_positive'], keep) if _j]
if "keypoints" in target:
max_size = torch.as_tensor([w, h], dtype=torch.float32)
keypoints = target["keypoints"]
cropped_keypoints = keypoints.view(-1, 3)[:,:2] - torch.as_tensor([j, i])
cropped_keypoints = torch.min(cropped_keypoints, max_size)
cropped_keypoints = cropped_keypoints.clamp(min=0)
cropped_keypoints = torch.cat([cropped_keypoints, keypoints.view(-1, 3)[:,2].unsqueeze(1)], dim=1)
target["keypoints"] = cropped_keypoints.view(target["keypoints"].shape[0], target["keypoints"].shape[1], 3)
target["id2catname"] = id2catname
target["caption_list"] = caption_list
return cropped_image, target
def hflip(image, target):
flipped_image = F.hflip(image)
w, h = image.size
if target is not None:
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
target["boxes"] = boxes
if "masks" in target:
target['masks'] = target['masks'].flip(-1)
if "keypoints" in target:
dataset_name=target["dataset_name"]
if dataset_name == "coco_person" or dataset_name == "macaque":
flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16]]
elif dataset_name=="animalkindom_ak_P1_animal":
flip_pairs = [[1, 2], [4, 5],[7,8],[9,10],[11,12],[14,15],[16,17],[18,19]]
elif dataset_name=="animalweb_animal":
flip_pairs = [[0, 3], [1, 2], [5, 6]]
elif dataset_name=="face":
flip_pairs = [
[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
[17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
[31, 35], [32, 34],
[36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
[48, 54], [49, 53], [50, 52],
[55, 59], [56, 58],
[60, 64], [61, 63],
[65, 67]
]
elif dataset_name=="hand":
flip_pairs = []
elif dataset_name=="foot":
flip_pairs = []
elif dataset_name=="locust":
flip_pairs = [[5, 20], [6, 21], [7, 22], [8, 23], [9, 24], [10, 25], [11, 26], [12, 27], [13, 28], [14, 29], [15, 30], [16, 31], [17, 32], [18, 33], [19, 34]]
elif dataset_name=="fly":
flip_pairs = [[1, 2], [6, 18], [7, 19], [8, 20], [9, 21], [10, 22], [11, 23], [12, 24], [13, 25], [14, 26], [15, 27], [16, 28], [17, 29], [30, 31]]
elif dataset_name == "ap_36k_animal" or dataset_name == "ap_10k_animal":
flip_pairs = [[0, 1],[5, 8], [6, 9], [7, 10], [11, 14], [12, 15], [13, 16]]
keypoints = target["keypoints"]
keypoints[:,:,0] = w - keypoints[:,:, 0]-1
for pair in flip_pairs:
keypoints[:,pair[0], :], keypoints[:,pair[1], :] = keypoints[:,pair[1], :], keypoints[:,pair[0], :].clone()
target["keypoints"] = keypoints
return flipped_image, target
def resize(image, target, size, max_size=None):
# size can be min_size (scalar) or (w, h) tuple
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size)
if target is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
target["boxes"] = scaled_boxes
if "area" in target:
area = target["area"]
scaled_area = area * (ratio_width * ratio_height)
target["area"] = scaled_area
if "keypoints" in target:
keypoints = target["keypoints"]
scaled_keypoints = keypoints * torch.as_tensor([ratio_width, ratio_height, 1])
target["keypoints"] = scaled_keypoints
h, w = size
target["size"] = torch.tensor([h, w])
if "masks" in target:
target['masks'] = interpolate(
target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
return rescaled_image, target
def pad(image, target, padding):
# assumes that we only pad on the bottom right corners
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
if target is None:
return padded_image, None
target = target.copy()
# should we do something wrt the original size?
target["size"] = torch.tensor(padded_image.size[::-1])
if "masks" in target:
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
return padded_image, target
class ResizeDebug(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
return resize(img, target, self.size)
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
region = T.RandomCrop.get_params(img, self.size)
return crop(img, target, region)
class RandomSizeCrop(object):
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
# respect_boxes: True to keep all boxes
# False to tolerence box filter
self.min_size = min_size
self.max_size = max_size
self.respect_boxes = respect_boxes
def __call__(self, img: PIL.Image.Image, target: dict):
init_boxes = len(target["boxes"]) if (target is not None and "boxes" in target) else 0
max_patience = 10
for i in range(max_patience):
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])
result_img, result_target = crop(img, target, region)
if target is not None:
if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1:
return result_img, result_target
return result_img, result_target
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target
class RandomResize(object):
def __init__(self, sizes, max_size=None):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size
def __call__(self, img, target=None):
size = random.choice(self.sizes)
return resize(img, target, size, self.max_size)
class RandomPad(object):
def __init__(self, max_pad):
self.max_pad = max_pad
def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad(img, target, (pad_x, pad_y))
class RandomSelect(object):
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""
def __init__(self, transforms1, transforms2, p=0.5):
self.transforms1 = transforms1
self.transforms2 = transforms2
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return self.transforms1(img, target)
return self.transforms2(img, target)
class ToTensor(object):
def __call__(self, img, target):
return F.to_tensor(img), target
class RandomErasing(object):
def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs)
def __call__(self, img, target):
return self.eraser(img), target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
if "area" in target:
area = target["area"]
area = area / (torch.tensor(w, dtype=torch.float32)*torch.tensor(h, dtype=torch.float32))
target["area"] = area
if "keypoints" in target:
keypoints = target["keypoints"]
V = keypoints[:, :, 2]
V[V == 2] = 1
Z=keypoints[:, :, :2]
Z = Z.contiguous().view(-1, 2 * V.shape[-1])
Z = Z / torch.tensor([w, h] * V.shape[-1], dtype=torch.float32)
target["valid_kpt_num"] = V.shape[1]
Z_pad = torch.zeros(Z.shape[0],68 * 2 - Z.shape[1])
V_pad = torch.zeros(V.shape[0],68 - V.shape[1])
V=torch.cat([V, V_pad], dim=1)
Z=torch.cat([Z, Z_pad], dim=1)
all_keypoints = torch.cat([Z, V], dim=1)
target["keypoints"] = all_keypoints
return image, target
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string

View File

@ -0,0 +1,159 @@
import copy
class Dict(dict):
def __init__(__self, *args, **kwargs):
object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
object.__setattr__(__self, '__key', kwargs.pop('__key', None))
object.__setattr__(__self, '__frozen', False)
for arg in args:
if not arg:
continue
elif isinstance(arg, dict):
for key, val in arg.items():
__self[key] = __self._hook(val)
elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
__self[arg[0]] = __self._hook(arg[1])
else:
for key, val in iter(arg):
__self[key] = __self._hook(val)
for key, val in kwargs.items():
__self[key] = __self._hook(val)
def __setattr__(self, name, value):
if hasattr(self.__class__, name):
raise AttributeError("'Dict' object attribute "
"'{0}' is read-only".format(name))
else:
self[name] = value
def __setitem__(self, name, value):
isFrozen = (hasattr(self, '__frozen') and
object.__getattribute__(self, '__frozen'))
if isFrozen and name not in super(Dict, self).keys():
raise KeyError(name)
super(Dict, self).__setitem__(name, value)
try:
p = object.__getattribute__(self, '__parent')
key = object.__getattribute__(self, '__key')
except AttributeError:
p = None
key = None
if p is not None:
p[key] = self
object.__delattr__(self, '__parent')
object.__delattr__(self, '__key')
def __add__(self, other):
if not self.keys():
return other
else:
self_type = type(self).__name__
other_type = type(other).__name__
msg = "unsupported operand type(s) for +: '{}' and '{}'"
raise TypeError(msg.format(self_type, other_type))
@classmethod
def _hook(cls, item):
if isinstance(item, dict):
return cls(item)
elif isinstance(item, (list, tuple)):
return type(item)(cls._hook(elem) for elem in item)
return item
def __getattr__(self, item):
return self.__getitem__(item)
def __missing__(self, name):
if object.__getattribute__(self, '__frozen'):
raise KeyError(name)
return self.__class__(__parent=self, __key=name)
def __delattr__(self, name):
del self[name]
def to_dict(self):
base = {}
for key, value in self.items():
if isinstance(value, type(self)):
base[key] = value.to_dict()
elif isinstance(value, (list, tuple)):
base[key] = type(value)(
item.to_dict() if isinstance(item, type(self)) else
item for item in value)
else:
base[key] = value
return base
def copy(self):
return copy.copy(self)
def deepcopy(self):
return copy.deepcopy(self)
def __deepcopy__(self, memo):
other = self.__class__()
memo[id(self)] = other
for key, value in self.items():
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
return other
def update(self, *args, **kwargs):
other = {}
if args:
if len(args) > 1:
raise TypeError()
other.update(args[0])
other.update(kwargs)
for k, v in other.items():
if ((k not in self) or
(not isinstance(self[k], dict)) or
(not isinstance(v, dict))):
self[k] = v
else:
self[k].update(v)
def __getnewargs__(self):
return tuple(self.items())
def __getstate__(self):
return self
def __setstate__(self, state):
self.update(state)
def __or__(self, other):
if not isinstance(other, (Dict, dict)):
return NotImplemented
new = Dict(self)
new.update(other)
return new
def __ror__(self, other):
if not isinstance(other, (Dict, dict)):
return NotImplemented
new = Dict(other)
new.update(self)
return new
def __ior__(self, other):
self.update(other)
return self
def setdefault(self, key, default=None):
if key in self:
return self[key]
else:
self[key] = default
return default
def freeze(self, shouldFreeze=True):
object.__setattr__(self, '__frozen', shouldFreeze)
for key, val in self.items():
if isinstance(val, Dict):
val.freeze(shouldFreeze)
def unfreeze(self):
self.freeze(False)

View File

@ -0,0 +1,139 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch, os
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
# import ipdb; ipdb.set_trace()
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / (union + 1e-6)
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
# except:
# import ipdb; ipdb.set_trace()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / (area + 1e-6)
# modified from torchvision to also return the union
def box_iou_pairwise(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
wh = (rb - lt).clamp(min=0) # [N,2]
inter = wh[:, 0] * wh[:, 1] # [N]
union = area1 + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou_pairwise(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
Input:
- boxes1, boxes2: N,4
Output:
- giou: N, 4
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
assert boxes1.shape == boxes2.shape
iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,2]
area = wh[:, 0] * wh[:, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)
if __name__ == '__main__':
x = torch.rand(5, 4)
y = torch.rand(3, 4)
iou, union = box_iou(x, y)
import ipdb; ipdb.set_trace()

View File

@ -0,0 +1,426 @@
# ==========================================================
# Modified from mmcv
# ==========================================================
import sys
import os.path as osp
import ast
import tempfile
import shutil
from importlib import import_module
from argparse import Action
from .addict import Dict
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict']
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'")
except Exception as e:
ex = e
else:
return value
raise ex
class Config(object):
"""
config files.
only support .py file as config now.
ref: mmcv.utils.config
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _validate_py_syntax(filename):
with open(filename) as f:
content = f.read()
try:
ast.parse(content)
except SyntaxError:
raise SyntaxError('There are syntax errors in config '
f'file {filename}')
@staticmethod
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
if filename.lower().endswith('.py'):
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix='.py')
temp_config_name = osp.basename(temp_config_file.name)
shutil.copyfile(filename,
osp.join(temp_config_dir, temp_config_name))
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules[temp_module_name]
# close temp file
temp_config_file.close()
elif filename.lower().endswith(('.yml', '.yaml', '.json')):
from .slio import slload
cfg_dict = slload(filename)
else:
raise IOError('Only py/yml/yaml/json type are supported now!')
cfg_text = filename + '\n'
with open(filename, 'r') as f:
cfg_text += f.read()
# parse the base file
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError('Duplicate key is not allowed among bases')
# TODO Allow the duplicate key while warnning user
base_cfg_dict.update(c)
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b):
"""merge dict `a` into dict `b` (non-inplace).
values in `a` will overwrite `b`.
copy first to avoid inplace modification
Args:
a ([type]): [description]
b ([type]): [description]
Returns:
[dict]: [description]
"""
# import ipdb; ipdb.set_trace()
if not isinstance(a, dict):
return a
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict) and not isinstance(b[k], list):
# if :
# import ipdb; ipdb.set_trace()
raise TypeError(
f'{k}={v} in child config cannot inherit from base '
f'because {k} is a dict in the child config but is of '
f'type {type(b[k])} in base config. You may set '
f'`{DELETE_KEY}=True` to ignore the base config')
b[k] = Config._merge_a_into_b(v, b[k])
elif isinstance(b, list):
try:
_ = int(k)
except:
raise TypeError(
f'b is a list, '
f'index {k} should be an int when input but {type(k)}'
)
b[int(k)] = Config._merge_a_into_b(v, b[int(k)])
else:
b[k] = v
return b
@staticmethod
def fromfile(filename):
cfg_dict, cfg_text = Config._file2dict(filename)
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but '
f'got {type(cfg_dict)}')
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f'{key} is reserved for config file')
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
text = f.read()
else:
text = ''
super(Config, self).__setattr__('_text', text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = '[\n'
v_str += '\n'.join(
f'dict({_indent(_format_dict(v_), indent)}),'
for v_ in v).rstrip(',')
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent) + ']'
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= \
(not str(key_name).isidentifier())
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ''
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += '{'
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: dict({v_str}'
else:
attr_str = f'{str(k)}=dict({v_str}'
attr_str = _indent(attr_str, indent) + ')' + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += '\n'.join(s)
if use_mapping:
r += '}'
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
return text
def __repr__(self):
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
# # debug
# print('+'*15)
# print('name=%s' % name)
# print("addr:", id(self))
# # print('type(self):', type(self))
# print(self.__dict__)
# print('+'*15)
# if self.__dict__ == {}:
# raise ValueError
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def dump(self, file=None):
# import ipdb; ipdb.set_trace()
if file is None:
return self.pretty_text
else:
with open(file, 'w') as f:
f.write(self.pretty_text)
def merge_from_dict(self, options):
"""Merge list into cfg_dict
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
Args:
options (dict): dict of configs to merge from.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split('.')
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
super(Config, self).__setattr__(
'_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
# for multiprocess
def __setstate__(self, state):
self.__init__(state)
def copy(self):
return Config(self._cfg_dict.copy())
def deepcopy(self):
return Config(self._cfg_dict.deepcopy())
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
if val.lower() in ['none', 'null']:
return None
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(',')]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)

View File

@ -0,0 +1,29 @@
import torch, os
def keypoint_xyxyzz_to_xyzxyz(keypoints: torch.Tensor):
"""_summary_
Args:
keypoints (torch.Tensor): ..., 51
"""
res = torch.zeros_like(keypoints)
num_points = keypoints.shape[-1] // 3
Z = keypoints[..., :2*num_points]
V = keypoints[..., 2*num_points:]
res[...,0::3] = Z[..., 0::2]
res[...,1::3] = Z[..., 1::2]
res[...,2::3] = V[...]
return res
def keypoint_xyzxyz_to_xyxyzz(keypoints: torch.Tensor):
"""_summary_
Args:
keypoints (torch.Tensor): ..., 51
"""
res = torch.zeros_like(keypoints)
num_points = keypoints.shape[-1] // 3
res[...,0:2*num_points:2] = keypoints[..., 0::3]
res[...,1:2*num_points:2] = keypoints[..., 1::3]
res[...,2*num_points:] = keypoints[..., 2::3]
return res

View File

@ -0,0 +1,701 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import functools
import io
import os
import random
import subprocess
import time
from collections import OrderedDict, defaultdict, deque
import datetime
import pickle
from typing import Optional, List
import json, time
import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor
import colorsys
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
__torchvision_need_compat_flag = float(torchvision.__version__.split('.')[1]) < 7
if __torchvision_need_compat_flag:
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
if d.shape[0] == 0:
return 0
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
if os.environ.get("SHILONG_AMP", None) == '1':
eps = 1e-4
else:
eps = 1e-6
return self.total / (self.count + eps)
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
return dist.group.WORLD
def all_gather_cpu(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
cpu_group = _get_global_gloo_group()
buffer = io.BytesIO()
torch.save(data, buffer)
data_view = buffer.getbuffer()
device = "cuda" if cpu_group is None else "cpu"
tensor = torch.ByteTensor(data_view).to(device)
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
if cpu_group is None:
dist.all_gather(size_list, local_size)
else:
print("gathering on cpu")
dist.all_gather(size_list, local_size, group=cpu_group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
assert isinstance(local_size.item(), int)
local_size = int(local_size.item())
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
tensor = torch.cat((tensor, padding), dim=0)
if cpu_group is None:
dist.all_gather(tensor_list, tensor)
else:
dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer)
data_list.append(obj)
return data_list
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
if os.getenv("CPU_REDUCE") == "1":
return all_gather_cpu(data)
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
# print(name, str(meter))
# import ipdb;ipdb.set_trace()
if meter.count > 0:
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None, logger=None):
if logger is None:
print_func = print
else:
print_func = logger.info
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
# import ipdb; ipdb.set_trace()
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print_func(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print_func(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print_func('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
# import ipdb; ipdb.set_trace()
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
if mask == 'auto':
self.mask = torch.zeros_like(tensors).to(tensors.device)
if self.mask.dim() == 3:
self.mask = self.mask.sum(0).to(bool)
elif self.mask.dim() == 4:
self.mask = self.mask.sum(1).to(bool)
else:
raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape))
def imgsize(self):
res = []
for i in range(self.tensors.shape[0]):
mask = self.mask[i]
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
res.append(torch.Tensor([maxH, maxW]))
return res
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def to_img_list_single(self, tensor, mask):
assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
img = tensor[:, :maxH, :maxW]
return img
def to_img_list(self):
"""remove the padding and convert to img list
Returns:
[type]: [description]
"""
if self.tensors.dim() == 3:
return self.to_img_list_single(self.tensors, self.mask)
else:
res = []
for i in range(self.tensors.shape[0]):
tensor_i = self.tensors[i]
mask_i = self.mask[i]
res.append(self.to_img_list_single(tensor_i, mask_i))
return res
@property
def device(self):
return self.tensors.device
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
@property
def shape(self):
return {
'tensors.shape': self.tensors.shape,
'mask.shape': self.mask.shape
}
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
# launch by torch.distributed.launch
# Single node
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
# Multi nodes
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
# local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
# args.world_size = args.world_size * local_world_size
# args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
# args.rank = args.rank * local_world_size + args.local_rank
print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank))
print(json.dumps(dict(os.environ), indent=2))
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID'])
args.world_size = int(os.environ['SLURM_NPROCS'])
if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1':
pass
else:
import util.hostlist as uh
nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST'])
gpu_ids = [int(node[3:]) for node in nodenames]
fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0))
# fixid += random.randint(0, 300)
port = str(3137 + int(min(gpu_ids)) + fixid)
args.dist_url = "tcp://{ip}:{port}".format(ip=uh.nodename_to_ip(nodenames[0]), port=port)
print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count()))
else:
print('Not using distributed mode')
args.distributed = False
args.world_size = 1
args.rank = 0
args.local_rank = 0
return
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
args.distributed = True
torch.cuda.set_device(args.local_rank)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend,
world_size=args.world_size,
rank=args.rank,
init_method=args.dist_url,
)
print("Before torch.distributed.barrier()")
torch.distributed.barrier()
print("End torch.distributed.barrier()")
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
@torch.no_grad()
def accuracy_onehot(pred, gt):
"""_summary_
Args:
pred (_type_): n, c
gt (_type_): n, c
"""
tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
acc = tp / gt.shape[0] * 100
return acc
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if __torchvision_need_compat_flag < 0.7:
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
class color_sys():
def __init__(self, num_colors) -> None:
self.num_colors = num_colors
colors=[]
for i in np.arange(0., 360., 360. / num_colors):
hue = i/360.
lightness = (50 + np.random.rand() * 10)/100.
saturation = (90 + np.random.rand() * 10)/100.
colors.append(tuple([int(j*255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)]))
self.colors = colors
def __call__(self, idx):
return self.colors[idx]
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1/x2)
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == 'module.':
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict

View File

@ -9,6 +9,8 @@ import os.path as osp
import torch
from collections import OrderedDict
import numpy as np
from scipy.spatial import ConvexHull # pylint: disable=E0401,E0611
from typing import Union
import cv2
from ..modules.spade_generator import SPADEDecoder
@ -18,6 +20,27 @@ from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork
def tensor_to_numpy(data: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
"""transform torch.Tensor into numpy.ndarray"""
if isinstance(data, torch.Tensor):
return data.data.cpu().numpy()
return data
def calc_motion_multiplier(
kp_source: Union[np.ndarray, torch.Tensor],
kp_driving_initial: Union[np.ndarray, torch.Tensor]
) -> float:
"""calculate motion_multiplier based on the source image and the first driving frame"""
kp_source_np = tensor_to_numpy(kp_source)
kp_driving_initial_np = tensor_to_numpy(kp_driving_initial)
source_area = ConvexHull(kp_source_np.squeeze(0)).volume
driving_area = ConvexHull(kp_driving_initial_np.squeeze(0)).volume
motion_multiplier = np.sqrt(source_area) / np.sqrt(driving_area)
# motion_multiplier = np.cbrt(source_area) / np.cbrt(driving_area)
return motion_multiplier
def suffix(filename):
"""a.jpg -> jpg"""
pos = filename.rfind(".")
@ -163,3 +186,11 @@ def is_square_video(video_path):
# gr.Info(f"Uploaded video is not square, force do crop (driving) to be True")
return width == height
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == 'module.':
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict

Binary file not shown.

Binary file not shown.

View File

@ -59,8 +59,9 @@ def video2gif(video_fp, fps=30, size=256):
# use the palette to generate the gif
cmd = f'ffmpeg -i "{video_fp}" -i "{palette_wfp}" -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" "{gif_wfp}" -y'
exec_cmd(cmd)
return gif_wfp
else:
print(f'video_fp: {video_fp} not exists!')
raise FileNotFoundError(f"video_fp: {video_fp} not exists!")
def merge_audio_video(video_fp, audio_fp, wfp):