fix: compile on v100

LGTM.
This commit is contained in:
ZhizhouZhong 2024-07-13 14:51:32 +08:00 committed by GitHub
commit 6bcbb014fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,10 +6,13 @@ Benchmark the inference speed of each module in LivePortrait.
TODO: heavy GPT style, need to refactor TODO: heavy GPT style, need to refactor
""" """
import yaml
import torch import torch
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
import yaml
import time import time
import numpy as np import numpy as np
from src.utils.helper import load_model, concat_feat from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig from src.config.inference_config import InferenceConfig
@ -47,11 +50,11 @@ def load_and_compile_models(cfg, model_config):
""" """
Load and compile models for inference Load and compile models for inference
""" """
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor') appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor') motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module') warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator') spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module') stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
models_with_params = [ models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor), ('Appearance Feature Extractor', appearance_feature_extractor),