diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index 5b63fe1..c6956fa 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -23,7 +23,7 @@ from .utils.camera import get_rotation_matrix from .utils.video import get_fps, has_audio_stream, concat_frames, images2video, add_audio_to_video from .utils.helper import is_square_video, mkdir, dct2device, basename from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio - +from skimage.exposure import match_histograms def update_args(args, user_args): """update the args according to user inputs @@ -426,6 +426,8 @@ class GradioPipeline(LivePortraitPipeline): I_p_lst.append(I_p_i) if flag_do_crop_input_retargeting_video: + I_p_i = match_histograms(I_p_i,img_crop_256x256_lst[i]) + I_p_i = np.clip(I_p_i,0,255).astype(np.uint8) I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) I_p_pstbk_lst.append(I_p_pstbk) else: @@ -447,6 +449,8 @@ class GradioPipeline(LivePortraitPipeline): I_p_lst.append(I_p_i) if flag_do_crop_input_retargeting_video: + I_p_i = match_histograms(I_p_i,img_crop_256x256_lst[i]) + I_p_i = np.clip(I_p_i,0,255).astype(np.uint8) I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) I_p_pstbk_lst.append(I_p_pstbk)