more black

This commit is contained in:
Rafael Silva 2024-07-11 22:36:33 -04:00
parent dbb3bf87f3
commit a8e40677dd

View File

@ -1,7 +1,63 @@
import torch
import logging
from src.utils.helper import load_model
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}}
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
model_config = {
"model_params": {
"appearance_feature_extractor_params": {
"image_channel": 3,
"block_expansion": 64,
"num_down_blocks": 2,
"max_features": 512,
"reshape_channel": 32,
"reshape_depth": 16,
"num_resblocks": 6,
},
"motion_extractor_params": {"num_kp": 21, "backbone": "convnextv2_tiny"},
"warping_module_params": {
"num_kp": 21,
"block_expansion": 64,
"max_features": 512,
"num_down_blocks": 2,
"reshape_channel": 32,
"estimate_occlusion_map": True,
"dense_motion_params": {
"block_expansion": 32,
"max_features": 1024,
"num_blocks": 5,
"reshape_depth": 16,
"compress": 4,
},
},
"spade_generator_params": {
"upscale": 2,
"block_expansion": 64,
"max_features": 512,
"num_down_blocks": 2,
},
"stitching_retargeting_module_params": {
"stitching": {
"input_size": 126,
"hidden_sizes": [128, 128, 64],
"output_size": 65,
},
"lip": {
"input_size": 65,
"hidden_sizes": [128, 128, 64],
"output_size": 63,
},
"eye": {
"input_size": 66,
"hidden_sizes": [256, 256, 128, 128, 64],
"output_size": 63,
},
},
}
}
def trace_appearance_feature_extractor():
@ -9,120 +65,96 @@ def trace_appearance_feature_extractor():
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth",
model_config=model_config,
device=0,
model_type='appearance_feature_extractor',
model_type="appearance_feature_extractor",
)
with torch.no_grad():
appearance_feature_extractor.eval()
appearance_feature_extractor = torch.jit.script(appearance_feature_extractor)
LOGGER.info("Traced appearance_feature_extractor")
torch.jit.save(appearance_feature_extractor, "build/appearance_feature_extractor.pt")
# def trace_appearance_feature_extractor():
trace_appearance_feature_extractor()
def trace_motion_extractor():
motion_extractor = load_model(
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth",
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}},
model_config=model_config,
device=0,
model_type='motion_extractor',
model_type="motion_extractor",
)
# print(motion_extractor)
# with torch.no_grad():
# motion_extractor.eval()
# torch.jit.script(self.motion_extractor)
motion_extractor = torch.jit.script(motion_extractor)
with torch.no_grad():
motion_extractor.eval()
motion_extractor = torch.jit.script(motion_extractor)
LOGGER.info("Traced motion_extractor")
torch.jit.save(motion_extractor, "build/motion_extractor.pt")
# trace_motion_extractor()
def trace_warping_module():
warping_module = load_model(
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/warping_module.pth",
model_config=model_config,
device=0,
model_type='warping_module',
model_type="warping_module",
)
with torch.no_grad():
warping_module.eval()
warping_module = torch.jit.script(warping_module)
LOGGER.info("Traced warping_module")
torch.jit.save(warping_module, "build/warping_module.pt")
# def trace_warping_module():
trace_warping_module()
def trace_spade_generator():
spade_generator = load_model(
ckpt_path='./pretrained_weights/liveportrait/base_models/spade_generator.pth',
ckpt_path="./pretrained_weights/liveportrait/base_models/spade_generator.pth",
model_config=model_config,
device=0,
model_type='spade_generator',
model_type="spade_generator",
)
# with torch.no_grad():
# spade_generator.eval()
# print(spade_generator)
spade_generator = torch.jit.script(spade_generator)
with torch.no_grad():
spade_generator.eval()
spade_generator = torch.jit.script(spade_generator)
LOGGER.info("Traced spade_generator")
torch.jit.save(spade_generator, "build/spade_generator.pt")
trace_spade_generator()
def trace_stitching_retargeting_module():
stitching_retargeting_module = load_model(
ckpt_path='/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth',
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth",
model_config=model_config,
device=0,
model_type='stitching_retargeting_module',
model_type="stitching_retargeting_module",
)
with torch.no_grad():
stitching_retargeting_module.eval()
stitching_retargeting_module = torch.jit.script(stitching_retargeting_module)
LOGGER.info("Traced stitching_retargeting_module")
torch.jit.save(stitching_retargeting_module, "build/stitching_retargeting_module.pt")
# trace_stitching_retargeting_module()
'''
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.use_se = use_se
print(f"SPADEResnetBlock: fin={fin}, fout={fout}, norm_G={norm_G}, fmiddle={fmiddle}, learned_shortcut={self.learned_shortcut}, use_se={use_se}")
# create conv layers
# self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
# self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
# if self.learned_shortcut:
# self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
# if 'spectral' in norm_G:
self.conv_0 = spectral_norm(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation))
# self.conv_0: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation), "weight", 1, 0, 1e-12)
self.conv_1 = spectral_norm(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation))
# self.conv_1: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation), "weight", 1, 0, 1e-12)
# if self.learned_shortcut:
self.conv_s = spectral_norm(nn.Conv2d(fin, fout, kernel_size=1, bias=False))
# self.conv_s: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fin, fout, kernel_size=1, bias=False), "weight", 1, 0, 1e-12)
# define normalization layers
self.norm_0 = SPADE(fin, label_nc)
self.norm_1 = SPADE(fmiddle, label_nc)
# if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
"""
from src.modules.util import SPADEResnetBlock
from torch.nn.utils.spectral_norm import spectral_norm, SpectralNorm
@ -173,8 +205,6 @@ m = torch.nn.utils.spectral_norm(nn.Linear(20, 40))
m
m.weight_u.size()
---
def modify_state_dict_inplace(model):
@ -234,11 +264,6 @@ missing_keys = ["G_middle_0.conv_0.weight_orig", "G_middle_0.conv_0.weight_u", "
missing_keys in list(modified_state_dict.keys())
def apply(module: Module, name: str):
weight = module._parameters[name]
@ -260,9 +285,4 @@ def apply(module: Module, name: str):
'''
"""