feat: edit

This commit is contained in:
zhangdingyun 2024-08-06 20:23:57 +08:00 committed by guojianzhu
parent f01202e9a9
commit 7c42708695
4 changed files with 251 additions and 45 deletions

77
app.py
View File

@ -97,10 +97,24 @@ video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, labe
head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch") head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch")
head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw") head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw")
head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll") head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll")
mov_x = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="x-axis movement")
mov_y = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="y-axis movement")
mov_z = gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.01, label="z-axis movement")
lip_variation_zero = gr.Slider(minimum=-0.09, maximum=0.09, value=0, step=0.01, label="pouting")
lip_variation_one = gr.Slider(minimum=-20.0, maximum=15.0, value=0, step=0.01, label="lip compressed<->pursing")
lip_variation_two = gr.Slider(minimum=0.0, maximum=15.0, value=0, step=0.01, label="grin😬")
lip_variation_three = gr.Slider(minimum=-90.0, maximum=120.0, value=0, step=1.0, label="lip close <-> lip open")
smile = gr.Slider(minimum=-0.3, maximum=1.3, value=0, step=0.01, label="smile")
wink = gr.Slider(minimum=0, maximum=39, value=0, step=0.01, label="wink")
eyebrow = gr.Slider(minimum=-30, maximum=30, value=0, step=0.01, label="eyebrow")
eyeball_direction_x = gr.Slider(minimum=-30.0, maximum=30.0, value=0, step=0.01, label="eye gaze (horizontal)")
eyeball_direction_y = gr.Slider(minimum=-63.0, maximum=63.0, value=0, step=0.01, label="eye gaze (vertical)")
retargeting_input_image = gr.Image(type="filepath") retargeting_input_image = gr.Image(type="filepath")
retargeting_input_video = gr.Video() retargeting_input_video = gr.Video()
output_image = gr.Image(type="numpy") output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy")
retargeting_output_image = gr.Image(type="numpy")
retargeting_output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video(autoplay=False) output_video = gr.Video(autoplay=False)
output_video_paste_back = gr.Video(autoplay=False) output_video_paste_back = gr.Video(autoplay=False)
output_video_i2v = gr.Video(autoplay=False) output_video_i2v = gr.Video(autoplay=False)
@ -250,15 +264,40 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True) gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True)
with gr.Row(visible=True): with gr.Row(visible=True):
flag_do_crop_input_retargeting_image = gr.Checkbox(value=True, label="do crop (source)") flag_do_crop_input_retargeting_image = gr.Checkbox(value=True, label="do crop (source)")
flag_stitching_retargeting_input = gr.Checkbox(value=True, label="stitching")
retargeting_source_scale.render() retargeting_source_scale.render()
eye_retargeting_slider.render() eye_retargeting_slider.render()
lip_retargeting_slider.render() lip_retargeting_slider.render()
gr.Markdown(
"""
<div style="text-align: center;">
<h5>Face movement sliders</h5>
</div>
""")
with gr.Row(visible=True): with gr.Row(visible=True):
head_pitch_slider.render() head_pitch_slider.render()
head_yaw_slider.render() head_yaw_slider.render()
head_roll_slider.render() head_roll_slider.render()
mov_x.render()
mov_y.render()
mov_z.render()
gr.Markdown(
"""
<div style="text-align: center;">
<h5>Expression blendshape sliders</h5>
</div>
""")
with gr.Row(visible=True): with gr.Row(visible=True):
process_button_retargeting = gr.Button("🚗 Retargeting Image", variant="primary") lip_variation_zero.render()
lip_variation_one.render()
lip_variation_two.render()
lip_variation_three.render()
smile.render()
with gr.Row(visible=True):
wink.render()
eyebrow.render()
eyeball_direction_x.render()
eyeball_direction_y.render()
with gr.Row(visible=True): with gr.Row(visible=True):
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Retargeting Image Input"): with gr.Accordion(open=True, label="Retargeting Image Input"):
@ -279,21 +318,16 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
) )
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Retargeting Result"): with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render() retargeting_output_image.render()
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"): with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render() retargeting_output_image_paste_back.render()
with gr.Row(visible=True): with gr.Row(visible=True):
process_button_reset_retargeting = gr.ClearButton( process_button_reset_retargeting = gr.ClearButton(
[ [
eye_retargeting_slider,
lip_retargeting_slider,
head_pitch_slider,
head_yaw_slider,
head_roll_slider,
retargeting_input_image, retargeting_input_image,
output_image, retargeting_output_image,
output_image_paste_back retargeting_output_image_paste_back
], ],
value="🧹 Clear" value="🧹 Clear"
) )
@ -306,7 +340,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
video_lip_retargeting_slider.render() video_lip_retargeting_slider.render()
driving_smooth_observation_variance_retargeting.render() driving_smooth_observation_variance_retargeting.render()
with gr.Row(visible=True): with gr.Row(visible=True):
process_button_retargeting_video = gr.Button("🍄 Retargeting Video", variant="primary") process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
with gr.Row(visible=True): with gr.Row(visible=True):
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Retargeting Video Input"): with gr.Accordion(open=True, label="Retargeting Video Input"):
@ -369,17 +403,22 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
retargeting_input_image.change( retargeting_input_image.change(
fn=gradio_pipeline.init_retargeting_image, fn=gradio_pipeline.init_retargeting_image,
inputs=[retargeting_source_scale, retargeting_input_image], inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image],
outputs=[eye_retargeting_slider, lip_retargeting_slider] outputs=[eye_retargeting_slider, lip_retargeting_slider]
) )
process_button_retargeting.click( sliders = [eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y]
# fn=gradio_pipeline.execute_image, for slider in sliders:
fn=gpu_wrapped_execute_image_retargeting, # NOTE: gradio >= 4.0.0 may cause slow response
inputs=[eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, retargeting_input_image, retargeting_source_scale, flag_do_crop_input_retargeting_image], slider.change(
outputs=[output_image, output_image_paste_back], fn=gpu_wrapped_execute_image_retargeting,
show_progress=True inputs=[
) eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z,
lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y,
retargeting_input_image, retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image
],
outputs=[retargeting_output_image, retargeting_output_image_paste_back],
)
process_button_retargeting_video.click( process_button_retargeting_video.click(
fn=gpu_wrapped_execute_video_retargeting, fn=gpu_wrapped_execute_video_retargeting,

View File

@ -7,7 +7,7 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
<div> <div>
<h2>Retargeting Image</h2> <h2>Retargeting Image</h2>
<p>Upload a Source Portrait as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting Image</strong> button. You can try running it multiple times. <p>Upload a Source Portrait as Retargeting Input, wait for the <code>target eyes-open ratio</code> and <code>target lip-open ratio</code> to be calculated, and then drag the sliders. You can try running it multiple times.
<br> <br>
<strong>😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!</strong></p> <strong>😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!</strong></p>
</div> </div>

View File

@ -2,7 +2,7 @@
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
<div> <div>
<h2>Retargeting Video</h2> <h2>Retargeting Video</h2>
<p>Upload a Source Video as Retargeting Input, then drag the sliders and click the <strong>🍄 Retargeting Video</strong> button. You can try running it multiple times. <p>Upload a Source Video as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting Video</strong> button. You can try running it multiple times.
<br> <br>
<strong>🤐 Set target lip-open ratio to 0 to see what's going on!</strong></p> <strong>🤐 Set target lip-open ratio to 0 to see what's going on!</strong></p>
</div> </div>

View File

@ -43,6 +43,104 @@ class GradioPipeline(LivePortraitPipeline):
# self.live_portrait_wrapper = self.live_portrait_wrapper # self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args self.args = args
@torch.no_grad()
def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs):
if eyeball_direction_x > 0:
delta_new[0, 11, 0] += eyeball_direction_x * 0.0007
delta_new[0, 15, 0] += eyeball_direction_x * 0.001
else:
delta_new[0, 11, 0] += eyeball_direction_x * 0.001
delta_new[0, 15, 0] += eyeball_direction_x * 0.0007
delta_new[0, 11, 1] += eyeball_direction_y * -0.001
delta_new[0, 15, 1] += eyeball_direction_y * -0.001
blink = -eyeball_direction_y / 2.
delta_new[0, 11, 1] += blink * -0.001
delta_new[0, 13, 1] += blink * 0.0003
delta_new[0, 15, 1] += blink * -0.001
delta_new[0, 16, 1] += blink * 0.0003
return delta_new
@torch.no_grad()
def update_delta_new_smile(self, smile, delta_new, **kwargs):
delta_new[0, 20, 1] += smile * -0.01
delta_new[0, 14, 1] += smile * -0.02
delta_new[0, 17, 1] += smile * 0.0065
delta_new[0, 17, 2] += smile * 0.003
delta_new[0, 13, 1] += smile * -0.00275
delta_new[0, 16, 1] += smile * -0.00275
delta_new[0, 3, 1] += smile * -0.0035
delta_new[0, 7, 1] += smile * -0.0035
return delta_new
@torch.no_grad()
def update_delta_new_wink(self, wink, delta_new, **kwargs):
delta_new[0, 11, 1] += wink * 0.001
delta_new[0, 13, 1] += wink * -0.0003
delta_new[0, 17, 0] += wink * 0.0003
delta_new[0, 17, 1] += wink * 0.0003
delta_new[0, 3, 1] += wink * -0.0003
return delta_new
@torch.no_grad()
def update_delta_new_eyebrow(self, eyebrow, delta_new, **kwargs):
if eyebrow > 0:
delta_new[0, 1, 1] += eyebrow * 0.001
delta_new[0, 2, 1] += eyebrow * -0.001
else:
delta_new[0, 1, 0] += eyebrow * -0.001
delta_new[0, 2, 0] += eyebrow * 0.001
delta_new[0, 1, 1] += eyebrow * 0.0003
delta_new[0, 2, 1] += eyebrow * -0.0003
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_zero(self, lip_variation_zero, delta_new, **kwargs):
delta_new[0, 19, 0] += lip_variation_zero
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_one(self, lip_variation_one, delta_new, **kwargs):
delta_new[0, 14, 1] += lip_variation_one * 0.001
delta_new[0, 3, 1] += lip_variation_one * -0.0005
delta_new[0, 7, 1] += lip_variation_one * -0.0005
delta_new[0, 17, 2] += lip_variation_one * -0.0005
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_two(self, lip_variation_two, delta_new, **kwargs):
delta_new[0, 20, 2] += lip_variation_two * -0.001
delta_new[0, 20, 1] += lip_variation_two * -0.001
delta_new[0, 14, 1] += lip_variation_two * -0.001
return delta_new
@torch.no_grad()
def update_delta_new_lip_variation_three(self, lip_variation_three, delta_new, **kwargs):
delta_new[0, 19, 1] += lip_variation_three * 0.001
delta_new[0, 19, 2] += lip_variation_three * 0.0001
delta_new[0, 17, 1] += lip_variation_three * -0.0001
return delta_new
@torch.no_grad()
def update_delta_new_mov_x(self, mov_x, delta_new, **kwargs):
delta_new[0, 5, 0] += mov_x
return delta_new
@torch.no_grad()
def update_delta_new_mov_y(self, mov_y, delta_new, **kwargs):
delta_new[0, 5, 1] += mov_y
return delta_new
@torch.no_grad() @torch.no_grad()
def execute_video( def execute_video(
self, self,
@ -112,14 +210,37 @@ class GradioPipeline(LivePortraitPipeline):
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5) raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
@torch.no_grad() @torch.no_grad()
def execute_image_retargeting(self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, input_image, retargeting_source_scale: float, flag_do_crop_input_retargeting_image=True): def execute_image_retargeting(
self,
input_eye_ratio: float,
input_lip_ratio: float,
input_head_pitch_variation: float,
input_head_yaw_variation: float,
input_head_roll_variation: float,
mov_x: float,
mov_y: float,
mov_z: float,
lip_variation_zero: float,
lip_variation_one: float,
lip_variation_two: float,
lip_variation_three: float,
smile: float,
wink: float,
eyebrow: float,
eyeball_direction_x: float,
eyeball_direction_y: float,
input_image,
retargeting_source_scale: float,
flag_stitching_retargeting_input=True,
flag_do_crop_input_retargeting_image=True):
""" for single image retargeting """ for single image retargeting
""" """
if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None: if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
raise gr.Error("Invalid relative pose input 💥!", duration=5) raise gr.Error("Invalid relative pose input 💥!", duration=5)
# disposable feature # disposable feature
f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \ f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting_image(input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image) self.prepare_retargeting_image(
input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image)
if input_eye_ratio is None or input_lip_ratio is None: if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5) raise gr.Error("Invalid ratio input 💥!", duration=5)
@ -130,6 +251,18 @@ class GradioPipeline(LivePortraitPipeline):
f_s_user = f_s_user.to(device) f_s_user = f_s_user.to(device)
R_s_user = R_s_user.to(device) R_s_user = R_s_user.to(device)
R_d_user = R_d_user.to(device) R_d_user = R_d_user.to(device)
mov_x = torch.tensor(mov_x).to(device)
mov_y = torch.tensor(mov_y).to(device)
mov_z = torch.tensor(mov_z).to(device)
eyeball_direction_x = torch.tensor(eyeball_direction_x).to(device)
eyeball_direction_y = torch.tensor(eyeball_direction_y).to(device)
smile = torch.tensor(smile).to(device)
wink = torch.tensor(wink).to(device)
eyebrow = torch.tensor(eyebrow).to(device)
lip_variation_zero = torch.tensor(lip_variation_zero).to(device)
lip_variation_one = torch.tensor(lip_variation_one).to(device)
lip_variation_two = torch.tensor(lip_variation_two).to(device)
lip_variation_three = torch.tensor(lip_variation_three).to(device)
x_c_s = x_s_info['kp'].to(device) x_c_s = x_s_info['kp'].to(device)
delta_new = x_s_info['exp'].to(device) delta_new = x_s_info['exp'].to(device)
@ -137,27 +270,56 @@ class GradioPipeline(LivePortraitPipeline):
t_new = x_s_info['t'].to(device) t_new = x_s_info['t'].to(device)
R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user
x_d_new = scale_new * (x_c_s @ R_d_new + delta_new) + t_new if eyeball_direction_x != 0 or eyeball_direction_y != 0:
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) delta_new = self.update_delta_new_eyeball_direction(eyeball_direction_x, eyeball_direction_y, delta_new)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user) if smile != 0:
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor) delta_new = self.update_delta_new_smile(smile, delta_new)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) if wink != 0:
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) delta_new = self.update_delta_new_wink(wink, delta_new)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) if eyebrow != 0:
x_d_new = x_d_new + eyes_delta + lip_delta delta_new = self.update_delta_new_eyebrow(eyebrow, delta_new)
x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new) if lip_variation_zero != 0:
# D(W(f_s; x_s, x_d)) delta_new = self.update_delta_new_lip_variation_zero(lip_variation_zero, delta_new)
if lip_variation_one != 0:
delta_new = self.update_delta_new_lip_variation_one(lip_variation_one, delta_new)
if lip_variation_two != 0:
delta_new = self.update_delta_new_lip_variation_two(lip_variation_two, delta_new)
if lip_variation_three != 0:
delta_new = self.update_delta_new_lip_variation_three(lip_variation_three, delta_new)
if mov_x != 0:
delta_new = self.update_delta_new_mov_x(-mov_x, delta_new)
if mov_y !=0 :
delta_new = self.update_delta_new_mov_y(mov_y, delta_new)
x_d_new = mov_z * scale_new * (x_c_s @ R_d_new + delta_new) + t_new
eyes_delta, lip_delta = None, None
if input_eye_ratio != self.source_eye_ratio:
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
if input_lip_ratio != self.source_lip_ratio:
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
x_d_new = x_d_new + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
if flag_stitching_retargeting_input:
x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new) out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0] out = self.live_portrait_wrapper.parse_output(out['out'])[0]
if flag_do_crop_input_retargeting_image: if flag_do_crop_input_retargeting_image:
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori) out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
else: else:
out_to_ori_blend = out out_to_ori_blend = out
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend return out, out_to_ori_blend
@torch.no_grad() @torch.no_grad()
def prepare_retargeting_image(self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True): def prepare_retargeting_image(
self,
input_image,
input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation,
retargeting_source_scale,
flag_do_crop=True):
""" for single image retargeting """ for single image retargeting
""" """
if input_image is not None: if input_image is not None:
@ -168,7 +330,6 @@ class GradioPipeline(LivePortraitPipeline):
inference_cfg = self.live_portrait_wrapper.inference_cfg inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ######## ######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2) img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
log(f"Load source image from {input_image}.")
if flag_do_crop: if flag_do_crop:
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg) crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
@ -181,27 +342,27 @@ class GradioPipeline(LivePortraitPipeline):
crop_M_c2o = None crop_M_c2o = None
mask_ori = None mask_ori = None
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_s_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation x_d_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
x_s_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation x_d_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
x_s_info_user_roll = x_s_info['roll'] + input_head_roll_variation x_d_info_user_roll = x_s_info['roll'] + input_head_roll_variation
R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
R_d_user = get_rotation_matrix(x_s_info_user_pitch, x_s_info_user_yaw, x_s_info_user_roll) R_d_user = get_rotation_matrix(x_d_info_user_pitch, x_d_info_user_yaw, x_d_info_user_roll)
############################################ ############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else: else:
# when press the clear button, go here
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5) raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
def init_retargeting_image(self, retargeting_source_scale: float, input_image = None): @torch.no_grad()
def init_retargeting_image(self, retargeting_source_scale: float, source_eye_ratio: float, source_lip_ratio:float, input_image = None):
""" initialize the retargeting slider """ initialize the retargeting slider
""" """
if input_image != None: if input_image != None:
args_user = {'scale': retargeting_source_scale} args_user = {'scale': retargeting_source_scale}
self.args = update_args(self.args, args_user) self.args = update_args(self.args, args_user)
self.cropper.update_config(self.args.__dict__) self.cropper.update_config(self.args.__dict__)
inference_cfg = self.live_portrait_wrapper.inference_cfg # inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ######## ######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16) img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image}.") log(f"Load source image from {input_image}.")
@ -210,9 +371,14 @@ class GradioPipeline(LivePortraitPipeline):
raise gr.Error("Source portrait NO face detected", duration=2) raise gr.Error("Source portrait NO face detected", duration=2)
source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None]) source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None])
source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None]) source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
return round(float(source_eye_ratio.mean()), 2), round(source_lip_ratio[0][0], 2) self.source_eye_ratio = round(float(source_eye_ratio.mean()), 2)
return 0., 0. self.source_lip_ratio = round(float(source_lip_ratio[0][0]), 2)
log("Calculating eyes-open and lip-open ratios successfully!")
return self.source_eye_ratio, self.source_lip_ratio
else:
return source_eye_ratio, source_lip_ratio
@torch.no_grad()
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True): def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True):
""" retargeting the lip-open ratio of each source frame """ retargeting the lip-open ratio of each source frame
""" """
@ -277,6 +443,7 @@ class GradioPipeline(LivePortraitPipeline):
gr.Info("Run successfully!", duration=2) gr.Info("Run successfully!", duration=2)
return wfp_concat, wfp return wfp_concat, wfp
@torch.no_grad()
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True): def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
""" for video retargeting """ for video retargeting
""" """