mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 12:22:38 +00:00
Update speed.py
This commit is contained in:
parent
a96532752d
commit
286d1fd35a
8
speed.py
8
speed.py
@ -6,10 +6,13 @@ Benchmark the inference speed of each module in LivePortrait.
|
||||
TODO: heavy GPT style, need to refactor
|
||||
"""
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
|
||||
import yaml
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from src.utils.helper import load_model, concat_feat
|
||||
from src.config.inference_config import InferenceConfig
|
||||
|
||||
@ -47,9 +50,6 @@ def load_and_compile_models(cfg, model_config):
|
||||
"""
|
||||
Load and compile models for inference
|
||||
"""
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
||||
|
||||
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
||||
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
||||
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
||||
|
Loading…
Reference in New Issue
Block a user