chore: modify arguments (#249)

This commit is contained in:
Jianzhu Guo 2024-07-30 18:56:16 +08:00 committed by GitHub
parent 5d1d71b1e2
commit 3f394785fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 13 deletions

View File

@ -3,11 +3,10 @@
""" """
All configs for user All configs for user
""" """
from dataclasses import dataclass from dataclasses import dataclass
import tyro import tyro
from typing_extensions import Annotated from typing_extensions import Annotated
from typing import Optional from typing import Optional, Literal
from .base_config import PrintableConfig, make_abs_path from .base_config import PrintableConfig, make_abs_path
@ -33,13 +32,15 @@ class ArgumentConfig(PrintableConfig):
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
########## source crop arguments ########## ########## source crop arguments ##########
det_thresh: float = 0.15 # detection threshold det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger scale: float = 2.3 # the ratio of face area is smaller if scale is larger
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
source_max_dim: int = 1280 # the max dim of height and width of source image or video, you can change it to a larger number, e.g., 1920
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
########## driving crop arguments ########## ########## driving crop arguments ##########
scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video

View File

@ -37,11 +37,13 @@ class InferenceConfig(PrintableConfig):
flag_do_rot: bool = True flag_do_rot: bool = True
flag_force_cpu: bool = False flag_force_cpu: bool = False
flag_do_torch_compile: bool = False flag_do_torch_compile: bool = False
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
# NOT EXPORTED PARAMS # NOT EXPORTED PARAMS
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
anchor_frame: int = 0 # TO IMPLEMENT anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape input_shape: Tuple[int, int] = (256, 256) # input shape
@ -51,5 +53,3 @@ class InferenceConfig(PrintableConfig):
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
size_gif: int = 256 # default gif size, TO IMPLEMENT size_gif: int = 256 # default gif size, TO IMPLEMENT
source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number

View File

@ -19,9 +19,9 @@ from .config.crop_config import CropConfig
from .utils.cropper import Cropper from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
from .utils.crop import _transform_img, prepare_paste_back, paste_back from .utils.crop import prepare_paste_back, paste_back
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, is_square_video
from .utils.filter import smooth from .utils.filter import smooth
from .utils.rprint import rlog as log from .utils.rprint import rlog as log
# from .utils.viz import viz_lmk # from .utils.viz import viz_lmk
@ -137,7 +137,7 @@ class LivePortraitPipeline(object):
driving_rgb_lst = driving_rgb_lst[:n_frames] driving_rgb_lst = driving_rgb_lst[:n_frames]
else: else:
n_frames = driving_n_frames n_frames = driving_n_frames
if inf_cfg.flag_crop_driving_video: if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
ret_d = self.cropper.crop_driving_video(driving_rgb_lst) ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames: if len(ret_d["frame_crop_lst"]) is not n_frames:
@ -382,8 +382,7 @@ class LivePortraitPipeline(object):
if flag_source_has_audio or flag_driving_has_audio: if flag_source_has_audio or flag_driving_has_audio:
# final result with concatenation # final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
# audio_from_which_video = args.source if flag_source_has_audio else args.driving # default source audio audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
audio_from_which_video = args.driving if flag_driving_has_audio else args.source # default driving audio
log(f"Audio is selected from {audio_from_which_video}, concat mode") log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat) os.replace(wfp_concat_with_audio, wfp_concat)
@ -399,8 +398,7 @@ class LivePortraitPipeline(object):
######### build the final result ######### ######### build the final result #########
if flag_source_has_audio or flag_driving_has_audio: if flag_source_has_audio or flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
# audio_from_which_video = args.source if flag_source_has_audio else args.driving # default source audio audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
audio_from_which_video = args.driving if flag_driving_has_audio else args.source # default driving audio
log(f"Audio is selected from {audio_from_which_video}") log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp) os.replace(wfp_with_audio, wfp)

View File

@ -135,6 +135,7 @@ class Cropper(object):
return lmk return lmk
# TODO: support skipping frame with NO FACE
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs): def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
"""Tracking based landmarks/alignment and cropping""" """Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory() trajectory = Trajectory()
@ -157,8 +158,10 @@ class Cropper(object):
lmk = self.landmark_runner.run(frame_rgb, lmk) lmk = self.landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx trajectory.start, trajectory.end = idx, idx
else: else:
# TODO: add IOU check for tracking
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1]) lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx trajectory.end = idx
trajectory.lmk_lst.append(lmk) trajectory.lmk_lst.append(lmk)
# crop the face # crop the face