mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 12:22:38 +00:00
Merge branch 'main' of https://github.com/aled93/LivePortrait
This commit is contained in:
commit
b3ebc3aaba
28
speed.py
28
speed.py
@ -17,17 +17,17 @@ from src.utils.helper import load_model, concat_feat
|
||||
from src.config.inference_config import InferenceConfig
|
||||
|
||||
|
||||
def initialize_inputs(batch_size=1):
|
||||
def initialize_inputs(batch_size=1, device_id=0):
|
||||
"""
|
||||
Generate random input tensors and move them to GPU
|
||||
"""
|
||||
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
|
||||
kp_source = torch.randn(batch_size, 21, 3).cuda().half()
|
||||
kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
|
||||
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
|
||||
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
|
||||
eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
|
||||
lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
|
||||
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
|
||||
kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
|
||||
kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
|
||||
source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
|
||||
generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
|
||||
eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
|
||||
lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
|
||||
feat_stitching = concat_feat(kp_source, kp_driving).half()
|
||||
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
|
||||
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
|
||||
@ -102,7 +102,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
|
||||
Measure inference times for each model
|
||||
"""
|
||||
times = {name: [] for name in compiled_models.keys()}
|
||||
times['Retargeting Models'] = []
|
||||
times['Stitching and Retargeting Modules'] = []
|
||||
|
||||
overall_times = []
|
||||
|
||||
@ -136,7 +136,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input
|
||||
stitching_retargeting_module['eye'](inputs['feat_eye'])
|
||||
stitching_retargeting_module['lip'](inputs['feat_lip'])
|
||||
torch.cuda.synchronize()
|
||||
times['Retargeting Models'].append(time.time() - start)
|
||||
times['Stitching and Retargeting Modules'].append(time.time() - start)
|
||||
|
||||
overall_times.append(time.time() - overall_start)
|
||||
|
||||
@ -169,15 +169,15 @@ def main():
|
||||
"""
|
||||
Main function to benchmark speed and model parameters
|
||||
"""
|
||||
# Sample input tensors
|
||||
inputs = initialize_inputs()
|
||||
|
||||
# Load configuration
|
||||
cfg = InferenceConfig(device_id=0)
|
||||
cfg = InferenceConfig()
|
||||
model_config_path = cfg.models_config
|
||||
with open(model_config_path, 'r') as file:
|
||||
model_config = yaml.safe_load(file)
|
||||
|
||||
# Sample input tensors
|
||||
inputs = initialize_inputs(device_id = cfg.device_id)
|
||||
|
||||
# Load and compile models
|
||||
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
|
||||
|
||||
|
@ -51,6 +51,7 @@ class LivePortraitWrapper(object):
|
||||
self.stitching_retargeting_module = None
|
||||
# Optimize for inference
|
||||
if self.compile:
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
|
||||
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user