fix run speed.py

This commit is contained in:
fengfusen 2024-07-12 15:03:49 +08:00
parent 6275173411
commit a96532752d

View File

@ -47,11 +47,14 @@ def load_and_compile_models(cfg, model_config):
"""
Load and compile models for inference
"""
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module')
import torch._dynamo
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor),