diff --git a/app.py b/app.py index c1e895a..9946714 100644 --- a/app.py +++ b/app.py @@ -97,10 +97,24 @@ video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, labe head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch") head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw") head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll") +mov_x = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="x-axis movement") +mov_y = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="y-axis movement") +mov_z = gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.01, label="z-axis movement") +lip_variation_zero = gr.Slider(minimum=-0.09, maximum=0.09, value=0, step=0.01, label="pouting") +lip_variation_one = gr.Slider(minimum=-20.0, maximum=15.0, value=0, step=0.01, label="lip compressed<->pursing") +lip_variation_two = gr.Slider(minimum=0.0, maximum=15.0, value=0, step=0.01, label="grin😬") +lip_variation_three = gr.Slider(minimum=-90.0, maximum=120.0, value=0, step=1.0, label="lip close <-> lip open") +smile = gr.Slider(minimum=-0.3, maximum=1.3, value=0, step=0.01, label="smile") +wink = gr.Slider(minimum=0, maximum=39, value=0, step=0.01, label="wink") +eyebrow = gr.Slider(minimum=-30, maximum=30, value=0, step=0.01, label="eyebrow") +eyeball_direction_x = gr.Slider(minimum=-30.0, maximum=30.0, value=0, step=0.01, label="eye gaze (horizontal)") +eyeball_direction_y = gr.Slider(minimum=-63.0, maximum=63.0, value=0, step=0.01, label="eye gaze (vertical)") retargeting_input_image = gr.Image(type="filepath") retargeting_input_video = gr.Video() output_image = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy") +retargeting_output_image = gr.Image(type="numpy") +retargeting_output_image_paste_back = gr.Image(type="numpy") output_video = gr.Video(autoplay=False) output_video_paste_back = gr.Video(autoplay=False) output_video_i2v = gr.Video(autoplay=False) @@ -250,15 +264,40 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True) with gr.Row(visible=True): flag_do_crop_input_retargeting_image = gr.Checkbox(value=True, label="do crop (source)") + flag_stitching_retargeting_input = gr.Checkbox(value=True, label="stitching") retargeting_source_scale.render() eye_retargeting_slider.render() lip_retargeting_slider.render() + gr.Markdown( + """ +
+
Face movement sliders
+
+ """) with gr.Row(visible=True): head_pitch_slider.render() head_yaw_slider.render() head_roll_slider.render() + mov_x.render() + mov_y.render() + mov_z.render() + gr.Markdown( + """ +
+
Expression blendshape sliders
+
+ """) with gr.Row(visible=True): - process_button_retargeting = gr.Button("πŸš— Retargeting Image", variant="primary") + lip_variation_zero.render() + lip_variation_one.render() + lip_variation_two.render() + lip_variation_three.render() + smile.render() + with gr.Row(visible=True): + wink.render() + eyebrow.render() + eyeball_direction_x.render() + eyeball_direction_y.render() with gr.Row(visible=True): with gr.Column(): with gr.Accordion(open=True, label="Retargeting Image Input"): @@ -279,21 +318,16 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San ) with gr.Column(): with gr.Accordion(open=True, label="Retargeting Result"): - output_image.render() + retargeting_output_image.render() with gr.Column(): with gr.Accordion(open=True, label="Paste-back Result"): - output_image_paste_back.render() + retargeting_output_image_paste_back.render() with gr.Row(visible=True): process_button_reset_retargeting = gr.ClearButton( [ - eye_retargeting_slider, - lip_retargeting_slider, - head_pitch_slider, - head_yaw_slider, - head_roll_slider, retargeting_input_image, - output_image, - output_image_paste_back + retargeting_output_image, + retargeting_output_image_paste_back ], value="🧹 Clear" ) @@ -306,7 +340,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San video_lip_retargeting_slider.render() driving_smooth_observation_variance_retargeting.render() with gr.Row(visible=True): - process_button_retargeting_video = gr.Button("πŸ„ Retargeting Video", variant="primary") + process_button_retargeting_video = gr.Button("πŸš— Retargeting Video", variant="primary") with gr.Row(visible=True): with gr.Column(): with gr.Accordion(open=True, label="Retargeting Video Input"): @@ -369,17 +403,22 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San retargeting_input_image.change( fn=gradio_pipeline.init_retargeting_image, - inputs=[retargeting_source_scale, retargeting_input_image], + inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image], outputs=[eye_retargeting_slider, lip_retargeting_slider] ) - process_button_retargeting.click( - # fn=gradio_pipeline.execute_image, - fn=gpu_wrapped_execute_image_retargeting, - inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input_retargeting_image], - outputs=[output_image, output_image_paste_back], - show_progress=True - ) + sliders = [eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y] + for slider in sliders: + # NOTE: gradio >= 4.0.0 may cause slow response + slider.change( + fn=gpu_wrapped_execute_image_retargeting, + inputs=[ + eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, + lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y, + retargeting_input_image, retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image + ], + outputs=[retargeting_output_image, retargeting_output_image_paste_back], + ) process_button_retargeting_video.click( fn=gpu_wrapped_execute_video_retargeting, diff --git a/assets/gradio/gradio_description_retargeting.md b/assets/gradio/gradio_description_retargeting.md index 5978d08..14fc0c4 100644 --- a/assets/gradio/gradio_description_retargeting.md +++ b/assets/gradio/gradio_description_retargeting.md @@ -7,7 +7,7 @@

Retargeting Image

-

Upload a Source Portrait as Retargeting Input, then drag the sliders and click the πŸš— Retargeting Image button. You can try running it multiple times. +

Upload a Source Portrait as Retargeting Input, wait for the target eyes-open ratio and target lip-open ratio to be calculated, and then drag the sliders. You can try running it multiple times.
😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!

diff --git a/assets/gradio/gradio_description_retargeting_video.md b/assets/gradio/gradio_description_retargeting_video.md index e54be9e..9d6bb8c 100644 --- a/assets/gradio/gradio_description_retargeting_video.md +++ b/assets/gradio/gradio_description_retargeting_video.md @@ -2,7 +2,7 @@

Retargeting Video

-

Upload a Source Video as Retargeting Input, then drag the sliders and click the πŸ„ Retargeting Video button. You can try running it multiple times. +

Upload a Source Video as Retargeting Input, then drag the sliders and click the πŸš— Retargeting Video button. You can try running it multiple times.
🀐 Set target lip-open ratio to 0 to see what's going on!

diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index 92e72dd..b3fd405 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -43,6 +43,104 @@ class GradioPipeline(LivePortraitPipeline): # self.live_portrait_wrapper = self.live_portrait_wrapper self.args = args + @torch.no_grad() + def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs): + if eyeball_direction_x > 0: + delta_new[0, 11, 0] += eyeball_direction_x * 0.0007 + delta_new[0, 15, 0] += eyeball_direction_x * 0.001 + else: + delta_new[0, 11, 0] += eyeball_direction_x * 0.001 + delta_new[0, 15, 0] += eyeball_direction_x * 0.0007 + + delta_new[0, 11, 1] += eyeball_direction_y * -0.001 + delta_new[0, 15, 1] += eyeball_direction_y * -0.001 + blink = -eyeball_direction_y / 2. + + delta_new[0, 11, 1] += blink * -0.001 + delta_new[0, 13, 1] += blink * 0.0003 + delta_new[0, 15, 1] += blink * -0.001 + delta_new[0, 16, 1] += blink * 0.0003 + + return delta_new + + @torch.no_grad() + def update_delta_new_smile(self, smile, delta_new, **kwargs): + delta_new[0, 20, 1] += smile * -0.01 + delta_new[0, 14, 1] += smile * -0.02 + delta_new[0, 17, 1] += smile * 0.0065 + delta_new[0, 17, 2] += smile * 0.003 + delta_new[0, 13, 1] += smile * -0.00275 + delta_new[0, 16, 1] += smile * -0.00275 + delta_new[0, 3, 1] += smile * -0.0035 + delta_new[0, 7, 1] += smile * -0.0035 + + return delta_new + + @torch.no_grad() + def update_delta_new_wink(self, wink, delta_new, **kwargs): + delta_new[0, 11, 1] += wink * 0.001 + delta_new[0, 13, 1] += wink * -0.0003 + delta_new[0, 17, 0] += wink * 0.0003 + delta_new[0, 17, 1] += wink * 0.0003 + delta_new[0, 3, 1] += wink * -0.0003 + + return delta_new + + @torch.no_grad() + def update_delta_new_eyebrow(self, eyebrow, delta_new, **kwargs): + if eyebrow > 0: + delta_new[0, 1, 1] += eyebrow * 0.001 + delta_new[0, 2, 1] += eyebrow * -0.001 + else: + delta_new[0, 1, 0] += eyebrow * -0.001 + delta_new[0, 2, 0] += eyebrow * 0.001 + delta_new[0, 1, 1] += eyebrow * 0.0003 + delta_new[0, 2, 1] += eyebrow * -0.0003 + return delta_new + + @torch.no_grad() + def update_delta_new_lip_variation_zero(self, lip_variation_zero, delta_new, **kwargs): + delta_new[0, 19, 0] += lip_variation_zero + + return delta_new + + @torch.no_grad() + def update_delta_new_lip_variation_one(self, lip_variation_one, delta_new, **kwargs): + delta_new[0, 14, 1] += lip_variation_one * 0.001 + delta_new[0, 3, 1] += lip_variation_one * -0.0005 + delta_new[0, 7, 1] += lip_variation_one * -0.0005 + delta_new[0, 17, 2] += lip_variation_one * -0.0005 + + return delta_new + + @torch.no_grad() + def update_delta_new_lip_variation_two(self, lip_variation_two, delta_new, **kwargs): + delta_new[0, 20, 2] += lip_variation_two * -0.001 + delta_new[0, 20, 1] += lip_variation_two * -0.001 + delta_new[0, 14, 1] += lip_variation_two * -0.001 + + return delta_new + + @torch.no_grad() + def update_delta_new_lip_variation_three(self, lip_variation_three, delta_new, **kwargs): + delta_new[0, 19, 1] += lip_variation_three * 0.001 + delta_new[0, 19, 2] += lip_variation_three * 0.0001 + delta_new[0, 17, 1] += lip_variation_three * -0.0001 + + return delta_new + + @torch.no_grad() + def update_delta_new_mov_x(self, mov_x, delta_new, **kwargs): + delta_new[0, 5, 0] += mov_x + + return delta_new + + @torch.no_grad() + def update_delta_new_mov_y(self, mov_y, delta_new, **kwargs): + delta_new[0, 5, 1] += mov_y + + return delta_new + @torch.no_grad() def execute_video( self, @@ -112,14 +210,37 @@ class GradioPipeline(LivePortraitPipeline): raise gr.Error("Please upload the source portrait or source video, and driving video πŸ€—πŸ€—πŸ€—", duration=5) @torch.no_grad() - def execute_image_retargeting(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop_input_retargeting_image=True): + def execute_image_retargeting( + self, + input_eye_ratio: float, + input_lip_ratio: float, + input_head_pitch_variation: float, + input_head_yaw_variation: float, + input_head_roll_variation: float, + mov_x: float, + mov_y: float, + mov_z: float, + lip_variation_zero: float, + lip_variation_one: float, + lip_variation_two: float, + lip_variation_three: float, + smile: float, + wink: float, + eyebrow: float, + eyeball_direction_x: float, + eyeball_direction_y: float, + input_image, + retargeting_source_scale: float, + flag_stitching_retargeting_input=True, + flag_do_crop_input_retargeting_image=True): """ for single image retargeting """ if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None: raise gr.Error("Invalid relative pose input πŸ’₯!", duration=5) # disposable feature f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \ - self.prepare_retargeting_image(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image) + self.prepare_retargeting_image( + input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image) if input_eye_ratio is None or input_lip_ratio is None: raise gr.Error("Invalid ratio input πŸ’₯!", duration=5) @@ -130,6 +251,18 @@ class GradioPipeline(LivePortraitPipeline): f_s_user = f_s_user.to(device) R_s_user = R_s_user.to(device) R_d_user = R_d_user.to(device) + mov_x = torch.tensor(mov_x).to(device) + mov_y = torch.tensor(mov_y).to(device) + mov_z = torch.tensor(mov_z).to(device) + eyeball_direction_x = torch.tensor(eyeball_direction_x).to(device) + eyeball_direction_y = torch.tensor(eyeball_direction_y).to(device) + smile = torch.tensor(smile).to(device) + wink = torch.tensor(wink).to(device) + eyebrow = torch.tensor(eyebrow).to(device) + lip_variation_zero = torch.tensor(lip_variation_zero).to(device) + lip_variation_one = torch.tensor(lip_variation_one).to(device) + lip_variation_two = torch.tensor(lip_variation_two).to(device) + lip_variation_three = torch.tensor(lip_variation_three).to(device) x_c_s = x_s_info['kp'].to(device) delta_new = x_s_info['exp'].to(device) @@ -137,27 +270,56 @@ class GradioPipeline(LivePortraitPipeline): t_new = x_s_info['t'].to(device) R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user - x_d_new = scale_new * (x_c_s @ R_d_new + delta_new) + t_new - # βˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) - combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user) - eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor) - # βˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) - combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) - lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) - x_d_new = x_d_new + eyes_delta + lip_delta - x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new) - # D(W(f_s; x_s, xβ€²_d)) + if eyeball_direction_x != 0 or eyeball_direction_y != 0: + delta_new = self.update_delta_new_eyeball_direction(eyeball_direction_x, eyeball_direction_y, delta_new) + if smile != 0: + delta_new = self.update_delta_new_smile(smile, delta_new) + if wink != 0: + delta_new = self.update_delta_new_wink(wink, delta_new) + if eyebrow != 0: + delta_new = self.update_delta_new_eyebrow(eyebrow, delta_new) + if lip_variation_zero != 0: + delta_new = self.update_delta_new_lip_variation_zero(lip_variation_zero, delta_new) + if lip_variation_one != 0: + delta_new = self.update_delta_new_lip_variation_one(lip_variation_one, delta_new) + if lip_variation_two != 0: + delta_new = self.update_delta_new_lip_variation_two(lip_variation_two, delta_new) + if lip_variation_three != 0: + delta_new = self.update_delta_new_lip_variation_three(lip_variation_three, delta_new) + if mov_x != 0: + delta_new = self.update_delta_new_mov_x(-mov_x, delta_new) + if mov_y !=0 : + delta_new = self.update_delta_new_mov_y(mov_y, delta_new) + + x_d_new = mov_z * scale_new * (x_c_s @ R_d_new + delta_new) + t_new + eyes_delta, lip_delta = None, None + if input_eye_ratio != self.source_eye_ratio: + combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user) + eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor) + if input_lip_ratio != self.source_lip_ratio: + combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) + lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) + x_d_new = x_d_new + \ + (eyes_delta if eyes_delta is not None else 0) + \ + (lip_delta if lip_delta is not None else 0) + + if flag_stitching_retargeting_input: + x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new) out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new) out = self.live_portrait_wrapper.parse_output(out['out'])[0] if flag_do_crop_input_retargeting_image: out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori) else: out_to_ori_blend = out - gr.Info("Run successfully!", duration=2) return out, out_to_ori_blend @torch.no_grad() - def prepare_retargeting_image(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True): + def prepare_retargeting_image( + self, + input_image, + input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, + retargeting_source_scale, + flag_do_crop=True): """ for single image retargeting """ if input_image is not None: @@ -168,7 +330,6 @@ class GradioPipeline(LivePortraitPipeline): inference_cfg = self.live_portrait_wrapper.inference_cfg ######## process source portrait ######## img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2) - log(f"Load source image from {input_image}.") if flag_do_crop: crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg) I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) @@ -181,27 +342,27 @@ class GradioPipeline(LivePortraitPipeline): crop_M_c2o = None mask_ori = None x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) - x_s_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation - x_s_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation - x_s_info_user_roll = x_s_info['roll'] + input_head_roll_variation + x_d_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation + x_d_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation + x_d_info_user_roll = x_s_info['roll'] + input_head_roll_variation R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) - R_d_user = get_rotation_matrix(x_s_info_user_pitch, x_s_info_user_yaw, x_s_info_user_roll) + R_d_user = get_rotation_matrix(x_d_info_user_pitch, x_d_info_user_yaw, x_d_info_user_roll) ############################################ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb else: - # when press the clear button, go here raise gr.Error("Please upload a source portrait as the retargeting input πŸ€—πŸ€—πŸ€—", duration=5) - def init_retargeting_image(self, retargeting_source_scale: float, input_image = None): + @torch.no_grad() + def init_retargeting_image(self, retargeting_source_scale: float, source_eye_ratio: float, source_lip_ratio:float, input_image = None): """ initialize the retargeting slider """ if input_image != None: args_user = {'scale': retargeting_source_scale} self.args = update_args(self.args, args_user) self.cropper.update_config(self.args.__dict__) - inference_cfg = self.live_portrait_wrapper.inference_cfg + # inference_cfg = self.live_portrait_wrapper.inference_cfg ######## process source portrait ######## img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16) log(f"Load source image from {input_image}.") @@ -210,9 +371,14 @@ class GradioPipeline(LivePortraitPipeline): raise gr.Error("Source portrait NO face detected", duration=2) source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None]) source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None]) - return round(float(source_eye_ratio.mean()), 2), round(source_lip_ratio[0][0], 2) - return 0., 0. + self.source_eye_ratio = round(float(source_eye_ratio.mean()), 2) + self.source_lip_ratio = round(float(source_lip_ratio[0][0]), 2) + log("Calculating eyes-open and lip-open ratios successfully!") + return self.source_eye_ratio, self.source_lip_ratio + else: + return source_eye_ratio, source_lip_ratio + @torch.no_grad() def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True): """ retargeting the lip-open ratio of each source frame """ @@ -277,6 +443,7 @@ class GradioPipeline(LivePortraitPipeline): gr.Info("Run successfully!", duration=2) return wfp_concat, wfp + @torch.no_grad() def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True): """ for video retargeting """