From bc4aaa44bb2d358e5123ac9a702216109ca09b90 Mon Sep 17 00:00:00 2001 From: guojianzhu Date: Wed, 17 Jul 2024 17:00:26 +0800 Subject: [PATCH] fix: ndarray with default_factory --- src/config/inference_config.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/config/inference_config.py b/src/config/inference_config.py index c25e83d..7c6c718 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -4,10 +4,9 @@ config dataclass used for inference """ -import os.path as osp import cv2 from numpy import ndarray -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal, Tuple from .base_config import PrintableConfig, make_abs_path @@ -34,8 +33,8 @@ class InferenceConfig(PrintableConfig): flag_pasteback: bool = True flag_do_crop: bool = True flag_do_rot: bool = True - flag_force_cpu: bool = False - flag_do_torch_compile: bool = False + flag_force_cpu: bool = False + flag_do_torch_compile: bool = False # NOT EXPORTED PARAMS lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero @@ -46,7 +45,7 @@ class InferenceConfig(PrintableConfig): crf: int = 15 # crf for output video output_fps: int = 25 # default output fps - mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR) + mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) size_gif: int = 256 # default gif size, TO IMPLEMENT source_max_dim: int = 1280 # the max dim of height and width of source image source_division: int = 2 # make sure the height and width of source image can be divided by this number