mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-03-15 05:52:58 +00:00
wip
This commit is contained in:
parent
54e50986b2
commit
b836f49555
2
.gitignore
vendored
2
.gitignore
vendored
@ -18,4 +18,4 @@ pretrained_weights/docs
|
||||
# Temporary files or benchmark resources
|
||||
animations/*
|
||||
tmp/*
|
||||
.vscode/launch.json
|
||||
venv/*
|
||||
|
@ -52,7 +52,9 @@ def main():
|
||||
|
||||
# run
|
||||
live_portrait_pipeline.execute(args)
|
||||
|
||||
# live_portrait_pipeline_cp = torch.compile(live_portrait_pipeline.execute, backend="inductor")
|
||||
# with torch.no_grad():
|
||||
# live_portrait_pipeline_cp(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
53
nuke_live_portrait.py
Normal file
53
nuke_live_portrait.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from src.utils.helper import load_model
|
||||
|
||||
model_config = {'model_params': {'appearance_feature_extractor_params': {'image_channel': 3, 'block_expansion': 64, 'num_down_blocks': 2, 'max_features': 512, 'reshape_channel': 32, 'reshape_depth': 16, 'num_resblocks': 6}, 'motion_extractor_params': {'num_kp': 21, 'backbone': 'convnextv2_tiny'}, 'warping_module_params': {'num_kp': 21, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2, 'reshape_channel': 32, 'estimate_occlusion_map': True, 'dense_motion_params': {'block_expansion': 32, 'max_features': 1024, 'num_blocks': 5, 'reshape_depth': 16, 'compress': 4}}, 'spade_generator_params': {'upscale': 2, 'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2}, 'stitching_retargeting_module_params': {'stitching': {'input_size': 126, 'hidden_sizes': [128, 128, 64], 'output_size': 65}, 'lip': {'input_size': 65, 'hidden_sizes': [128, 128, 64], 'output_size': 63}, 'eye': {'input_size': 66, 'hidden_sizes': [256, 256, 128, 128, 64], 'output_size': 63}}}},
|
||||
|
||||
|
||||
def trace_appearance_feature_extractor():
|
||||
appearance_feature_extractor = load_model(
|
||||
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth",
|
||||
model_config=model_config,
|
||||
device=0,
|
||||
model_type='appearance_feature_extractor')
|
||||
|
||||
with torch.no_grad():
|
||||
appearance_feature_extractor.eval()
|
||||
appearance_feature_extractor = torch.jit.script(appearance_feature_extractor)
|
||||
|
||||
torch.jit.save(appearance_feature_extractor, "build/appearance_feature_extractor.pt")
|
||||
|
||||
# def trace_appearance_feature_extractor():
|
||||
|
||||
|
||||
def trace_motion_extractor():
|
||||
motion_extractor = load_model(
|
||||
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth",
|
||||
model_config=model_config,
|
||||
device=0,
|
||||
model_type='motion_extractor')
|
||||
|
||||
with torch.no_grad():
|
||||
motion_extractor.eval()
|
||||
motion_extractor = torch.jit.script(motion_extractor)
|
||||
|
||||
# torch.jit.save(motion_extractor, "build/motion_extractor.pt")
|
||||
|
||||
trace_motion_extractor()
|
||||
|
||||
|
||||
def trace_warping_module():
|
||||
warping_module = load_model(
|
||||
ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/warping_module.pth",
|
||||
model_config=model_config,
|
||||
device=0,
|
||||
model_type='warping_module')
|
||||
|
||||
with torch.no_grad():
|
||||
warping_module.eval()
|
||||
warping_module = torch.jit.script(warping_module)
|
||||
|
||||
torch.jit.save(warping_module, "build/warping_module.pt")
|
||||
|
||||
# def trace_warping_module():
|
||||
|
@ -38,8 +38,8 @@ class AppearanceFeatureExtractor(nn.Module):
|
||||
def forward(self, source_image):
|
||||
out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
|
||||
|
||||
for i in range(len(self.down_blocks)):
|
||||
out = self.down_blocks[i](out)
|
||||
for _, block in enumerate(self.down_blocks):
|
||||
out = block(out)
|
||||
out = self.second(out)
|
||||
bs, c, h, w = out.shape # ->Bx512x64x64
|
||||
|
||||
|
@ -5,6 +5,7 @@ This file defines various neural network modules and utility functions, includin
|
||||
normalizations, and functions for spatial transformation and tensor manipulation.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
@ -13,7 +14,7 @@ import math
|
||||
import warnings
|
||||
|
||||
|
||||
def kp2gaussian(kp, spatial_size, kp_variance):
|
||||
def kp2gaussian(kp, spatial_size: List[int], kp_variance: float):
|
||||
"""
|
||||
Transform a keypoint into gaussian like representation
|
||||
"""
|
||||
@ -22,13 +23,13 @@ def kp2gaussian(kp, spatial_size, kp_variance):
|
||||
coordinate_grid = make_coordinate_grid(spatial_size, mean)
|
||||
number_of_leading_dimensions = len(mean.shape) - 1
|
||||
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
||||
coordinate_grid = coordinate_grid.view(*shape)
|
||||
coordinate_grid = coordinate_grid.view(torch.Size(shape))
|
||||
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
|
||||
coordinate_grid = coordinate_grid.repeat(*repeats)
|
||||
coordinate_grid = coordinate_grid.repeat(torch.Size(repeats))
|
||||
|
||||
# Preprocess kp shape
|
||||
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
|
||||
mean = mean.view(*shape)
|
||||
mean = mean.view(torch.Size(shape))
|
||||
|
||||
mean_sub = (coordinate_grid - mean)
|
||||
|
||||
@ -37,7 +38,7 @@ def kp2gaussian(kp, spatial_size, kp_variance):
|
||||
return out
|
||||
|
||||
|
||||
def make_coordinate_grid(spatial_size, ref, **kwargs):
|
||||
def make_coordinate_grid(spatial_size: List[int], ref):
|
||||
d, h, w = spatial_size
|
||||
x = torch.arange(w).type(ref.dtype).to(ref.device)
|
||||
y = torch.arange(h).type(ref.dtype).to(ref.device)
|
||||
@ -112,7 +113,7 @@ class UpBlock3d(nn.Module):
|
||||
self.norm = nn.BatchNorm3d(out_features, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.interpolate(x, scale_factor=(1, 2, 2))
|
||||
out = F.interpolate(x, scale_factor=(1.0, 2.0, 2.0))
|
||||
out = self.conv(out)
|
||||
out = self.norm(out)
|
||||
out = F.relu(out)
|
||||
@ -224,7 +225,7 @@ class Decoder(nn.Module):
|
||||
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
|
||||
self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: List[torch.Tensor]):
|
||||
out = x.pop()
|
||||
for up_block in self.up_blocks:
|
||||
out = up_block(out)
|
||||
|
Loading…
Reference in New Issue
Block a user