This commit is contained in:
Rafael Silva 2024-07-10 20:11:23 -04:00
parent 54e50986b2
commit b836f49555
5 changed files with 67 additions and 11 deletions

2
.gitignore vendored
View File

@ -18,4 +18,4 @@ pretrained_weights/docs
# Temporary files or benchmark resources
animations/*
tmp/*
.vscode/launch.json
venv/*

View File

@ -52,7 +52,9 @@ def main():
# run
live_portrait_pipeline.execute(args)
# live_portrait_pipeline_cp = torch.compile(live_portrait_pipeline.execute, backend="inductor")
# with torch.no_grad():
# live_portrait_pipeline_cp(args)
if __name__ == "__main__":
main()

53
nuke_live_portrait.py Normal file
View File

@ -0,0 +1,53 @@
import torch
from src.utils.helper import load_model
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}},
def trace_appearance_feature_extractor():
appearance_feature_extractor = load_model(
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth",
model_config=model_config,
device=0,
model_type='appearance_feature_extractor')
with torch.no_grad():
appearance_feature_extractor.eval()
appearance_feature_extractor = torch.jit.script(appearance_feature_extractor)
torch.jit.save(appearance_feature_extractor, "build/appearance_feature_extractor.pt")
# def trace_appearance_feature_extractor():
def trace_motion_extractor():
motion_extractor = load_model(
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth",
model_config=model_config,
device=0,
model_type='motion_extractor')
with torch.no_grad():
motion_extractor.eval()
motion_extractor = torch.jit.script(motion_extractor)
# torch.jit.save(motion_extractor, "build/motion_extractor.pt")
trace_motion_extractor()
def trace_warping_module():
warping_module = load_model(
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/warping_module.pth",
model_config=model_config,
device=0,
model_type='warping_module')
with torch.no_grad():
warping_module.eval()
warping_module = torch.jit.script(warping_module)
torch.jit.save(warping_module, "build/warping_module.pt")
# def trace_warping_module():

View File

@ -38,8 +38,8 @@ class AppearanceFeatureExtractor(nn.Module):
def forward(self, source_image):
out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
for _, block in enumerate(self.down_blocks):
out = block(out)
out = self.second(out)
bs, c, h, w = out.shape # ->Bx512x64x64

View File

@ -5,6 +5,7 @@ This file defines various neural network modules and utility functions, includin
normalizations, and functions for spatial transformation and tensor manipulation.
"""
from typing import List
from torch import nn
import torch.nn.functional as F
import torch
@ -13,7 +14,7 @@ import math
import warnings
def kp2gaussian(kp, spatial_size, kp_variance):
def kp2gaussian(kp, spatial_size: List[int], kp_variance: float):
"""
Transform a keypoint into gaussian like representation
"""
@ -22,13 +23,13 @@ def kp2gaussian(kp, spatial_size, kp_variance):
coordinate_grid = make_coordinate_grid(spatial_size, mean)
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
coordinate_grid = coordinate_grid.view(*shape)
coordinate_grid = coordinate_grid.view(torch.Size(shape))
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
coordinate_grid = coordinate_grid.repeat(*repeats)
coordinate_grid = coordinate_grid.repeat(torch.Size(repeats))
# Preprocess kp shape
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
mean = mean.view(*shape)
mean = mean.view(torch.Size(shape))
mean_sub = (coordinate_grid - mean)
@ -37,7 +38,7 @@ def kp2gaussian(kp, spatial_size, kp_variance):
return out
def make_coordinate_grid(spatial_size, ref, **kwargs):
def make_coordinate_grid(spatial_size: List[int], ref):
d, h, w = spatial_size
x = torch.arange(w).type(ref.dtype).to(ref.device)
y = torch.arange(h).type(ref.dtype).to(ref.device)
@ -112,7 +113,7 @@ class UpBlock3d(nn.Module):
self.norm = nn.BatchNorm3d(out_features, affine=True)
def forward(self, x):
out = F.interpolate(x, scale_factor=(1, 2, 2))
out = F.interpolate(x, scale_factor=(1.0, 2.0, 2.0))
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
@ -224,7 +225,7 @@ class Decoder(nn.Module):
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
def forward(self, x):
def forward(self, x: List[torch.Tensor]):
out = x.pop()
for up_block in self.up_blocks:
out = up_block(out)