diff --git a/inference.py b/inference.py index 8387e7f..24a9fd7 100644 --- a/inference.py +++ b/inference.py @@ -1,5 +1,6 @@ # coding: utf-8 +import os.path as osp import tyro from src.config.argument_config import ArgumentConfig from src.config.inference_config import InferenceConfig @@ -11,11 +12,21 @@ def partial_fields(target_class, kwargs): return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) +def fast_check_args(args: ArgumentConfig): + if not osp.exists(args.source_image): + raise FileNotFoundError(f"source image not found: {args.source_image}") + if not osp.exists(args.driving_info): + raise FileNotFoundError(f"driving info not found: {args.driving_info}") + + def main(): # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) + # fast check the args + fast_check_args(args) + # specify configs for inference inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig