From f24b6ff5d1e3d49c23dd852c57bd1d8cf7c585cd Mon Sep 17 00:00:00 2001 From: New Bing Date: Fri, 9 Aug 2024 13:00:03 +0800 Subject: [PATCH] feat: support human template upload for Gradio (#302) * support upload motion template to fast infer * Update gradio_pipeline.py --------- Co-authored-by: Mystery099 <164347012+Mystery099@users.noreply.github.com> --- app.py | 52 ++++++++++++++++++++++++++++++------------ src/gradio_pipeline.py | 17 ++++++++++---- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/app.py b/app.py index b0d81dd..35a57f2 100644 --- a/app.py +++ b/app.py @@ -180,20 +180,40 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01) with gr.Column(): - with gr.Accordion(open=True, label="Driving Video"): - driving_video_input = gr.Video() - gr.Examples( - examples=[ - [osp.join(example_video_dir, "d0.mp4")], - [osp.join(example_video_dir, "d18.mp4")], - [osp.join(example_video_dir, "d19.mp4")], - [osp.join(example_video_dir, "d14.mp4")], - [osp.join(example_video_dir, "d6.mp4")], - [osp.join(example_video_dir, "d20.mp4")], - ], - inputs=[driving_video_input], - cache_examples=False, - ) + with gr.Tabs(): + with gr.TabItem("๐ŸŽž๏ธ Driving Video") as v_tab_video: + with gr.Accordion(open=True, label="Driving Video"): + driving_video_input = gr.Video() + gr.Examples( + examples=[ + [osp.join(example_video_dir, "d0.mp4")], + [osp.join(example_video_dir, "d18.mp4")], + [osp.join(example_video_dir, "d19.mp4")], + [osp.join(example_video_dir, "d14.mp4")], + [osp.join(example_video_dir, "d6.mp4")], + [osp.join(example_video_dir, "d20.mp4")], + ], + inputs=[driving_video_input], + cache_examples=False, + ) + with gr.TabItem("๐Ÿ“ Driving Pickle") as v_tab_pickle: + with gr.Accordion(open=True, label="Driving Pickle"): + driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"]) + gr.Examples( + examples=[ + [osp.join(example_video_dir, "d1.pkl")], + [osp.join(example_video_dir, "d2.pkl")], + [osp.join(example_video_dir, "d5.pkl")], + [osp.join(example_video_dir, "d7.pkl")], + [osp.join(example_video_dir, "d8.pkl")], + ], + inputs=[driving_video_pickle_input], + cache_examples=False, + ) + + v_tab_selection = gr.Textbox(visible=False) + v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection) + v_tab_video.select(lambda: "Video", None, v_tab_selection) # with gr.Accordion(open=False, label="Animation Instructions"): # gr.Markdown(load_description("assets/gradio/gradio_description_animation.md")) with gr.Accordion(open=True, label="Cropping Options for Driving Video"): @@ -225,7 +245,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San with gr.Accordion(open=True, label="The animated video"): output_video_concat_i2v.render() with gr.Row(): - process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="๐Ÿงน Clear") + process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="๐Ÿงน Clear") with gr.Row(): # Examples @@ -393,6 +413,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San inputs=[ source_image_input, source_video_input, + driving_video_pickle_input, driving_video_input, flag_relative_input, flag_do_crop_input, @@ -410,6 +431,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San vy_ratio_crop_driving_video, driving_smooth_observation_variance, tab_selection, + v_tab_selection, ], outputs=[output_video_i2v, output_video_concat_i2v], show_progress=True diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index b3fd405..38dcde5 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -146,6 +146,7 @@ class GradioPipeline(LivePortraitPipeline): self, input_source_image_path=None, input_source_video_path=None, + input_driving_video_pickle_path=None, input_driving_video_path=None, flag_relative_input=True, flag_do_crop_input=True, @@ -163,8 +164,9 @@ class GradioPipeline(LivePortraitPipeline): vy_ratio_crop_driving_video=-0.1, driving_smooth_observation_variance=3e-7, tab_selection=None, + v_tab_selection=None ): - """ for video-driven potrait animation or video editing + """ for video-driven portrait animation or video editing """ if tab_selection == 'Image': input_source_path = input_source_image_path @@ -173,15 +175,22 @@ class GradioPipeline(LivePortraitPipeline): else: input_source_path = input_source_image_path - if input_source_path is not None and input_driving_video_path is not None: - if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False: + if v_tab_selection == 'Video': + input_driving_path = input_driving_video_path + elif v_tab_selection == 'Pickle': + input_driving_path = input_driving_video_pickle_path + else: + input_driving_path = input_driving_video_path + + if input_source_path is not None and input_driving_path is not None: + if osp.exists(input_driving_path) and v_tab_selection == 'Video' and is_square_video(input_driving_path) is False: flag_crop_driving_video_input = True log("The driving video is not square, it will be cropped to square automatically.") gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2) args_user = { 'source': input_source_path, - 'driving': input_driving_video_path, + 'driving': input_driving_path, 'flag_relative_motion': flag_relative_input, 'flag_do_crop': flag_do_crop_input, 'flag_pasteback': flag_remap_input,