mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-15 14:02:12 +00:00
spade_generator compiled
This commit is contained in:
parent
314b71a3bd
commit
dbb3bf87f3
@ -1,52 +1,7 @@
|
||||
import torch
|
||||
from src.utils.helper import load_model
|
||||
|
||||
model_config = (
|
||||
{
|
||||
'model_params': {
|
||||
'appearance_feature_extractor_params': {
|
||||
'image_channel': 3,
|
||||
'block_expansion': 64,
|
||||
'num_down_blocks': 2,
|
||||
'max_features': 512,
|
||||
'reshape_channel': 32,
|
||||
'reshape_depth': 16,
|
||||
'num_resblocks': 6,
|
||||
},
|
||||
'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'},
|
||||
'warping_module_params': {
|
||||
'num_kp': 21,
|
||||
'block_expansion': 64,
|
||||
'max_features': 512,
|
||||
'num_down_blocks': 2,
|
||||
'reshape_channel': 32,
|
||||
'estimate_occlusion_map': True,
|
||||
'dense_motion_params': {
|
||||
'block_expansion': 32,
|
||||
'max_features': 1024,
|
||||
'num_blocks': 5,
|
||||
'reshape_depth': 16,
|
||||
'compress': 4,
|
||||
},
|
||||
},
|
||||
'spade_generator_params': {
|
||||
'upscale': 2,
|
||||
'block_expansion': 64,
|
||||
'max_features': 512,
|
||||
'num_down_blocks': 2,
|
||||
},
|
||||
'stitching_retargeting_module_params': {
|
||||
'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65},
|
||||
'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63},
|
||||
'eye': {
|
||||
'input_size': 66,
|
||||
'hidden_sizes': [256, 256, 128, 128, 64],
|
||||
'output_size': 63,
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}}
|
||||
|
||||
|
||||
def trace_appearance_feature_extractor():
|
||||
@ -70,20 +25,21 @@ def trace_appearance_feature_extractor():
|
||||
def trace_motion_extractor():
|
||||
motion_extractor = load_model(
|
||||
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth",
|
||||
model_config=model_config,
|
||||
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}},
|
||||
device=0,
|
||||
model_type='motion_extractor',
|
||||
)
|
||||
# print(motion_extractor)
|
||||
|
||||
with torch.no_grad():
|
||||
motion_extractor.eval()
|
||||
motion_extractor = torch.jit.script(motion_extractor)
|
||||
# with torch.no_grad():
|
||||
# motion_extractor.eval()
|
||||
# torch.jit.script(self.motion_extractor)
|
||||
|
||||
motion_extractor = torch.jit.script(motion_extractor)
|
||||
|
||||
# torch.jit.save(motion_extractor, "build/motion_extractor.pt")
|
||||
|
||||
|
||||
trace_motion_extractor()
|
||||
torch.jit.save(motion_extractor, "build/motion_extractor.pt")
|
||||
|
||||
# trace_motion_extractor()
|
||||
|
||||
def trace_warping_module():
|
||||
warping_module = load_model(
|
||||
@ -101,3 +57,212 @@ def trace_warping_module():
|
||||
|
||||
|
||||
# def trace_warping_module():
|
||||
|
||||
|
||||
def trace_spade_generator():
|
||||
spade_generator = load_model(
|
||||
ckpt_path='./pretrained_weights/liveportrait/base_models/spade_generator.pth',
|
||||
model_config=model_config,
|
||||
device=0,
|
||||
model_type='spade_generator',
|
||||
)
|
||||
|
||||
# with torch.no_grad():
|
||||
# spade_generator.eval()
|
||||
# print(spade_generator)
|
||||
spade_generator = torch.jit.script(spade_generator)
|
||||
torch.jit.save(spade_generator, "build/spade_generator.pt")
|
||||
|
||||
trace_spade_generator()
|
||||
|
||||
|
||||
|
||||
def trace_stitching_retargeting_module():
|
||||
stitching_retargeting_module = load_model(
|
||||
ckpt_path='/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth',
|
||||
model_config=model_config,
|
||||
device=0,
|
||||
model_type='stitching_retargeting_module',
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
stitching_retargeting_module.eval()
|
||||
stitching_retargeting_module = torch.jit.script(stitching_retargeting_module)
|
||||
|
||||
torch.jit.save(stitching_retargeting_module, "build/stitching_retargeting_module.pt")
|
||||
|
||||
|
||||
# trace_stitching_retargeting_module()
|
||||
|
||||
'''
|
||||
class SPADEResnetBlock(nn.Module):
|
||||
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
|
||||
super().__init__()
|
||||
# Attributes
|
||||
self.learned_shortcut = (fin != fout)
|
||||
fmiddle = min(fin, fout)
|
||||
self.use_se = use_se
|
||||
print(f"SPADEResnetBlock: fin={fin}, fout={fout}, norm_G={norm_G}, fmiddle={fmiddle}, learned_shortcut={self.learned_shortcut}, use_se={use_se}")
|
||||
# create conv layers
|
||||
# self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
|
||||
# self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
|
||||
# if self.learned_shortcut:
|
||||
# self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
||||
# apply spectral norm if specified
|
||||
# if 'spectral' in norm_G:
|
||||
self.conv_0 = spectral_norm(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation))
|
||||
# self.conv_0: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation), "weight", 1, 0, 1e-12)
|
||||
self.conv_1 = spectral_norm(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation))
|
||||
# self.conv_1: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation), "weight", 1, 0, 1e-12)
|
||||
# if self.learned_shortcut:
|
||||
self.conv_s = spectral_norm(nn.Conv2d(fin, fout, kernel_size=1, bias=False))
|
||||
# self.conv_s: SpectralNorm = SpectralNorm.apply(nn.Conv2d(fin, fout, kernel_size=1, bias=False), "weight", 1, 0, 1e-12)
|
||||
# define normalization layers
|
||||
self.norm_0 = SPADE(fin, label_nc)
|
||||
self.norm_1 = SPADE(fmiddle, label_nc)
|
||||
# if self.learned_shortcut:
|
||||
self.norm_s = SPADE(fin, label_nc)
|
||||
|
||||
|
||||
from src.modules.util import SPADEResnetBlock
|
||||
from torch.nn.utils.spectral_norm import spectral_norm, SpectralNorm
|
||||
from torch.nn.utils.parametrizations import spectral_norm
|
||||
|
||||
fin=512; fout=512; norm_G='spadespectralinstance'; label_nc=256; use_se=False; dilation=1
|
||||
fmiddle = min(fin, fout)
|
||||
c = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
|
||||
s = spectral_norm(c)
|
||||
j = spectral_norm(c).eval()
|
||||
torch.jit.script(s)
|
||||
torch.jit.trace(s, torch.randn(1, 512, 64, 64))
|
||||
m = torch.jit.trace(j, torch.randn(1, 512, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
torch.jit.script(s)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
torch.jit.trace(s, torch.randn(1, 512, 64, 64))
|
||||
|
||||
|
||||
|
||||
sp = SPADEResnetBlock(fin=512, fout=512, norm_G="spadespectralinstance", label_nc=256, use_se=False, dilation=1)
|
||||
sp.eval()
|
||||
sp_traced = torch.jit.trace(sp, (torch.randn(1, 512, 64, 64), torch.randn(1, 256, 64, 64)))
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
with torch.no_grad():
|
||||
c1 = torch.nn.utils.spectral_norm(torch.nn.Conv2d(1,2,3))
|
||||
torch.jit.script(c1)
|
||||
|
||||
|
||||
c2 = torch.nn.utils.parametrizations.spectral_norm(torch.nn.Conv2d(1,2,3))
|
||||
torch.jit.script(c2)
|
||||
|
||||
with torch.no_grad():
|
||||
c2 = torch.nn.utils.parametrizations.spectral_norm(torch.nn.Conv2d(1,2,3))
|
||||
torch.jit.script(c2)
|
||||
|
||||
l = nn.Linear(20, 40)
|
||||
l.weight.size()
|
||||
|
||||
m = torch.nn.utils.spectral_norm(nn.Linear(20, 40))
|
||||
m
|
||||
m.weight_u.size()
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
def modify_state_dict_inplace(model):
|
||||
state_dict = model.state_dict()
|
||||
keys_to_delete = []
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
k = key.split(".")
|
||||
|
||||
if len(k) == 3 and k[-1] == "weight" and k[-2] in ["conv_0", "conv_1"]:
|
||||
# Register new parameters
|
||||
model.register_parameter(f"{k[0]}.{k[-2]}.weight_orig", torch.nn.Parameter(value))
|
||||
model.register_parameter(f"{k[0]}.{k[-2]}.weight_u", torch.nn.Parameter(torch.zeros_like(value)))
|
||||
model.register_parameter(f"{k[0]}.{k[-2]}.weight_v", torch.nn.Parameter(torch.zeros_like(value)))
|
||||
|
||||
state_dict[f"{k[0]}.{k[-2]}.weight_orig"] = value
|
||||
keys_to_delete.append(key)
|
||||
state_dict[f"{k[0]}.{k[-2]}.weight_u"] = torch.zeros_like(value)
|
||||
state_dict[f"{k[0]}.{k[-2]}.weight_v"] = torch.zeros_like(value)
|
||||
|
||||
# for key in keys_to_delete:
|
||||
# delattr(module, key)
|
||||
|
||||
# Load the modified state_dict back into the model
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
model = SPADEDecoder(**model_params).cuda(device)
|
||||
|
||||
|
||||
modify_state_dict_inplace(model)
|
||||
model.state_dict().keys()
|
||||
|
||||
model._parameters[name]
|
||||
|
||||
model = modify_state_dict_inplace(model)
|
||||
model.state_dict().keys()
|
||||
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True)
|
||||
|
||||
|
||||
|
||||
modify_state_dict_inplace(model.state_dict())
|
||||
|
||||
modified_state_dict.keys()
|
||||
model.register_parameter("aaalero", torch.nn.Parameter(torch.randn(1, 2, 3)))
|
||||
model.state_dict = modified_state_dict
|
||||
model = SPADEDecoder(**model_params).cuda(device)
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True)
|
||||
|
||||
type(state_dict.keys())
|
||||
type(modified_state_dict)
|
||||
|
||||
missing_keys = ["G_middle_0.conv_0.weight_orig", "G_middle_0.conv_0.weight_u", "G_middle_0.conv_0.weight_v", "G_middle_0.conv_1.weight_orig", "G_middle_0.conv_1.weight_u", "G_middle_0.conv_1.weight_v", "G_middle_1.conv_0.weight_orig", "G_middle_1.conv_0.weight_u", "G_middle_1.conv_0.weight_v", "G_middle_1.conv_1.weight_orig", "G_middle_1.conv_1.weight_u", "G_middle_1.conv_1.weight_v", "G_middle_2.conv_0.weight_orig", "G_middle_2.conv_0.weight_u", "G_middle_2.conv_0.weight_v", "G_middle_2.conv_1.weight_orig", "G_middle_2.conv_1.weight_u", "G_middle_2.conv_1.weight_v", "G_middle_3.conv_0.weight_orig", "G_middle_3.conv_0.weight_u", "G_middle_3.conv_0.weight_v", "G_middle_3.conv_1.weight_orig", "G_middle_3.conv_1.weight_u", "G_middle_3.conv_1.weight_v", "G_middle_4.conv_0.weight_orig", "G_middle_4.conv_0.weight_u", "G_middle_4.conv_0.weight_v", "G_middle_4.conv_1.weight_orig", "G_middle_4.conv_1.weight_u", "G_middle_4.conv_1.weight_v", "G_middle_5.conv_0.weight_orig", "G_middle_5.conv_0.weight_u", "G_middle_5.conv_0.weight_v", "G_middle_5.conv_1.weight_orig", "G_middle_5.conv_1.weight_u", "G_middle_5.conv_1.weight_v", "up_0.conv_0.weight_orig", "up_0.conv_0.weight_u", "up_0.conv_0.weight_v", "up_0.conv_1.weight_orig", "up_0.conv_1.weight_u", "up_0.conv_1.weight_v", "up_0.conv_s.weight_orig", "up_0.conv_s.weight_u", "up_0.conv_s.weight_v", "up_1.conv_0.weight_orig", "up_1.conv_0.weight_u", "up_1.conv_0.weight_v", "up_1.conv_1.weight_orig", "up_1.conv_1.weight_u", "up_1.conv_1.weight_v", "up_1.conv_s.weight_orig", "up_1.conv_s.weight_u", "up_1.conv_s.weight_v"]
|
||||
|
||||
missing_keys in list(modified_state_dict.keys())
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def apply(module: Module, name: str):
|
||||
weight = module._parameters[name]
|
||||
|
||||
delattr(module, fn.name)
|
||||
module.register_parameter(fn.name + "_orig", weight)
|
||||
# We still need to assign weight back as fn.name because all sorts of
|
||||
# things may assume that it exists, e.g., when initializing weights.
|
||||
# However, we can't directly assign as it could be an nn.Parameter and
|
||||
# gets added as a parameter. Instead, we register weight.data as a plain
|
||||
# attribute.
|
||||
setattr(module, fn.name, weight.data)
|
||||
module.register_buffer(fn.name + "_u", u)
|
||||
module.register_buffer(fn.name + "_v", v)
|
||||
|
||||
module.register_forward_pre_hook(fn)
|
||||
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
|
||||
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
|
||||
return fn
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
'''
|
@ -20,8 +20,10 @@ class MotionExtractor(nn.Module):
|
||||
super(MotionExtractor, self).__init__()
|
||||
|
||||
# default is convnextv2_base
|
||||
backbone = kwargs.get('backbone', 'convnextv2_tiny')
|
||||
self.detector = model_dict.get(backbone)(**kwargs)
|
||||
# backbone = kwargs.get('backbone', 'convnextv2_tiny')
|
||||
# self.detector = model_dict.get(backbone)(**kwargs)
|
||||
self.detector = convnextv2_tiny(num_kp=21)
|
||||
# print("---> %s", kwargs)
|
||||
|
||||
def load_pretrained(self, init_path: str):
|
||||
if init_path not in (None, ''):
|
||||
|
@ -37,23 +37,48 @@ class SPADEDecoder(nn.Module):
|
||||
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
|
||||
nn.PixelShuffle(upscale_factor=2)
|
||||
)
|
||||
self.patch_nonscriptable_classes()
|
||||
|
||||
def forward(self, feature):
|
||||
seg = feature # Bx256x64x64
|
||||
x = self.fc(feature) # Bx512x64x64
|
||||
print("self.G_middle_0: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_0(x, seg)
|
||||
print("self.G_middle_1: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_1(x, seg)
|
||||
print("self.G_middle_2: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_2(x, seg)
|
||||
print("self.G_middle_3: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_3(x, seg)
|
||||
print("self.G_middle_4: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_4(x, seg)
|
||||
print("self.G_middle_5: %s", x.shape, seg.shape)
|
||||
x = self.G_middle_5(x, seg)
|
||||
|
||||
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
|
||||
print("self.up_0: %s", x.shape, seg.shape)
|
||||
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
|
||||
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
|
||||
print("self.up_1: %s", x.shape, seg.shape)
|
||||
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
|
||||
|
||||
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
|
||||
x = torch.sigmoid(x) # Bx3xHxW
|
||||
|
||||
return x
|
||||
return x
|
||||
|
||||
def patch_nonscriptable_classes(self):
|
||||
in_1 = torch.rand(1, 512, 64, 64)
|
||||
in_2 = torch.rand(1, 256, 64, 64)
|
||||
|
||||
in_3 = torch.rand(1, 512, 128, 128)
|
||||
in_4 = torch.rand(1, 256, 256, 256)
|
||||
|
||||
self.G_middle_0 = torch.jit.trace(self.G_middle_0.eval(), (in_1, in_2))
|
||||
self.G_middle_1 = torch.jit.trace(self.G_middle_1.eval(), (in_1, in_2))
|
||||
self.G_middle_2 = torch.jit.trace(self.G_middle_2.eval(), (in_1, in_2))
|
||||
self.G_middle_3 = torch.jit.trace(self.G_middle_3.eval(), (in_1, in_2))
|
||||
self.G_middle_4 = torch.jit.trace(self.G_middle_4.eval(), (in_1, in_2))
|
||||
self.G_middle_5 = torch.jit.trace(self.G_middle_5.eval(), (in_1, in_2))
|
||||
self.up_0 = torch.jit.trace(self.up_0.eval(), (in_3, in_2))
|
||||
self.up_1 = torch.jit.trace(self.up_1.eval(), (in_4, in_2))
|
||||
|
@ -9,7 +9,9 @@ from typing import List
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import torch.nn.utils.spectral_norm as spectral_norm
|
||||
from torch.nn.utils.spectral_norm import spectral_norm
|
||||
# from torch.nn.utils.parametrizations import spectral_norm
|
||||
|
||||
import math
|
||||
import warnings
|
||||
|
||||
@ -274,10 +276,10 @@ class SPADE(nn.Module):
|
||||
out = normalized * (1 + gamma) + beta
|
||||
return out
|
||||
|
||||
|
||||
class SPADEResnetBlock(nn.Module):
|
||||
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
|
||||
super().__init__()
|
||||
# print(f"{fin=}, {fout=}, {norm_G=}, {label_nc=}, {use_se=}, {dilation=}")
|
||||
# Attributes
|
||||
self.learned_shortcut = (fin != fout)
|
||||
fmiddle = min(fin, fout)
|
||||
@ -299,7 +301,23 @@ class SPADEResnetBlock(nn.Module):
|
||||
if self.learned_shortcut:
|
||||
self.norm_s = SPADE(fin, label_nc)
|
||||
|
||||
|
||||
# def __prepare_scriptable__(self):
|
||||
# m = [self.conv_0, self.conv_1, self.norm_0, self.norm_1]
|
||||
# for module in m:
|
||||
# for hook in module._forward_pre_hooks.values():
|
||||
# # The hook we want to remove is an instance of WeightNorm class, so
|
||||
# # normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# # because of shadowing, so we check the module name directly.
|
||||
# # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
# if hook.__module__ == "torch.nn.utils.spectral_norm" and hook.__class__.__name__ == "SpectralNorm":
|
||||
# print(f"Remove spectral norm from {module}")
|
||||
# torch.nn.utils.remove_spectral_norm(module)
|
||||
# module.eval()
|
||||
# return self
|
||||
|
||||
def forward(self, x, seg1):
|
||||
# print(x.shape, seg1.shape)
|
||||
x_s = self.shortcut(x, seg1)
|
||||
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
|
||||
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
|
||||
|
@ -134,7 +134,7 @@ def load_model(ckpt_path, model_config, device, model_type):
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user