feat: support human template upload for Gradio (#302)

* support upload motion template to fast infer

* Update gradio_pipeline.py

---------

Co-authored-by: Mystery099 <164347012+Mystery099@users.noreply.github.com>
This commit is contained in:
New Bing 2024-08-09 13:00:03 +08:00 committed by GitHub
parent b95d7b60ec
commit f24b6ff5d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 19 deletions

52
app.py
View File

@ -180,20 +180,40 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01) vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Driving Video"): with gr.Tabs():
driving_video_input = gr.Video() with gr.TabItem("🎞️ Driving Video") as v_tab_video:
gr.Examples( with gr.Accordion(open=True, label="Driving Video"):
examples=[ driving_video_input = gr.Video()
[osp.join(example_video_dir, "d0.mp4")], gr.Examples(
[osp.join(example_video_dir, "d18.mp4")], examples=[
[osp.join(example_video_dir, "d19.mp4")], [osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d14.mp4")], [osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d6.mp4")], [osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d20.mp4")], [osp.join(example_video_dir, "d14.mp4")],
], [osp.join(example_video_dir, "d6.mp4")],
inputs=[driving_video_input], [osp.join(example_video_dir, "d20.mp4")],
cache_examples=False, ],
) inputs=[driving_video_input],
cache_examples=False,
)
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
gr.Examples(
examples=[
[osp.join(example_video_dir, "d1.pkl")],
[osp.join(example_video_dir, "d2.pkl")],
[osp.join(example_video_dir, "d5.pkl")],
[osp.join(example_video_dir, "d7.pkl")],
[osp.join(example_video_dir, "d8.pkl")],
],
inputs=[driving_video_pickle_input],
cache_examples=False,
)
v_tab_selection = gr.Textbox(visible=False)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
v_tab_video.select(lambda: "Video", None, v_tab_selection)
# with gr.Accordion(open=False, label="Animation Instructions"): # with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md")) # gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"): with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
@ -225,7 +245,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Accordion(open=True, label="The animated video"): with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render() output_video_concat_i2v.render()
with gr.Row(): with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear") process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
with gr.Row(): with gr.Row():
# Examples # Examples
@ -393,6 +413,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[ inputs=[
source_image_input, source_image_input,
source_video_input, source_video_input,
driving_video_pickle_input,
driving_video_input, driving_video_input,
flag_relative_input, flag_relative_input,
flag_do_crop_input, flag_do_crop_input,
@ -410,6 +431,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
vy_ratio_crop_driving_video, vy_ratio_crop_driving_video,
driving_smooth_observation_variance, driving_smooth_observation_variance,
tab_selection, tab_selection,
v_tab_selection,
], ],
outputs=[output_video_i2v, output_video_concat_i2v], outputs=[output_video_i2v, output_video_concat_i2v],
show_progress=True show_progress=True

View File

@ -146,6 +146,7 @@ class GradioPipeline(LivePortraitPipeline):
self, self,
input_source_image_path=None, input_source_image_path=None,
input_source_video_path=None, input_source_video_path=None,
input_driving_video_pickle_path=None,
input_driving_video_path=None, input_driving_video_path=None,
flag_relative_input=True, flag_relative_input=True,
flag_do_crop_input=True, flag_do_crop_input=True,
@ -163,8 +164,9 @@ class GradioPipeline(LivePortraitPipeline):
vy_ratio_crop_driving_video=-0.1, vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=3e-7, driving_smooth_observation_variance=3e-7,
tab_selection=None, tab_selection=None,
v_tab_selection=None
): ):
""" for video-driven potrait animation or video editing """ for video-driven portrait animation or video editing
""" """
if tab_selection == 'Image': if tab_selection == 'Image':
input_source_path = input_source_image_path input_source_path = input_source_image_path
@ -173,15 +175,22 @@ class GradioPipeline(LivePortraitPipeline):
else: else:
input_source_path = input_source_image_path input_source_path = input_source_image_path
if input_source_path is not None and input_driving_video_path is not None: if v_tab_selection == 'Video':
if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False: input_driving_path = input_driving_video_path
elif v_tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path
else:
input_driving_path = input_driving_video_path
if input_source_path is not None and input_driving_path is not None:
if osp.exists(input_driving_path) and v_tab_selection == 'Video' and is_square_video(input_driving_path) is False:
flag_crop_driving_video_input = True flag_crop_driving_video_input = True
log("The driving video is not square, it will be cropped to square automatically.") log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2) gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
args_user = { args_user = {
'source': input_source_path, 'source': input_source_path,
'driving': input_driving_video_path, 'driving': input_driving_path,
'flag_relative_motion': flag_relative_input, 'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input, 'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input, 'flag_pasteback': flag_remap_input,