diff --git a/.gitignore b/.gitignore index c7612d0..a6a28fc 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ animations/* tmp/* .vscode/launch.json **/*.DS_Store +gradio_temp/** diff --git a/app.py b/app.py index 9b59a4d..ad2b4f8 100644 --- a/app.py +++ b/app.py @@ -4,6 +4,7 @@ The entrance of the gradio """ +import os import tyro import subprocess import gradio as gr @@ -47,6 +48,9 @@ gradio_pipeline = GradioPipeline( args=args ) +if args.gradio_temp_dir not in (None, ''): + os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir + os.makedirs(args.gradio_temp_dir, exist_ok=True) def gpu_wrapped_execute_video(*args, **kwargs): return gradio_pipeline.execute_video(*args, **kwargs) @@ -69,25 +73,27 @@ data_examples_i2v = [ [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True], ] data_examples_v2v = [ - [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7], - # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 1e-7], - # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 1e-7], - [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7], - # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 1e-7], - [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 1e-7], + [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7], + # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7], + # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7], + [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7], + # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7], + [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7], ] #################### interface logic #################### # Define components first +retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") +head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch") +head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw") +head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll") retargeting_input_image = gr.Image(type="filepath") output_image = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy") output_video_i2v = gr.Video(autoplay=False) output_video_concat_i2v = gr.Video(autoplay=False) -# output_video_v2v = gr.Video(autoplay=False) -# output_video_concat_v2v = gr.Video(autoplay=False) with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo: @@ -108,6 +114,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San [osp.join(example_portrait_dir, "s5.jpg")], [osp.join(example_portrait_dir, "s7.jpg")], [osp.join(example_portrait_dir, "s12.jpg")], + [osp.join(example_portrait_dir, "s22.jpg")], + [osp.join(example_portrait_dir, "s23.jpg")], ], inputs=[source_image_input], cache_examples=False, @@ -149,6 +157,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San [osp.join(example_video_dir, "d19.mp4")], [osp.join(example_video_dir, "d14.mp4")], [osp.join(example_video_dir, "d6.mp4")], + [osp.join(example_video_dir, "d20.mp4")], ], inputs=[driving_video_input], cache_examples=False, @@ -168,14 +177,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)") - driving_smooth_observation_variance = gr.Number(value=1e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8) + driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8) gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md")) with gr.Row(): - with gr.Column(): - process_button_animation = gr.Button("๐Ÿš€ Animate", variant="primary") - with gr.Column(): - process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="๐Ÿงน Clear") + process_button_animation = gr.Button("๐Ÿš€ Animate", variant="primary") with gr.Row(): with gr.Column(): with gr.Accordion(open=True, label="The animated video in the original image space"): @@ -183,6 +189,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San with gr.Column(): with gr.Accordion(open=True, label="The animated video"): output_video_concat_i2v.render() + with gr.Row(): + process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="๐Ÿงน Clear") with gr.Row(): # Examples @@ -227,20 +235,15 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San # Retargeting gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True) with gr.Row(visible=True): + retargeting_source_scale.render() eye_retargeting_slider.render() lip_retargeting_slider.render() + with gr.Row(visible=True): + head_pitch_slider.render() + head_yaw_slider.render() + head_roll_slider.render() with gr.Row(visible=True): process_button_retargeting = gr.Button("๐Ÿš— Retargeting", variant="primary") - process_button_reset_retargeting = gr.ClearButton( - [ - eye_retargeting_slider, - lip_retargeting_slider, - retargeting_input_image, - output_image, - output_image_paste_back - ], - value="๐Ÿงน Clear" - ) with gr.Row(visible=True): with gr.Column(): with gr.Accordion(open=True, label="Retargeting Input"): @@ -253,6 +256,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San [osp.join(example_portrait_dir, "s5.jpg")], [osp.join(example_portrait_dir, "s7.jpg")], [osp.join(example_portrait_dir, "s12.jpg")], + [osp.join(example_portrait_dir, "s22.jpg")], + [osp.join(example_portrait_dir, "s23.jpg")], ], inputs=[retargeting_input_image], cache_examples=False, @@ -263,15 +268,30 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San with gr.Column(): with gr.Accordion(open=True, label="Paste-back Result"): output_image_paste_back.render() + with gr.Row(visible=True): + process_button_reset_retargeting = gr.ClearButton( + [ + eye_retargeting_slider, + lip_retargeting_slider, + head_pitch_slider, + head_yaw_slider, + head_roll_slider, + retargeting_input_image, + output_image, + output_image_paste_back + ], + value="๐Ÿงน Clear" + ) # binding functions for buttons process_button_retargeting.click( # fn=gradio_pipeline.execute_image, fn=gpu_wrapped_execute_image, - inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input], + inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input], outputs=[output_image, output_image_paste_back], show_progress=True ) + process_button_animation.click( fn=gpu_wrapped_execute_video, inputs=[ @@ -296,6 +316,12 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San show_progress=True ) + retargeting_input_image.change( + fn=gradio_pipeline.init_retargeting, + inputs=[retargeting_source_scale, retargeting_input_image], + outputs=[eye_retargeting_slider, lip_retargeting_slider] + ) + demo.launch( server_port=args.server_port, share=args.share, diff --git a/assets/docs/changelog/2024-07-24.md b/assets/docs/changelog/2024-07-24.md new file mode 100644 index 0000000..e54aa45 --- /dev/null +++ b/assets/docs/changelog/2024-07-24.md @@ -0,0 +1,5 @@ +

+ LivePortrait +
+ Pose Editing Interface in the Gradio Interface +

diff --git a/assets/docs/pose-edit-2024-07-24.jpg b/assets/docs/pose-edit-2024-07-24.jpg new file mode 100644 index 0000000..74650bc Binary files /dev/null and b/assets/docs/pose-edit-2024-07-24.jpg differ diff --git a/assets/examples/driving/d20.mp4 b/assets/examples/driving/d20.mp4 new file mode 100644 index 0000000..30822f9 Binary files /dev/null and b/assets/examples/driving/d20.mp4 differ diff --git a/assets/examples/source/s22.jpg b/assets/examples/source/s22.jpg new file mode 100644 index 0000000..9ca08bb Binary files /dev/null and b/assets/examples/source/s22.jpg differ diff --git a/assets/examples/source/s23.jpg b/assets/examples/source/s23.jpg new file mode 100644 index 0000000..4e1373a Binary files /dev/null and b/assets/examples/source/s23.jpg differ diff --git a/assets/gradio/gradio_description_retargeting.md b/assets/gradio/gradio_description_retargeting.md index 64f1a7c..bd7b2bb 100644 --- a/assets/gradio/gradio_description_retargeting.md +++ b/assets/gradio/gradio_description_retargeting.md @@ -9,6 +9,6 @@

Retargeting

Upload a Source Portrait as Retargeting Input, then drag the sliders and click the ๐Ÿš— Retargeting button. You can try running it multiple times.
- ๐Ÿ˜Š Set both ratios to 0.8 to see what's going on!

+ ๐Ÿ˜Š Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!

diff --git a/assets/gradio/gradio_description_upload.md b/assets/gradio/gradio_description_upload.md index f5a018a..f6b975f 100644 --- a/assets/gradio/gradio_description_upload.md +++ b/assets/gradio/gradio_description_upload.md @@ -4,6 +4,9 @@
Step 1: Upload a Source Image or Video (any aspect ratio) โฌ‡๏ธ
+
+ Note: Better if Source Video has the same FPS as the Driving Video. +
diff --git a/readme.md b/readme.md index b378eeb..6518aa1 100644 --- a/readme.md +++ b/readme.md @@ -39,7 +39,8 @@ ## ๐Ÿ”ฅ Updates -- **`2024/07/19`**: โœจ We support ๐ŸŽž๏ธ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md). +- **`2024/07/24`**: ๐ŸŽจ We support pose editing for source portraits in the Gradio interface. We've also lowered the default detection threshold to support more input detections. [Have fun](assets/docs/changelog/2024-07-24.md)! +- **`2024/07/19`**: โœจ We support ๐ŸŽž๏ธ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md). - **`2024/07/17`**: ๐ŸŽ We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143). - **`2024/07/10`**: ๐Ÿ’ช We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md). - **`2024/07/09`**: ๐Ÿค— We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)! diff --git a/src/config/argument_config.py b/src/config/argument_config.py index 08d17a7..6653f9c 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -32,9 +32,10 @@ class ArgumentConfig(PrintableConfig): flag_relative_motion: bool = True # whether to use relative motion flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space - driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy + driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy ########## source crop arguments ########## + det_thresh: float = 0.15 # detection threshold scale: float = 2.3 # the ratio of face area is smaller if scale is larger vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space @@ -50,3 +51,4 @@ class ArgumentConfig(PrintableConfig): share: bool = False # whether to share the server to public server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation + gradio_temp_dir: Optional[str] = None # directory to save gradio temp files diff --git a/src/config/crop_config.py b/src/config/crop_config.py index c7d64a5..6c1f8f2 100644 --- a/src/config/crop_config.py +++ b/src/config/crop_config.py @@ -15,6 +15,7 @@ class CropConfig(PrintableConfig): landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx" device_id: int = 0 # gpu device id flag_force_cpu: bool = False # force cpu inference, WIP + det_thresh: float = 0.1 # detection threshold ########## source image or video cropping option ########## dsize: int = 512 # crop size scale: float = 2.8 # scale factor diff --git a/src/config/inference_config.py b/src/config/inference_config.py index 48bf88c..c1f8653 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -41,7 +41,7 @@ class InferenceConfig(PrintableConfig): # NOT EXPORTED PARAMS lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video - driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy + driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy anchor_frame: int = 0 # TO IMPLEMENT input_shape: Tuple[int, int] = (256, 256) # input shape diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index cbe898e..003c5ca 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -6,6 +6,7 @@ Pipeline for gradio import os.path as osp import gradio as gr +import torch from .config.argument_config import ArgumentConfig from .live_portrait_pipeline import LivePortraitPipeline @@ -14,6 +15,7 @@ from .utils.rprint import rlog as log from .utils.crop import prepare_paste_back, paste_back from .utils.camera import get_rotation_matrix from .utils.helper import is_square_video +from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio def update_args(args, user_args): @@ -32,6 +34,7 @@ class GradioPipeline(LivePortraitPipeline): # self.live_portrait_wrapper = self.live_portrait_wrapper self.args = args + @torch.no_grad() def execute_video( self, input_source_image_path=None, @@ -48,7 +51,7 @@ class GradioPipeline(LivePortraitPipeline): scale_crop_driving_video=2.2, vx_ratio_crop_driving_video=0.0, vy_ratio_crop_driving_video=-0.1, - driving_smooth_observation_variance=1e-7, + driving_smooth_observation_variance=3e-7, tab_selection=None, ): """ for video-driven potrait animation or video editing @@ -93,27 +96,41 @@ class GradioPipeline(LivePortraitPipeline): else: raise gr.Error("Please upload the source portrait or source video, and driving video ๐Ÿค—๐Ÿค—๐Ÿค—", duration=5) - def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True): + @torch.no_grad() + def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop=True): """ for single image retargeting """ + if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None: + raise gr.Error("Invalid relative pose input ๐Ÿ’ฅ!", duration=5) # disposable feature - f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \ - self.prepare_retargeting(input_image, flag_do_crop) + f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \ + self.prepare_retargeting(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop) if input_eye_ratio is None or input_lip_ratio is None: raise gr.Error("Invalid ratio input ๐Ÿ’ฅ!", duration=5) else: - inference_cfg = self.live_portrait_wrapper.inference_cfg - x_s_user = x_s_user.to(self.live_portrait_wrapper.device) - f_s_user = f_s_user.to(self.live_portrait_wrapper.device) + device = self.live_portrait_wrapper.device + # inference_cfg = self.live_portrait_wrapper.inference_cfg + x_s_user = x_s_user.to(device) + f_s_user = f_s_user.to(device) + R_s_user = R_s_user.to(device) + R_d_user = R_d_user.to(device) + + x_c_s = x_s_info['kp'].to(device) + delta_new = x_s_info['exp'].to(device) + scale_new = x_s_info['scale'].to(device) + t_new = x_s_info['t'].to(device) + R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user + + x_d_new = scale_new * (x_c_s @ R_d_new + delta_new) + t_new # โˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) - combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user) + combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user) eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor) # โˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) - combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user) + combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) - # default: use x_s - x_d_new = x_s_user + eyes_delta + lip_delta + x_d_new = x_d_new + eyes_delta + lip_delta + x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new) # D(W(f_s; x_s, xโ€ฒ_d)) out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new) out = self.live_portrait_wrapper.parse_output(out['out'])[0] @@ -121,14 +138,18 @@ class GradioPipeline(LivePortraitPipeline): gr.Info("Run successfully!", duration=2) return out, out_to_ori_blend - def prepare_retargeting(self, input_image, flag_do_crop=True): + @torch.no_grad() + def prepare_retargeting(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True): """ for single image retargeting """ if input_image is not None: # gr.Info("Upload successfully!", duration=2) + args_user = {'scale': retargeting_source_scale} + self.args = update_args(self.args, args_user) + self.cropper.update_config(self.args.__dict__) inference_cfg = self.live_portrait_wrapper.inference_cfg ######## process source portrait ######## - img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16) + img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2) log(f"Load source image from {input_image}.") crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg) if flag_do_crop: @@ -136,14 +157,37 @@ class GradioPipeline(LivePortraitPipeline): else: I_s = self.live_portrait_wrapper.prepare_source(img_rgb) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) - R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + x_s_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation + x_s_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation + x_s_info_user_roll = x_s_info['roll'] + input_head_roll_variation + R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + R_d_user = get_rotation_matrix(x_s_info_user_pitch, x_s_info_user_yaw, x_s_info_user_roll) ############################################ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) source_lmk_user = crop_info['lmk_crop'] crop_M_c2o = crop_info['M_c2o'] mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) - return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb + return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb else: # when press the clear button, go here raise gr.Error("Please upload a source portrait as the retargeting input ๐Ÿค—๐Ÿค—๐Ÿค—", duration=5) + + def init_retargeting(self, retargeting_source_scale: float, input_image = None): + """ initialize the retargeting slider + """ + if input_image != None: + args_user = {'scale': retargeting_source_scale} + self.args = update_args(self.args, args_user) + self.cropper.update_config(self.args.__dict__) + inference_cfg = self.live_portrait_wrapper.inference_cfg + ######## process source portrait ######## + img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16) + log(f"Load source image from {input_image}.") + crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg) + if crop_info is None: + raise gr.Error("Source portrait NO face detected", duration=2) + source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None]) + source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None]) + return round(float(source_eye_ratio.mean()), 2), round(source_lip_ratio[0][0], 2) + return 0., 0. diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index e344ffc..9eccc8e 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -200,17 +200,16 @@ class LivePortraitPipeline(object): x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) else: # if the input is a source image, process it only once - crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg) - if crop_info is None: - raise Exception("No face detected in the source image!") - source_lmk = crop_info['lmk_crop'] - img_crop_256x256 = crop_info['img_crop_256x256'] - if inf_cfg.flag_do_crop: - I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) + crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg) + if crop_info is None: + raise Exception("No face detected in the source image!") + source_lmk = crop_info['lmk_crop'] + img_crop_256x256 = crop_info['img_crop_256x256'] else: + source_lmk = self.cropper.calc_lmk_from_cropped_image(source_rgb_lst[0]) img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256 - I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) + I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_c_s = x_s_info['kp'] R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) @@ -218,7 +217,7 @@ class LivePortraitPipeline(object): x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) # let lip-open scalar to be 0 at first - if flag_normalize_lip: + if flag_normalize_lip and source_lmk is not None: c_d_lip_before_animation = [0.] combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold: @@ -245,14 +244,14 @@ class LivePortraitPipeline(object): x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) # let lip-open scalar to be 0 at first if the input is a video - if flag_normalize_lip: + if flag_normalize_lip and source_lmk is not None: c_d_lip_before_animation = [0.] combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold: lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) # let eye-open scalar to be the same as the first frame if the latter is eye-open state - if flag_source_video_eye_retargeting: + if flag_source_video_eye_retargeting and source_lmk is not None: if i == 0: combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0] c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]] @@ -312,12 +311,12 @@ class LivePortraitPipeline(object): x_d_i_new += eye_delta_before_animation else: eyes_delta, lip_delta = None, None - if inf_cfg.flag_eye_retargeting: + if inf_cfg.flag_eye_retargeting and source_lmk is not None: c_d_eyes_i = c_d_eyes_lst[i] combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) # โˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) - if inf_cfg.flag_lip_retargeting: + if inf_cfg.flag_lip_retargeting and source_lmk is not None: c_d_lip_i = c_d_lip_lst[i] combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) # โˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) diff --git a/src/utils/cropper.py b/src/utils/cropper.py index e0e3789..c42e74b 100644 --- a/src/utils/cropper.py +++ b/src/utils/cropper.py @@ -67,7 +67,7 @@ class Cropper(object): root=make_abs_path(self.crop_cfg.insightface_root), providers=face_analysis_wrapper_provider, ) - self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512)) + self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh) self.face_analysis_wrapper.warmup() def update_config(self, user_args): @@ -117,6 +117,24 @@ class Cropper(object): return ret_dct + def calc_lmk_from_cropped_image(self, img_rgb_, **kwargs): + direction = kwargs.get("direction", "large-small") + src_face = self.face_analysis_wrapper.get( + contiguous(img_rgb_[..., ::-1]), # convert to BGR + flag_do_landmark_2d_106=True, + direction=direction, + ) + if len(src_face) == 0: + log("No face detected in the source image.") + return None + elif len(src_face) > 1: + log(f"More than one face detected in the image, only pick one face by rule {direction}.") + src_face = src_face[0] + lmk = src_face.landmark_2d_106 + lmk = self.landmark_runner.run(img_rgb_, lmk) + + return lmk + def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs): """Tracking based landmarks/alignment and cropping""" trajectory = Trajectory() diff --git a/src/utils/filter.py b/src/utils/filter.py index 5238f49..a8e27ca 100644 --- a/src/utils/filter.py +++ b/src/utils/filter.py @@ -5,7 +5,7 @@ import numpy as np from pykalman import KalmanFilter -def smooth(x_d_lst, shape, device, observation_variance=1e-7, process_variance=1e-5): +def smooth(x_d_lst, shape, device, observation_variance=3e-7, process_variance=1e-5): x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst] x_d_stacked = np.vstack(x_d_lst_reshape) kf = KalmanFilter(