trace_motion_extractor

This commit is contained in:
Rafael Silva 2024-07-11 23:42:19 -04:00
parent a8e40677dd
commit c921dfafa5
4 changed files with 133 additions and 33 deletions

View File

@ -76,10 +76,11 @@ def trace_appearance_feature_extractor():
torch.jit.save(appearance_feature_extractor, "build/appearance_feature_extractor.pt")
trace_appearance_feature_extractor()
# trace_appearance_feature_extractor() # done
def trace_motion_extractor():
import src.modules.convnextv2
motion_extractor = load_model(
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth",
model_config=model_config,
@ -90,12 +91,14 @@ def trace_motion_extractor():
with torch.no_grad():
motion_extractor.eval()
motion_extractor = torch.jit.script(motion_extractor)
# model = src.modules.convnextv2.ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], nuk_kp=21)
# model = torch.jit.script(model.downsample_layers)
LOGGER.info("Traced motion_extractor")
torch.jit.save(motion_extractor, "build/motion_extractor.pt")
# trace_motion_extractor()
trace_motion_extractor()
def trace_warping_module():
@ -114,7 +117,7 @@ def trace_warping_module():
torch.jit.save(warping_module, "build/warping_module.pt")
trace_warping_module()
# trace_warping_module() # done
def trace_spade_generator():
@ -133,7 +136,7 @@ def trace_spade_generator():
torch.jit.save(spade_generator, "build/spade_generator.pt")
trace_spade_generator()
trace_spade_generator() # done
def trace_stitching_retargeting_module():
@ -145,14 +148,21 @@ def trace_stitching_retargeting_module():
)
with torch.no_grad():
stitching_retargeting_module.eval()
stitching_retargeting_module = torch.jit.script(stitching_retargeting_module)
stitching = stitching_retargeting_module['stitching'].eval()
lip = stitching_retargeting_module['lip'].eval()
eye = stitching_retargeting_module['eye'].eval()
stitching_trace = torch.jit.script(stitching)
lip_trace = torch.jit.script(lip)
eye_trace = torch.jit.script(eye)
LOGGER.info("Traced stitching_retargeting_module")
torch.jit.save(stitching_retargeting_module, "build/stitching_retargeting_module.pt")
torch.jit.save(stitching_trace, "build/stitching_retargeting_module_stitching.pt")
torch.jit.save(lip_trace, "build/stitching_retargeting_module_lip.pt")
torch.jit.save(eye_trace, "build/stitching_retargeting_module_eye.pt")
# trace_stitching_retargeting_module()
trace_stitching_retargeting_module() # done
"""
@ -284,5 +294,62 @@ def apply(module: Module, name: str):
return fn
dims = [96, 192, 384, 768]
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps: float = 1e-6, data_format: str = "channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
else:
raise ValueError(f"Unsupported data_format: {self.data_format}")
def extra_repr(self) -> str:
return f'normalized_shape={self.normalized_shape}, ' \
f'eps={self.eps}, data_format={self.data_format}'
class YourModule(nn.Module):
def __init__(self, in_chans: int, dims: List[int]):
super().__init__()
self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.downsample_layers:
x = layer(x)
return x
# Now try to script the entire module
model = YourModule(3, dims)
torch.jit.script(model)
"""

View File

@ -67,18 +67,39 @@ class ConvNeXtV2(nn.Module):
):
super().__init__()
self.depths = depths
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
# self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
# stem = nn.Sequential(
# nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
# LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
# )
# self.downsample_layers.append(stem)
# for i in range(3):
# downsample_layer = nn.Sequential(
# LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
# nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
# )
# self.downsample_layers.append(downsample_layer)
self.downsample_layers = nn.ModuleList(
(
nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
),
nn.Sequential(
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[0], dims[0+1], kernel_size=2, stride=2),
),
nn.Sequential(
LayerNorm(dims[1], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[1], dims[1+1], kernel_size=2, stride=2),
),
nn.Sequential(
LayerNorm(dims[2], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[2], dims[2+1], kernel_size=2, stride=2),
),
)) # stem and 3 intermediate downsampling conv layers
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
@ -111,9 +132,18 @@ class ConvNeXtV2(nn.Module):
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
x = self.downsample_layers[0](x)
x = self.stages[0](x)
x = self.downsample_layers[1](x)
x = self.stages[1](x)
x = self.downsample_layers[2](x)
x = self.stages[2](x)
x = self.downsample_layers[3](x)
x = self.stages[3](x)
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
def forward(self, x):

View File

@ -42,24 +42,24 @@ class SPADEDecoder(nn.Module):
def forward(self, feature):
seg = feature # Bx256x64x64
x = self.fc(feature) # Bx512x64x64
print("self.G_middle_0: %s", x.shape, seg.shape)
# print("self.G_middle_0: %s", x.shape, seg.shape)
x = self.G_middle_0(x, seg)
print("self.G_middle_1: %s", x.shape, seg.shape)
# print("self.G_middle_1: %s", x.shape, seg.shape)
x = self.G_middle_1(x, seg)
print("self.G_middle_2: %s", x.shape, seg.shape)
# print("self.G_middle_2: %s", x.shape, seg.shape)
x = self.G_middle_2(x, seg)
print("self.G_middle_3: %s", x.shape, seg.shape)
# print("self.G_middle_3: %s", x.shape, seg.shape)
x = self.G_middle_3(x, seg)
print("self.G_middle_4: %s", x.shape, seg.shape)
# print("self.G_middle_4: %s", x.shape, seg.shape)
x = self.G_middle_4(x, seg)
print("self.G_middle_5: %s", x.shape, seg.shape)
# print("self.G_middle_5: %s", x.shape, seg.shape)
x = self.G_middle_5(x, seg)
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
print("self.up_0: %s", x.shape, seg.shape)
# print("self.up_0: %s", x.shape, seg.shape)
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
print("self.up_1: %s", x.shape, seg.shape)
# print("self.up_1: %s", x.shape, seg.shape)
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
@ -70,7 +70,7 @@ class SPADEDecoder(nn.Module):
def patch_nonscriptable_classes(self):
in_1 = torch.rand(1, 512, 64, 64)
in_2 = torch.rand(1, 256, 64, 64)
in_3 = torch.rand(1, 512, 128, 128)
in_4 = torch.rand(1, 256, 256, 256)

View File

@ -5,12 +5,13 @@ This file defines various neural network modules and utility functions, includin
normalizations, and functions for spatial transformation and tensor manipulation.
"""
from typing import List
from typing import List, Optional
from torch import nn
import torch.nn.functional as F
import torch
from torch.nn.utils.spectral_norm import spectral_norm
# from torch.nn.utils.parametrizations import spectral_norm
torch.backends.cudnn.enabled = True
import math
import warnings
@ -378,8 +379,10 @@ class LayerNorm(nn.Module):
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
else: # elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)