spade_generator compiled

This commit is contained in:
Rafael Silva 2024-07-11 22:23:26 -04:00
parent 314b71a3bd
commit dbb3bf87f3
5 changed files with 270 additions and 60 deletions

View File

@ -1,52 +1,7 @@
import torch
from src.utils.helper import load_model
model_config = (
{
'model_params': {
'appearance_feature_extractor_params': {
'image_channel': 3,
'block_expansion': 64,
'num_down_blocks': 2,
'max_features': 512,
'reshape_channel': 32,
'reshape_depth': 16,
'num_resblocks': 6,
},
'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'},
'warping_module_params': {
'num_kp': 21,
'block_expansion': 64,
'max_features': 512,
'num_down_blocks': 2,
'reshape_channel': 32,
'estimate_occlusion_map': True,
'dense_motion_params': {
'block_expansion': 32,
'max_features': 1024,
'num_blocks': 5,
'reshape_depth': 16,
'compress': 4,
},
},
'spade_generator_params': {
'upscale': 2,
'block_expansion': 64,
'max_features': 512,
'num_down_blocks': 2,
},
'stitching_retargeting_module_params': {
'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65},
'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63},
'eye': {
'input_size': 66,
'hidden_sizes': [256, 256, 128, 128, 64],
'output_size': 63,
},
},
}
},
)
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}}
def trace_appearance_feature_extractor():
@ -70,20 +25,21 @@ def trace_appearance_feature_extractor():
def trace_motion_extractor():
motion_extractor = load_model(
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth",
model_config=model_config,
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}},
device=0,
model_type='motion_extractor',
)
# print(motion_extractor)
with torch.no_grad():
motion_extractor.eval()
motion_extractor = torch.jit.script(motion_extractor)
# with torch.no_grad():
# motion_extractor.eval()
# torch.jit.script(self.motion_extractor)
motion_extractor = torch.jit.script(motion_extractor)
# torch.jit.save(motion_extractor, "build/motion_extractor.pt")
trace_motion_extractor()
torch.jit.save(motion_extractor, "build/motion_extractor.pt")
# trace_motion_extractor()
def trace_warping_module():
warping_module = load_model(
@ -101,3 +57,212 @@ def trace_warping_module():
# def trace_warping_module():
def trace_spade_generator():
spade_generator = load_model(
ckpt_path='./pretrained_weights/liveportrait/base_models/spade_generator.pth',
model_config=model_config,
device=0,
model_type='spade_generator',
)
# with torch.no_grad():
# spade_generator.eval()
# print(spade_generator)
spade_generator = torch.jit.script(spade_generator)
torch.jit.save(spade_generator, "build/spade_generator.pt")
trace_spade_generator()
def trace_stitching_retargeting_module():
stitching_retargeting_module = load_model(
ckpt_path='/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth',
model_config=model_config,
device=0,
model_type='stitching_retargeting_module',
)
with torch.no_grad():
stitching_retargeting_module.eval()
stitching_retargeting_module = torch.jit.script(stitching_retargeting_module)
torch.jit.save(stitching_retargeting_module, "build/stitching_retargeting_module.pt")
# trace_stitching_retargeting_module()
'''
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.use_se = use_se
print(f"SPADEResnetBlock: fin={fin}, fout={fout}, norm_G={norm_G}, fmiddle={fmiddle}, learned_shortcut={self.learned_shortcut}, use_se={use_se}")
# create conv layers
# self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
# self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
# if self.learned_shortcut:
# self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
# if 'spectral' in norm_G:
self.conv_0 = spectral_norm(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation))
# self.conv_0: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation), "weight", 1, 0, 1e-12)
self.conv_1 = spectral_norm(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation))
# self.conv_1: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation), "weight", 1, 0, 1e-12)
# if self.learned_shortcut:
self.conv_s = spectral_norm(nn.Conv2d(fin, fout, kernel_size=1, bias=False))
# self.conv_s: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fin, fout, kernel_size=1, bias=False), "weight", 1, 0, 1e-12)
# define normalization layers
self.norm_0 = SPADE(fin, label_nc)
self.norm_1 = SPADE(fmiddle, label_nc)
# if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
from src.modules.util import SPADEResnetBlock
from torch.nn.utils.spectral_norm import spectral_norm, SpectralNorm
from torch.nn.utils.parametrizations import spectral_norm
fin=512; fout=512; norm_G='spadespectralinstance'; label_nc=256; use_se=False; dilation=1
fmiddle = min(fin, fout)
c = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
s = spectral_norm(c)
j = spectral_norm(c).eval()
torch.jit.script(s)
torch.jit.trace(s, torch.randn(1, 512, 64, 64))
m = torch.jit.trace(j, torch.randn(1, 512, 64, 64))
with torch.no_grad():
torch.jit.script(s)
with torch.no_grad():
torch.jit.trace(s, torch.randn(1, 512, 64, 64))
sp = SPADEResnetBlock(fin=512, fout=512, norm_G="spadespectralinstance", label_nc=256, use_se=False, dilation=1)
sp.eval()
sp_traced = torch.jit.trace(sp, (torch.randn(1, 512, 64, 64), torch.randn(1, 256, 64, 64)))
import torch
from torch import nn
with torch.no_grad():
c1 = torch.nn.utils.spectral_norm(torch.nn.Conv2d(1,2,3))
torch.jit.script(c1)
c2 = torch.nn.utils.parametrizations.spectral_norm(torch.nn.Conv2d(1,2,3))
torch.jit.script(c2)
with torch.no_grad():
c2 = torch.nn.utils.parametrizations.spectral_norm(torch.nn.Conv2d(1,2,3))
torch.jit.script(c2)
l = nn.Linear(20, 40)
l.weight.size()
m = torch.nn.utils.spectral_norm(nn.Linear(20, 40))
m
m.weight_u.size()
---
def modify_state_dict_inplace(model):
state_dict = model.state_dict()
keys_to_delete = []
for key in list(state_dict.keys()):
value = state_dict[key]
k = key.split(".")
if len(k) == 3 and k[-1] == "weight" and k[-2] in ["conv_0", "conv_1"]:
# Register new parameters
model.register_parameter(f"{k[0]}.{k[-2]}.weight_orig", torch.nn.Parameter(value))
model.register_parameter(f"{k[0]}.{k[-2]}.weight_u", torch.nn.Parameter(torch.zeros_like(value)))
model.register_parameter(f"{k[0]}.{k[-2]}.weight_v", torch.nn.Parameter(torch.zeros_like(value)))
state_dict[f"{k[0]}.{k[-2]}.weight_orig"] = value
keys_to_delete.append(key)
state_dict[f"{k[0]}.{k[-2]}.weight_u"] = torch.zeros_like(value)
state_dict[f"{k[0]}.{k[-2]}.weight_v"] = torch.zeros_like(value)
# for key in keys_to_delete:
# delattr(module, key)
# Load the modified state_dict back into the model
model.load_state_dict(state_dict)
return model
model = SPADEDecoder(**model_params).cuda(device)
modify_state_dict_inplace(model)
model.state_dict().keys()
model._parameters[name]
model = modify_state_dict_inplace(model)
model.state_dict().keys()
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True)
modify_state_dict_inplace(model.state_dict())
modified_state_dict.keys()
model.register_parameter("aaalero", torch.nn.Parameter(torch.randn(1, 2, 3)))
model.state_dict = modified_state_dict
model = SPADEDecoder(**model_params).cuda(device)
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True)
type(state_dict.keys())
type(modified_state_dict)
missing_keys = ["G_middle_0.conv_0.weight_orig", "G_middle_0.conv_0.weight_u", "G_middle_0.conv_0.weight_v", "G_middle_0.conv_1.weight_orig", "G_middle_0.conv_1.weight_u", "G_middle_0.conv_1.weight_v", "G_middle_1.conv_0.weight_orig", "G_middle_1.conv_0.weight_u", "G_middle_1.conv_0.weight_v", "G_middle_1.conv_1.weight_orig", "G_middle_1.conv_1.weight_u", "G_middle_1.conv_1.weight_v", "G_middle_2.conv_0.weight_orig", "G_middle_2.conv_0.weight_u", "G_middle_2.conv_0.weight_v", "G_middle_2.conv_1.weight_orig", "G_middle_2.conv_1.weight_u", "G_middle_2.conv_1.weight_v", "G_middle_3.conv_0.weight_orig", "G_middle_3.conv_0.weight_u", "G_middle_3.conv_0.weight_v", "G_middle_3.conv_1.weight_orig", "G_middle_3.conv_1.weight_u", "G_middle_3.conv_1.weight_v", "G_middle_4.conv_0.weight_orig", "G_middle_4.conv_0.weight_u", "G_middle_4.conv_0.weight_v", "G_middle_4.conv_1.weight_orig", "G_middle_4.conv_1.weight_u", "G_middle_4.conv_1.weight_v", "G_middle_5.conv_0.weight_orig", "G_middle_5.conv_0.weight_u", "G_middle_5.conv_0.weight_v", "G_middle_5.conv_1.weight_orig", "G_middle_5.conv_1.weight_u", "G_middle_5.conv_1.weight_v", "up_0.conv_0.weight_orig", "up_0.conv_0.weight_u", "up_0.conv_0.weight_v", "up_0.conv_1.weight_orig", "up_0.conv_1.weight_u", "up_0.conv_1.weight_v", "up_0.conv_s.weight_orig", "up_0.conv_s.weight_u", "up_0.conv_s.weight_v", "up_1.conv_0.weight_orig", "up_1.conv_0.weight_u", "up_1.conv_0.weight_v", "up_1.conv_1.weight_orig", "up_1.conv_1.weight_u", "up_1.conv_1.weight_v", "up_1.conv_s.weight_orig", "up_1.conv_s.weight_u", "up_1.conv_s.weight_v"]
missing_keys in list(modified_state_dict.keys())
def apply(module: Module, name: str):
weight = module._parameters[name]
delattr(module, fn.name)
module.register_parameter(fn.name + "_orig", weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a plain
# attribute.
setattr(module, fn.name, weight.data)
module.register_buffer(fn.name + "_u", u)
module.register_buffer(fn.name + "_v", v)
module.register_forward_pre_hook(fn)
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
return fn
'''

View File

@ -20,8 +20,10 @@ class MotionExtractor(nn.Module):
super(MotionExtractor, self).__init__()
# default is convnextv2_base
backbone = kwargs.get('backbone', 'convnextv2_tiny')
self.detector = model_dict.get(backbone)(**kwargs)
# backbone = kwargs.get('backbone', 'convnextv2_tiny')
# self.detector = model_dict.get(backbone)(**kwargs)
self.detector = convnextv2_tiny(num_kp=21)
# print("---> %s", kwargs)
def load_pretrained(self, init_path: str):
if init_path not in (None, ''):

View File

@ -37,23 +37,48 @@ class SPADEDecoder(nn.Module):
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
nn.PixelShuffle(upscale_factor=2)
)
self.patch_nonscriptable_classes()
def forward(self, feature):
seg = feature # Bx256x64x64
x = self.fc(feature) # Bx512x64x64
print("self.G_middle_0: %s", x.shape, seg.shape)
x = self.G_middle_0(x, seg)
print("self.G_middle_1: %s", x.shape, seg.shape)
x = self.G_middle_1(x, seg)
print("self.G_middle_2: %s", x.shape, seg.shape)
x = self.G_middle_2(x, seg)
print("self.G_middle_3: %s", x.shape, seg.shape)
x = self.G_middle_3(x, seg)
print("self.G_middle_4: %s", x.shape, seg.shape)
x = self.G_middle_4(x, seg)
print("self.G_middle_5: %s", x.shape, seg.shape)
x = self.G_middle_5(x, seg)
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
print("self.up_0: %s", x.shape, seg.shape)
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
print("self.up_1: %s", x.shape, seg.shape)
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
x = torch.sigmoid(x) # Bx3xHxW
return x
return x
def patch_nonscriptable_classes(self):
in_1 = torch.rand(1, 512, 64, 64)
in_2 = torch.rand(1, 256, 64, 64)
in_3 = torch.rand(1, 512, 128, 128)
in_4 = torch.rand(1, 256, 256, 256)
self.G_middle_0 = torch.jit.trace(self.G_middle_0.eval(), (in_1, in_2))
self.G_middle_1 = torch.jit.trace(self.G_middle_1.eval(), (in_1, in_2))
self.G_middle_2 = torch.jit.trace(self.G_middle_2.eval(), (in_1, in_2))
self.G_middle_3 = torch.jit.trace(self.G_middle_3.eval(), (in_1, in_2))
self.G_middle_4 = torch.jit.trace(self.G_middle_4.eval(), (in_1, in_2))
self.G_middle_5 = torch.jit.trace(self.G_middle_5.eval(), (in_1, in_2))
self.up_0 = torch.jit.trace(self.up_0.eval(), (in_3, in_2))
self.up_1 = torch.jit.trace(self.up_1.eval(), (in_4, in_2))

View File

@ -9,7 +9,9 @@ from typing import List
from torch import nn
import torch.nn.functional as F
import torch
import torch.nn.utils.spectral_norm as spectral_norm
from torch.nn.utils.spectral_norm import spectral_norm
# from torch.nn.utils.parametrizations import spectral_norm
import math
import warnings
@ -274,10 +276,10 @@ class SPADE(nn.Module):
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# print(f"{fin=}, {fout=}, {norm_G=}, {label_nc=}, {use_se=}, {dilation=}")
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
@ -299,7 +301,23 @@ class SPADEResnetBlock(nn.Module):
if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
# def __prepare_scriptable__(self):
# m = [self.conv_0, self.conv_1, self.norm_0, self.norm_1]
# for module in m:
# for hook in module._forward_pre_hooks.values():
# # The hook we want to remove is an instance of WeightNorm class, so
# # normally we would do `if isinstance(...)` but this class is not accessible
# # because of shadowing, so we check the module name directly.
# # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
# if hook.__module__ == "torch.nn.utils.spectral_norm" and hook.__class__.__name__ == "SpectralNorm":
# print(f"Remove spectral norm from {module}")
# torch.nn.utils.remove_spectral_norm(module)
# module.eval()
# return self
def forward(self, x, seg1):
# print(x.shape, seg1.shape)
x_s = self.shortcut(x, seg1)
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))

View File

@ -134,7 +134,7 @@ def load_model(ckpt_path, model_config, device, model_type):
else:
raise ValueError(f"Unknown model type: {model_type}")
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True)
model.eval()
return model