chore: add fast check filepath

This commit is contained in:
guojianzhu 2024-07-11 22:58:45 +08:00
parent cf9a5b0d4c
commit fe43a21d81

View File

@ -1,5 +1,6 @@
# coding: utf-8 # coding: utf-8
import os.path as osp
import tyro import tyro
from src.config.argument_config import ArgumentConfig from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig from src.config.inference_config import InferenceConfig
@ -11,11 +12,21 @@ def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def fast_check_args(args: ArgumentConfig):
if not osp.exists(args.source_image):
raise FileNotFoundError(f"source image not found: {args.source_image}")
if not osp.exists(args.driving_info):
raise FileNotFoundError(f"driving info not found: {args.driving_info}")
def main(): def main():
# set tyro theme # set tyro theme
tyro.extras.set_accent_color("bright_cyan") tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig) args = tyro.cli(ArgumentConfig)
# fast check the args
fast_check_args(args)
# specify configs for inference # specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig