feat: update crop configuration parameters for clarify (#97)

The crop configuration parameters in `crop_config.py` have been updated. The changes include:
- Updating the paths for insightface_root and landmark_ckpt_path

These changes aim to improve the cropping functionality of the application.
This commit is contained in:
longredzhong 2024-07-11 16:30:57 +08:00 committed by GitHub
parent 1472379e77
commit 470c58fe5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 52 deletions

View File

@ -4,25 +4,26 @@
parameters used for crop faces parameters used for crop faces
""" """
import os.path as osp
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union, List
from .base_config import PrintableConfig from .base_config import PrintableConfig
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class CropConfig(PrintableConfig): class CropConfig(PrintableConfig):
device_id: int = 0 # gpu device id insightface_root: str = "../../pretrained_weights/insightface"
flag_force_cpu: bool = False # force cpu inference, WIP landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
########## source image cropping option ########## ########## source image cropping option ##########
dsize: int = 512 # crop size dsize: int = 512 # crop size
scale: float = 2.5 # scale factor scale: float = 2.5 # scale factor
vx_ratio: float = 0 # vx ratio vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down vy_ratio: float = -0.125 # vy ratio +up, -down
max_face_num: int = 0 # max face number, 0 mean no limit max_face_num: int = 0 # max face number, 0 mean no limit
########## driving video auto cropping option ########## ########## driving video auto cropping option ##########
scale_crop_video: float = 2.2 #2.0 # scale factor for cropping video scale_crop_video: float = 2.2 # 2.0 # scale factor for cropping video
vx_ratio_crop_video: float = 0. # adjust y offset vx_ratio_crop_video: float = 0.0 # adjust y offset
vy_ratio_crop_video: float = -0.1 # adjust x offset vy_ratio_crop_video: float = -0.1 # adjust x offset
direction: str = 'large-small' # direction of cropping direction: str = "large-small" # direction of cropping

View File

@ -1,17 +1,26 @@
# coding: utf-8 # coding: utf-8
import numpy as np
import os.path as osp import os.path as osp
from typing import List, Union, Tuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) from typing import List, Tuple, Union
import cv2
import numpy as np
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
from ..config.crop_config import CropConfig from ..config.crop_config import CropConfig
from .landmark_runner import LandmarkRunner from .crop import (
average_bbox_lst,
crop_image,
crop_image_by_bbox,
parse_bbox_from_landmark,
)
from .face_analysis_diy import FaceAnalysisDIY from .face_analysis_diy import FaceAnalysisDIY
from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst
from .rprint import rlog as log
from .io import contiguous from .io import contiguous
from .landmark_runner import LandmarkRunner
from .rprint import rlog as log
def make_abs_path(fn): def make_abs_path(fn):
@ -25,40 +34,44 @@ class Trajectory:
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(
default_factory=list
) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list default_factory=list
) # lmk list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(
default_factory=list
) # frame crop list
class Cropper(object): class Cropper(object):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
device_id = kwargs.get('device_id', 0) self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
flag_force_cpu = kwargs.get('flag_force_cpu', False) device_id = kwargs.get("device_id", 0)
flag_force_cpu = kwargs.get("flag_force_cpu", False)
if flag_force_cpu: if flag_force_cpu:
device = 'cpu' device = "cpu"
face_analysis_wrapper_provicer = ['CPUExecutionProvider'] face_analysis_wrapper_provicer = ["CPUExecutionProvider"]
else: else:
device = 'cuda' device = "cuda"
face_analysis_wrapper_provicer = ["CUDAExecutionProvider"] face_analysis_wrapper_provicer = ["CUDAExecutionProvider"]
self.landmark_runner = LandmarkRunner( self.landmark_runner = LandmarkRunner(
ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'), ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
onnx_provider=device, onnx_provider=device,
device_id=device_id device_id=device_id,
) )
self.landmark_runner.warmup() self.landmark_runner.warmup()
self.face_analysis_wrapper = FaceAnalysisDIY( self.face_analysis_wrapper = FaceAnalysisDIY(
name='buffalo_l', name="buffalo_l",
root=make_abs_path('../../pretrained_weights/insightface'), root=make_abs_path(self.crop_cfg.insightface_root),
providers=face_analysis_wrapper_provicer providers=face_analysis_wrapper_provicer,
) )
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512)) self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
self.face_analysis_wrapper.warmup() self.face_analysis_wrapper.warmup()
self.crop_cfg: CropConfig = kwargs.get('crop_cfg', None)
def update_config(self, user_args): def update_config(self, user_args):
for k, v in user_args.items(): for k, v in user_args.items():
if hasattr(self.crop_cfg, k): if hasattr(self.crop_cfg, k):
@ -77,10 +90,12 @@ class Cropper(object):
) )
if len(src_face) == 0: if len(src_face) == 0:
log('No face detected in the source image.') log("No face detected in the source image.")
return None return None
elif len(src_face) > 1: elif len(src_face) > 1:
log(f'More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.') log(
f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}."
)
# NOTE: temporarily only pick the first face, to support multiple face in the future # NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0] src_face = src_face[0]
@ -97,30 +112,34 @@ class Cropper(object):
) )
lmk = self.landmark_runner.run(img_rgb, lmk) lmk = self.landmark_runner.run(img_rgb, lmk)
ret_dct['lmk_crop'] = lmk ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input # update a 256x256 version for network input
ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA) ret_dct["img_crop_256x256"] = cv2.resize(
ret_dct['lmk_crop_256x256'] = ret_dct['lmk_crop'] * 256 / crop_cfg.dsize ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA
)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
return ret_dct return ret_dct
def crop_driving_video(self, driving_rgb_lst, **kwargs): def crop_driving_video(self, driving_rgb_lst, **kwargs):
"""Tracking based landmarks/alignment and cropping""" """Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory() trajectory = Trajectory()
direction = kwargs.get('direction', 'large-small') direction = kwargs.get("direction", "large-small")
for idx, frame_rgb in enumerate(driving_rgb_lst): for idx, frame_rgb in enumerate(driving_rgb_lst):
if idx == 0 or trajectory.start == -1: if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get( src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb[..., ::-1]), contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True, flag_do_landmark_2d_106=True,
direction=direction direction=direction,
) )
if len(src_face) == 0: if len(src_face) == 0:
log(f'No face detected in the frame #{idx}') log(f"No face detected in the frame #{idx}")
continue continue
elif len(src_face) > 1: elif len(src_face) > 1:
log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.') log(
f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}."
)
src_face = src_face[0] src_face = src_face[0]
lmk = src_face.landmark_2d_106 lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk) lmk = self.landmark_runner.run(frame_rgb, lmk)
@ -130,47 +149,61 @@ class Cropper(object):
trajectory.end = idx trajectory.end = idx
trajectory.lmk_lst.append(lmk) trajectory.lmk_lst.append(lmk)
ret_bbox = parse_bbox_from_landmark(lmk, scale=self.crop_cfg.scale_crop_video, vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video, vy_ratio=self.crop_cfg.vy_ratio_crop_video)['bbox'] ret_bbox = parse_bbox_from_landmark(
bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4, lmk,
scale=self.crop_cfg.scale_crop_video,
vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video,
vy_ratio=self.crop_cfg.vy_ratio_crop_video,
)["bbox"]
bbox = [
ret_bbox[0, 0],
ret_bbox[0, 1],
ret_bbox[2, 0],
ret_bbox[2, 1],
] # 4,
trajectory.bbox_lst.append(bbox) # bbox trajectory.bbox_lst.append(bbox) # bbox
trajectory.frame_rgb_lst.append(frame_rgb) trajectory.frame_rgb_lst.append(frame_rgb)
global_bbox = average_bbox_lst(trajectory.bbox_lst) global_bbox = average_bbox_lst(trajectory.bbox_lst)
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)): for idx, (frame_rgb, lmk) in enumerate(
zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)
):
ret_dct = crop_image_by_bbox( ret_dct = crop_image_by_bbox(
frame_rgb, frame_rgb,
global_bbox, global_bbox,
lmk=lmk, lmk=lmk,
dsize=kwargs.get('dsize', 512), dsize=kwargs.get("dsize", 512),
flag_rot=False, flag_rot=False,
borderValue=(0, 0, 0), borderValue=(0, 0, 0),
) )
trajectory.frame_rgb_crop_lst.append(ret_dct['img_crop']) trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"])
trajectory.lmk_crop_lst.append(ret_dct['lmk_crop']) trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"])
return { return {
'frame_crop_lst': trajectory.frame_rgb_crop_lst, "frame_crop_lst": trajectory.frame_rgb_crop_lst,
'lmk_crop_lst': trajectory.lmk_crop_lst, "lmk_crop_lst": trajectory.lmk_crop_lst,
} }
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs): def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
"""Tracking based landmarks/alignment""" """Tracking based landmarks/alignment"""
trajectory = Trajectory() trajectory = Trajectory()
direction = kwargs.get('direction', 'large-small') direction = kwargs.get("direction", "large-small")
for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst): for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst):
if idx == 0 or trajectory.start == -1: if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get( src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR
flag_do_landmark_2d_106=True, flag_do_landmark_2d_106=True,
direction=direction direction=direction,
) )
if len(src_face) == 0: if len(src_face) == 0:
log(f'No face detected in the frame #{idx}') log(f"No face detected in the frame #{idx}")
raise Exception(f'No face detected in the frame #{idx}') raise Exception(f"No face detected in the frame #{idx}")
elif len(src_face) > 1: elif len(src_face) > 1:
log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.') log(
f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}."
)
src_face = src_face[0] src_face = src_face[0]
lmk = src_face.landmark_2d_106 lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb_crop, lmk) lmk = self.landmark_runner.run(frame_rgb_crop, lmk)