diff --git a/.gitignore b/.gitignore index 1f85f19..c68fd1e 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,4 @@ pretrained_weights/docs # Temporary files or benchmark resources animations/* tmp/* -.vscode/launch.json +venv/* diff --git a/inference.py b/inference.py index dd7a768..cd6d16b 100644 --- a/inference.py +++ b/inference.py @@ -52,7 +52,9 @@ def main(): # run live_portrait_pipeline.execute(args) - + # live_portrait_pipeline_cp = torch.compile(live_portrait_pipeline.execute, backend="inductor") + # with torch.no_grad(): + # live_portrait_pipeline_cp(args) if __name__ == "__main__": main() diff --git a/nuke_live_portrait.py b/nuke_live_portrait.py new file mode 100644 index 0000000..d289d14 --- /dev/null +++ b/nuke_live_portrait.py @@ -0,0 +1,53 @@ +import torch +from src.utils.helper import load_model + +model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}}, + + +def trace_appearance_feature_extractor(): + appearance_feature_extractor = load_model( + ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth", + model_config=model_config, + device=0, + model_type='appearance_feature_extractor') + + with torch.no_grad(): + appearance_feature_extractor.eval() + appearance_feature_extractor = torch.jit.script(appearance_feature_extractor) + + torch.jit.save(appearance_feature_extractor, "build/appearance_feature_extractor.pt") + +# def trace_appearance_feature_extractor(): + + +def trace_motion_extractor(): + motion_extractor = load_model( + ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth", + model_config=model_config, + device=0, + model_type='motion_extractor') + + with torch.no_grad(): + motion_extractor.eval() + motion_extractor = torch.jit.script(motion_extractor) + + # torch.jit.save(motion_extractor, "build/motion_extractor.pt") + +trace_motion_extractor() + + +def trace_warping_module(): + warping_module = load_model( + ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/warping_module.pth", + model_config=model_config, + device=0, + model_type='warping_module') + + with torch.no_grad(): + warping_module.eval() + warping_module = torch.jit.script(warping_module) + + torch.jit.save(warping_module, "build/warping_module.pt") + +# def trace_warping_module(): + diff --git a/src/modules/appearance_feature_extractor.py b/src/modules/appearance_feature_extractor.py index 8d89e4f..8b7e85c 100644 --- a/src/modules/appearance_feature_extractor.py +++ b/src/modules/appearance_feature_extractor.py @@ -38,8 +38,8 @@ class AppearanceFeatureExtractor(nn.Module): def forward(self, source_image): out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 - for i in range(len(self.down_blocks)): - out = self.down_blocks[i](out) + for _, block in enumerate(self.down_blocks): + out = block(out) out = self.second(out) bs, c, h, w = out.shape # ->Bx512x64x64 diff --git a/src/modules/util.py b/src/modules/util.py index f83980b..d8e7768 100644 --- a/src/modules/util.py +++ b/src/modules/util.py @@ -5,6 +5,7 @@ This file defines various neural network modules and utility functions, includin normalizations, and functions for spatial transformation and tensor manipulation. """ +from typing import List from torch import nn import torch.nn.functional as F import torch @@ -13,7 +14,7 @@ import math import warnings -def kp2gaussian(kp, spatial_size, kp_variance): +def kp2gaussian(kp, spatial_size: List[int], kp_variance: float): """ Transform a keypoint into gaussian like representation """ @@ -22,13 +23,13 @@ def kp2gaussian(kp, spatial_size, kp_variance): coordinate_grid = make_coordinate_grid(spatial_size, mean) number_of_leading_dimensions = len(mean.shape) - 1 shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape - coordinate_grid = coordinate_grid.view(*shape) + coordinate_grid = coordinate_grid.view(torch.Size(shape)) repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) - coordinate_grid = coordinate_grid.repeat(*repeats) + coordinate_grid = coordinate_grid.repeat(torch.Size(repeats)) # Preprocess kp shape shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) - mean = mean.view(*shape) + mean = mean.view(torch.Size(shape)) mean_sub = (coordinate_grid - mean) @@ -37,7 +38,7 @@ def kp2gaussian(kp, spatial_size, kp_variance): return out -def make_coordinate_grid(spatial_size, ref, **kwargs): +def make_coordinate_grid(spatial_size: List[int], ref): d, h, w = spatial_size x = torch.arange(w).type(ref.dtype).to(ref.device) y = torch.arange(h).type(ref.dtype).to(ref.device) @@ -112,7 +113,7 @@ class UpBlock3d(nn.Module): self.norm = nn.BatchNorm3d(out_features, affine=True) def forward(self, x): - out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = F.interpolate(x, scale_factor=(1.0, 2.0, 2.0)) out = self.conv(out) out = self.norm(out) out = F.relu(out) @@ -224,7 +225,7 @@ class Decoder(nn.Module): self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) self.norm = nn.BatchNorm3d(self.out_filters, affine=True) - def forward(self, x): + def forward(self, x: List[torch.Tensor]): out = x.pop() for up_block in self.up_blocks: out = up_block(out)