mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 20:42:38 +00:00
fix: compile on v100
LGTM.
This commit is contained in:
commit
6bcbb014fb
15
speed.py
15
speed.py
@ -6,10 +6,13 @@ Benchmark the inference speed of each module in LivePortrait.
|
|||||||
TODO: heavy GPT style, need to refactor
|
TODO: heavy GPT style, need to refactor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import yaml
|
|
||||||
import torch
|
import torch
|
||||||
|
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||||
|
|
||||||
|
import yaml
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.utils.helper import load_model, concat_feat
|
from src.utils.helper import load_model, concat_feat
|
||||||
from src.config.inference_config import InferenceConfig
|
from src.config.inference_config import InferenceConfig
|
||||||
|
|
||||||
@ -47,11 +50,11 @@ def load_and_compile_models(cfg, model_config):
|
|||||||
"""
|
"""
|
||||||
Load and compile models for inference
|
Load and compile models for inference
|
||||||
"""
|
"""
|
||||||
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor')
|
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
||||||
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor')
|
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
||||||
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module')
|
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
||||||
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator')
|
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
||||||
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module')
|
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
||||||
|
|
||||||
models_with_params = [
|
models_with_params = [
|
||||||
('Appearance Feature Extractor', appearance_feature_extractor),
|
('Appearance Feature Extractor', appearance_feature_extractor),
|
||||||
|
Loading…
Reference in New Issue
Block a user