mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-15 14:02:12 +00:00
trace_motion_extractor
This commit is contained in:
parent
a8e40677dd
commit
c921dfafa5
@ -76,10 +76,11 @@ def trace_appearance_feature_extractor():
|
||||
torch.jit.save(appearance_feature_extractor, "build/appearance_feature_extractor.pt")
|
||||
|
||||
|
||||
trace_appearance_feature_extractor()
|
||||
# trace_appearance_feature_extractor() # done
|
||||
|
||||
|
||||
def trace_motion_extractor():
|
||||
import src.modules.convnextv2
|
||||
motion_extractor = load_model(
|
||||
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth",
|
||||
model_config=model_config,
|
||||
@ -90,12 +91,14 @@ def trace_motion_extractor():
|
||||
with torch.no_grad():
|
||||
motion_extractor.eval()
|
||||
motion_extractor = torch.jit.script(motion_extractor)
|
||||
# model = src.modules.convnextv2.ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], nuk_kp=21)
|
||||
# model = torch.jit.script(model.downsample_layers)
|
||||
|
||||
LOGGER.info("Traced motion_extractor")
|
||||
torch.jit.save(motion_extractor, "build/motion_extractor.pt")
|
||||
|
||||
|
||||
# trace_motion_extractor()
|
||||
trace_motion_extractor()
|
||||
|
||||
|
||||
def trace_warping_module():
|
||||
@ -114,7 +117,7 @@ def trace_warping_module():
|
||||
torch.jit.save(warping_module, "build/warping_module.pt")
|
||||
|
||||
|
||||
trace_warping_module()
|
||||
# trace_warping_module() # done
|
||||
|
||||
|
||||
def trace_spade_generator():
|
||||
@ -133,7 +136,7 @@ def trace_spade_generator():
|
||||
torch.jit.save(spade_generator, "build/spade_generator.pt")
|
||||
|
||||
|
||||
trace_spade_generator()
|
||||
trace_spade_generator() # done
|
||||
|
||||
|
||||
def trace_stitching_retargeting_module():
|
||||
@ -145,14 +148,21 @@ def trace_stitching_retargeting_module():
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
stitching_retargeting_module.eval()
|
||||
stitching_retargeting_module = torch.jit.script(stitching_retargeting_module)
|
||||
stitching = stitching_retargeting_module['stitching'].eval()
|
||||
lip = stitching_retargeting_module['lip'].eval()
|
||||
eye = stitching_retargeting_module['eye'].eval()
|
||||
|
||||
stitching_trace = torch.jit.script(stitching)
|
||||
lip_trace = torch.jit.script(lip)
|
||||
eye_trace = torch.jit.script(eye)
|
||||
|
||||
LOGGER.info("Traced stitching_retargeting_module")
|
||||
torch.jit.save(stitching_retargeting_module, "build/stitching_retargeting_module.pt")
|
||||
torch.jit.save(stitching_trace, "build/stitching_retargeting_module_stitching.pt")
|
||||
torch.jit.save(lip_trace, "build/stitching_retargeting_module_lip.pt")
|
||||
torch.jit.save(eye_trace, "build/stitching_retargeting_module_eye.pt")
|
||||
|
||||
|
||||
# trace_stitching_retargeting_module()
|
||||
trace_stitching_retargeting_module() # done
|
||||
|
||||
"""
|
||||
|
||||
@ -284,5 +294,62 @@ def apply(module: Module, name: str):
|
||||
return fn
|
||||
|
||||
|
||||
dims = [96, 192, 384, 768]
|
||||
from typing import List
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-6, data_format: str = "channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Unsupported data_format: {self.data_format}")
|
||||
def extra_repr(self) -> str:
|
||||
return f'normalized_shape={self.normalized_shape}, ' \
|
||||
f'eps={self.eps}, data_format={self.data_format}'
|
||||
|
||||
|
||||
class YourModule(nn.Module):
|
||||
def __init__(self, in_chans: int, dims: List[int]):
|
||||
super().__init__()
|
||||
self.downsample_layers = nn.ModuleList()
|
||||
stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
||||
)
|
||||
self.downsample_layers.append(stem)
|
||||
for i in range(3):
|
||||
downsample_layer = nn.Sequential(
|
||||
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
||||
)
|
||||
self.downsample_layers.append(downsample_layer)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.downsample_layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
# Now try to script the entire module
|
||||
model = YourModule(3, dims)
|
||||
torch.jit.script(model)
|
||||
|
||||
"""
|
||||
|
@ -67,18 +67,39 @@ class ConvNeXtV2(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.depths = depths
|
||||
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
||||
stem = nn.Sequential(
|
||||
# self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
||||
# stem = nn.Sequential(
|
||||
# nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
||||
# LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
||||
# )
|
||||
# self.downsample_layers.append(stem)
|
||||
# for i in range(3):
|
||||
# downsample_layer = nn.Sequential(
|
||||
# LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
# nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
||||
# )
|
||||
# self.downsample_layers.append(downsample_layer)
|
||||
|
||||
self.downsample_layers = nn.ModuleList(
|
||||
(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
||||
)
|
||||
self.downsample_layers.append(stem)
|
||||
for i in range(3):
|
||||
downsample_layer = nn.Sequential(
|
||||
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
||||
)
|
||||
self.downsample_layers.append(downsample_layer)
|
||||
),
|
||||
nn.Sequential(
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
||||
nn.Conv2d(dims[0], dims[0+1], kernel_size=2, stride=2),
|
||||
),
|
||||
nn.Sequential(
|
||||
LayerNorm(dims[1], eps=1e-6, data_format="channels_first"),
|
||||
nn.Conv2d(dims[1], dims[1+1], kernel_size=2, stride=2),
|
||||
),
|
||||
nn.Sequential(
|
||||
LayerNorm(dims[2], eps=1e-6, data_format="channels_first"),
|
||||
nn.Conv2d(dims[2], dims[2+1], kernel_size=2, stride=2),
|
||||
),
|
||||
)) # stem and 3 intermediate downsampling conv layers
|
||||
|
||||
|
||||
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
||||
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
||||
@ -111,9 +132,18 @@ class ConvNeXtV2(nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_features(self, x):
|
||||
for i in range(4):
|
||||
x = self.downsample_layers[i](x)
|
||||
x = self.stages[i](x)
|
||||
x = self.downsample_layers[0](x)
|
||||
x = self.stages[0](x)
|
||||
|
||||
x = self.downsample_layers[1](x)
|
||||
x = self.stages[1](x)
|
||||
|
||||
x = self.downsample_layers[2](x)
|
||||
x = self.stages[2](x)
|
||||
|
||||
x = self.downsample_layers[3](x)
|
||||
x = self.stages[3](x)
|
||||
|
||||
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -42,24 +42,24 @@ class SPADEDecoder(nn.Module):
|
||||
def forward(self, feature):
|
||||
seg = feature # Bx256x64x64
|
||||
x = self.fc(feature) # Bx512x64x64
|
||||
print("self.G_middle_0: %s", x.shape, seg.shape)
|
||||
# print("self.G_middle_0: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_0(x, seg)
|
||||
print("self.G_middle_1: %s", x.shape, seg.shape)
|
||||
# print("self.G_middle_1: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_1(x, seg)
|
||||
print("self.G_middle_2: %s", x.shape, seg.shape)
|
||||
# print("self.G_middle_2: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_2(x, seg)
|
||||
print("self.G_middle_3: %s", x.shape, seg.shape)
|
||||
# print("self.G_middle_3: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_3(x, seg)
|
||||
print("self.G_middle_4: %s", x.shape, seg.shape)
|
||||
# print("self.G_middle_4: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_4(x, seg)
|
||||
print("self.G_middle_5: %s", x.shape, seg.shape)
|
||||
# print("self.G_middle_5: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_5(x, seg)
|
||||
|
||||
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
|
||||
print("self.up_0: %s", x.shape, seg.shape)
|
||||
# print("self.up_0: %s", x.shape, seg.shape)
|
||||
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
|
||||
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
|
||||
print("self.up_1: %s", x.shape, seg.shape)
|
||||
# print("self.up_1: %s", x.shape, seg.shape)
|
||||
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
|
||||
|
||||
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
|
||||
@ -70,7 +70,7 @@ class SPADEDecoder(nn.Module):
|
||||
def patch_nonscriptable_classes(self):
|
||||
in_1 = torch.rand(1, 512, 64, 64)
|
||||
in_2 = torch.rand(1, 256, 64, 64)
|
||||
|
||||
|
||||
in_3 = torch.rand(1, 512, 128, 128)
|
||||
in_4 = torch.rand(1, 256, 256, 256)
|
||||
|
||||
|
@ -5,12 +5,13 @@ This file defines various neural network modules and utility functions, includin
|
||||
normalizations, and functions for spatial transformation and tensor manipulation.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from torch.nn.utils.spectral_norm import spectral_norm
|
||||
# from torch.nn.utils.parametrizations import spectral_norm
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
import math
|
||||
import warnings
|
||||
@ -378,8 +379,10 @@ class LayerNorm(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
elif self.data_format == "channels_first":
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
|
||||
else: # elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
|
Loading…
Reference in New Issue
Block a user