diff --git a/speed.py b/speed.py index 3cad248..4dc4d10 100644 --- a/speed.py +++ b/speed.py @@ -47,11 +47,14 @@ def load_and_compile_models(cfg, model_config): """ Load and compile models for inference """ - appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor') - motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor') - warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module') - spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator') - stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module') + import torch._dynamo + torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution + + appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') + motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') + warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') + spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') + stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') models_with_params = [ ('Appearance Feature Extractor', appearance_feature_extractor),