diff --git a/speed.py b/speed.py index c4ed93e..20a3482 100644 --- a/speed.py +++ b/speed.py @@ -17,17 +17,17 @@ from src.utils.helper import load_model, concat_feat from src.config.inference_config import InferenceConfig -def initialize_inputs(batch_size=1): +def initialize_inputs(batch_size=1, device_id=0): """ Generate random input tensors and move them to GPU """ - feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half() - kp_source = torch.randn(batch_size, 21, 3).cuda().half() - kp_driving = torch.randn(batch_size, 21, 3).cuda().half() - source_image = torch.randn(batch_size, 3, 256, 256).cuda().half() - generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half() - eye_close_ratio = torch.randn(batch_size, 3).cuda().half() - lip_close_ratio = torch.randn(batch_size, 2).cuda().half() + feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half() + kp_source = torch.randn(batch_size, 21, 3).to(device_id).half() + kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half() + source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half() + generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half() + eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half() + lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half() feat_stitching = concat_feat(kp_source, kp_driving).half() feat_eye = concat_feat(kp_source, eye_close_ratio).half() feat_lip = concat_feat(kp_source, lip_close_ratio).half() @@ -102,7 +102,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input Measure inference times for each model """ times = {name: [] for name in compiled_models.keys()} - times['Retargeting Models'] = [] + times['Stitching and Retargeting Modules'] = [] overall_times = [] @@ -136,7 +136,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input stitching_retargeting_module['eye'](inputs['feat_eye']) stitching_retargeting_module['lip'](inputs['feat_lip']) torch.cuda.synchronize() - times['Retargeting Models'].append(time.time() - start) + times['Stitching and Retargeting Modules'].append(time.time() - start) overall_times.append(time.time() - overall_start) @@ -169,15 +169,15 @@ def main(): """ Main function to benchmark speed and model parameters """ - # Sample input tensors - inputs = initialize_inputs() - # Load configuration - cfg = InferenceConfig(device_id=0) + cfg = InferenceConfig() model_config_path = cfg.models_config with open(model_config_path, 'r') as file: model_config = yaml.safe_load(file) + # Sample input tensors + inputs = initialize_inputs(device_id = cfg.device_id) + # Load and compile models compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py index ca4bd1b..d2a678d 100644 --- a/src/live_portrait_wrapper.py +++ b/src/live_portrait_wrapper.py @@ -51,6 +51,7 @@ class LivePortraitWrapper(object): self.stitching_retargeting_module = None # Optimize for inference if self.compile: + torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution self.warping_module = torch.compile(self.warping_module, mode='max-autotune') self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')