mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 12:22:38 +00:00
chore: modify arguments (#249)
This commit is contained in:
parent
5d1d71b1e2
commit
3f394785fb
@ -3,11 +3,10 @@
|
|||||||
"""
|
"""
|
||||||
All configs for user
|
All configs for user
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import tyro
|
import tyro
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from typing import Optional
|
from typing import Optional, Literal
|
||||||
from .base_config import PrintableConfig, make_abs_path
|
from .base_config import PrintableConfig, make_abs_path
|
||||||
|
|
||||||
|
|
||||||
@ -33,13 +32,15 @@ class ArgumentConfig(PrintableConfig):
|
|||||||
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
||||||
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
|
flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
|
||||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||||
|
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
|
||||||
########## source crop arguments ##########
|
########## source crop arguments ##########
|
||||||
det_thresh: float = 0.15 # detection threshold
|
det_thresh: float = 0.15 # detection threshold
|
||||||
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
||||||
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
|
vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
|
||||||
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
|
vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
|
||||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
||||||
|
source_max_dim: int = 1280 # the max dim of height and width of source image or video, you can change it to a larger number, e.g., 1920
|
||||||
|
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
||||||
|
|
||||||
########## driving crop arguments ##########
|
########## driving crop arguments ##########
|
||||||
scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video
|
scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video
|
||||||
|
@ -37,11 +37,13 @@ class InferenceConfig(PrintableConfig):
|
|||||||
flag_do_rot: bool = True
|
flag_do_rot: bool = True
|
||||||
flag_force_cpu: bool = False
|
flag_force_cpu: bool = False
|
||||||
flag_do_torch_compile: bool = False
|
flag_do_torch_compile: bool = False
|
||||||
|
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||||
|
source_max_dim: int = 1280 # the max dim of height and width of source image or video
|
||||||
|
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
||||||
|
|
||||||
# NOT EXPORTED PARAMS
|
# NOT EXPORTED PARAMS
|
||||||
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
||||||
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
|
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
|
||||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
|
||||||
anchor_frame: int = 0 # TO IMPLEMENT
|
anchor_frame: int = 0 # TO IMPLEMENT
|
||||||
|
|
||||||
input_shape: Tuple[int, int] = (256, 256) # input shape
|
input_shape: Tuple[int, int] = (256, 256) # input shape
|
||||||
@ -51,5 +53,3 @@ class InferenceConfig(PrintableConfig):
|
|||||||
|
|
||||||
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
|
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
|
||||||
size_gif: int = 256 # default gif size, TO IMPLEMENT
|
size_gif: int = 256 # default gif size, TO IMPLEMENT
|
||||||
source_max_dim: int = 1280 # the max dim of height and width of source image or video
|
|
||||||
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
|
||||||
|
@ -19,9 +19,9 @@ from .config.crop_config import CropConfig
|
|||||||
from .utils.cropper import Cropper
|
from .utils.cropper import Cropper
|
||||||
from .utils.camera import get_rotation_matrix
|
from .utils.camera import get_rotation_matrix
|
||||||
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
|
from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream
|
||||||
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
from .utils.crop import prepare_paste_back, paste_back
|
||||||
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
|
from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load
|
||||||
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image
|
from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, is_square_video
|
||||||
from .utils.filter import smooth
|
from .utils.filter import smooth
|
||||||
from .utils.rprint import rlog as log
|
from .utils.rprint import rlog as log
|
||||||
# from .utils.viz import viz_lmk
|
# from .utils.viz import viz_lmk
|
||||||
@ -137,7 +137,7 @@ class LivePortraitPipeline(object):
|
|||||||
driving_rgb_lst = driving_rgb_lst[:n_frames]
|
driving_rgb_lst = driving_rgb_lst[:n_frames]
|
||||||
else:
|
else:
|
||||||
n_frames = driving_n_frames
|
n_frames = driving_n_frames
|
||||||
if inf_cfg.flag_crop_driving_video:
|
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
|
||||||
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
|
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
|
||||||
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
|
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
|
||||||
if len(ret_d["frame_crop_lst"]) is not n_frames:
|
if len(ret_d["frame_crop_lst"]) is not n_frames:
|
||||||
@ -382,8 +382,7 @@ class LivePortraitPipeline(object):
|
|||||||
if flag_source_has_audio or flag_driving_has_audio:
|
if flag_source_has_audio or flag_driving_has_audio:
|
||||||
# final result with concatenation
|
# final result with concatenation
|
||||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
||||||
# audio_from_which_video = args.source if flag_source_has_audio else args.driving # default source audio
|
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||||
audio_from_which_video = args.driving if flag_driving_has_audio else args.source # default driving audio
|
|
||||||
log(f"Audio is selected from {audio_from_which_video}, concat mode")
|
log(f"Audio is selected from {audio_from_which_video}, concat mode")
|
||||||
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
||||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||||
@ -399,8 +398,7 @@ class LivePortraitPipeline(object):
|
|||||||
######### build the final result #########
|
######### build the final result #########
|
||||||
if flag_source_has_audio or flag_driving_has_audio:
|
if flag_source_has_audio or flag_driving_has_audio:
|
||||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
||||||
# audio_from_which_video = args.source if flag_source_has_audio else args.driving # default source audio
|
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||||
audio_from_which_video = args.driving if flag_driving_has_audio else args.source # default driving audio
|
|
||||||
log(f"Audio is selected from {audio_from_which_video}")
|
log(f"Audio is selected from {audio_from_which_video}")
|
||||||
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
||||||
os.replace(wfp_with_audio, wfp)
|
os.replace(wfp_with_audio, wfp)
|
||||||
|
@ -135,6 +135,7 @@ class Cropper(object):
|
|||||||
|
|
||||||
return lmk
|
return lmk
|
||||||
|
|
||||||
|
# TODO: support skipping frame with NO FACE
|
||||||
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
|
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
|
||||||
"""Tracking based landmarks/alignment and cropping"""
|
"""Tracking based landmarks/alignment and cropping"""
|
||||||
trajectory = Trajectory()
|
trajectory = Trajectory()
|
||||||
@ -157,8 +158,10 @@ class Cropper(object):
|
|||||||
lmk = self.landmark_runner.run(frame_rgb, lmk)
|
lmk = self.landmark_runner.run(frame_rgb, lmk)
|
||||||
trajectory.start, trajectory.end = idx, idx
|
trajectory.start, trajectory.end = idx, idx
|
||||||
else:
|
else:
|
||||||
|
# TODO: add IOU check for tracking
|
||||||
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
|
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
|
||||||
trajectory.end = idx
|
trajectory.end = idx
|
||||||
|
|
||||||
trajectory.lmk_lst.append(lmk)
|
trajectory.lmk_lst.append(lmk)
|
||||||
|
|
||||||
# crop the face
|
# crop the face
|
||||||
|
Loading…
Reference in New Issue
Block a user