fix: ndarray with default_factory

This commit is contained in:
guojianzhu 2024-07-17 17:00:26 +08:00
parent 0f839844f6
commit bc4aaa44bb

View File

@ -4,10 +4,9 @@
config dataclass used for inference
"""
import os.path as osp
import cv2
from numpy import ndarray
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path
@ -34,8 +33,8 @@ class InferenceConfig(PrintableConfig):
flag_pasteback: bool = True
flag_do_crop: bool = True
flag_do_rot: bool = True
flag_force_cpu: bool = False
flag_do_torch_compile: bool = False
flag_force_cpu: bool = False
flag_do_torch_compile: bool = False
# NOT EXPORTED PARAMS
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
@ -46,7 +45,7 @@ class InferenceConfig(PrintableConfig):
crf: int = 15 # crf for output video
output_fps: int = 25 # default output fps
mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
size_gif: int = 256 # default gif size, TO IMPLEMENT
source_max_dim: int = 1280 # the max dim of height and width of source image
source_division: int = 2 # make sure the height and width of source image can be divided by this number