fix: ndarray with default_factory

This commit is contained in:
guojianzhu 2024-07-17 17:00:26 +08:00
parent 0f839844f6
commit bc4aaa44bb

View File

@ -4,10 +4,9 @@
config dataclass used for inference config dataclass used for inference
""" """
import os.path as osp
import cv2 import cv2
from numpy import ndarray from numpy import ndarray
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path from .base_config import PrintableConfig, make_abs_path
@ -34,8 +33,8 @@ class InferenceConfig(PrintableConfig):
flag_pasteback: bool = True flag_pasteback: bool = True
flag_do_crop: bool = True flag_do_crop: bool = True
flag_do_rot: bool = True flag_do_rot: bool = True
flag_force_cpu: bool = False flag_force_cpu: bool = False
flag_do_torch_compile: bool = False flag_do_torch_compile: bool = False
# NOT EXPORTED PARAMS # NOT EXPORTED PARAMS
lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero
@ -46,7 +45,7 @@ class InferenceConfig(PrintableConfig):
crf: int = 15 # crf for output video crf: int = 15 # crf for output video
output_fps: int = 25 # default output fps output_fps: int = 25 # default output fps
mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR) mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
size_gif: int = 256 # default gif size, TO IMPLEMENT size_gif: int = 256 # default gif size, TO IMPLEMENT
source_max_dim: int = 1280 # the max dim of height and width of source image source_max_dim: int = 1280 # the max dim of height and width of source image
source_division: int = 2 # make sure the height and width of source image can be divided by this number source_division: int = 2 # make sure the height and width of source image can be divided by this number