This commit is contained in:
AleD 2024-07-13 21:18:32 +09:00
commit b3ebc3aaba
2 changed files with 15 additions and 14 deletions

View File

@ -17,17 +17,17 @@ from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig
def initialize_inputs(batch_size=1):
def initialize_inputs(batch_size=1, device_id=0):
"""
Generate random input tensors and move them to GPU
"""
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
kp_source = torch.randn(batch_size, 21, 3).cuda().half()
kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
feat_stitching = concat_feat(kp_source, kp_driving).half()
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
@ -102,7 +102,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
Measure inference times for each model
"""
times = {name: [] for name in compiled_models.keys()}
times['Retargeting Models'] = []
times['Stitching and Retargeting Modules'] = []
overall_times = []
@ -136,7 +136,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
torch.cuda.synchronize()
times['Retargeting Models'].append(time.time() - start)
times['Stitching and Retargeting Modules'].append(time.time() - start)
overall_times.append(time.time() - overall_start)
@ -169,15 +169,15 @@ def main():
"""
Main function to benchmark speed and model parameters
"""
# Sample input tensors
inputs = initialize_inputs()
# Load configuration
cfg = InferenceConfig(device_id=0)
cfg = InferenceConfig()
model_config_path = cfg.models_config
with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file)
# Sample input tensors
inputs = initialize_inputs(device_id = cfg.device_id)
# Load and compile models
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)

View File

@ -51,6 +51,7 @@ class LivePortraitWrapper(object):
self.stitching_retargeting_module = None
# Optimize for inference
if self.compile:
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')