support upload motion template to fast infer

This commit is contained in:
iflamed 2024-08-08 18:22:06 +08:00
parent b95d7b60ec
commit 36b3ec1770
2 changed files with 43 additions and 16 deletions

52
app.py
View File

@ -180,20 +180,40 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01) vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="Driving Video"): with gr.Tabs():
driving_video_input = gr.Video() with gr.TabItem("🎞️ Driving Video") as v_tab_video:
gr.Examples( with gr.Accordion(open=True, label="Driving Video"):
examples=[ driving_video_input = gr.Video()
[osp.join(example_video_dir, "d0.mp4")], gr.Examples(
[osp.join(example_video_dir, "d18.mp4")], examples=[
[osp.join(example_video_dir, "d19.mp4")], [osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d14.mp4")], [osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d6.mp4")], [osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d20.mp4")], [osp.join(example_video_dir, "d14.mp4")],
], [osp.join(example_video_dir, "d6.mp4")],
inputs=[driving_video_input], [osp.join(example_video_dir, "d20.mp4")],
cache_examples=False, ],
) inputs=[driving_video_input],
cache_examples=False,
)
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
gr.Examples(
examples=[
[osp.join(example_video_dir, "d1.pkl")],
[osp.join(example_video_dir, "d2.pkl")],
[osp.join(example_video_dir, "d5.pkl")],
[osp.join(example_video_dir, "d7.pkl")],
[osp.join(example_video_dir, "d8.pkl")],
],
inputs=[driving_video_pickle_input],
cache_examples=False,
)
v_tab_selection = gr.Textbox(visible=False)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
v_tab_video.select(lambda: "Video", None, v_tab_selection)
# with gr.Accordion(open=False, label="Animation Instructions"): # with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md")) # gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"): with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
@ -225,7 +245,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Accordion(open=True, label="The animated video"): with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render() output_video_concat_i2v.render()
with gr.Row(): with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear") process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
with gr.Row(): with gr.Row():
# Examples # Examples
@ -393,6 +413,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[ inputs=[
source_image_input, source_image_input,
source_video_input, source_video_input,
driving_video_pickle_input,
driving_video_input, driving_video_input,
flag_relative_input, flag_relative_input,
flag_do_crop_input, flag_do_crop_input,
@ -410,6 +431,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
vy_ratio_crop_driving_video, vy_ratio_crop_driving_video,
driving_smooth_observation_variance, driving_smooth_observation_variance,
tab_selection, tab_selection,
v_tab_selection,
], ],
outputs=[output_video_i2v, output_video_concat_i2v], outputs=[output_video_i2v, output_video_concat_i2v],
show_progress=True show_progress=True

View File

@ -146,6 +146,7 @@ class GradioPipeline(LivePortraitPipeline):
self, self,
input_source_image_path=None, input_source_image_path=None,
input_source_video_path=None, input_source_video_path=None,
input_driving_video_pickle_path=None,
input_driving_video_path=None, input_driving_video_path=None,
flag_relative_input=True, flag_relative_input=True,
flag_do_crop_input=True, flag_do_crop_input=True,
@ -163,6 +164,7 @@ class GradioPipeline(LivePortraitPipeline):
vy_ratio_crop_driving_video=-0.1, vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=3e-7, driving_smooth_observation_variance=3e-7,
tab_selection=None, tab_selection=None,
v_tab_selection=None
): ):
""" for video-driven potrait animation or video editing """ for video-driven potrait animation or video editing
""" """
@ -173,8 +175,11 @@ class GradioPipeline(LivePortraitPipeline):
else: else:
input_source_path = input_source_image_path input_source_path = input_source_image_path
if v_tab_selection == 'Pickle' and input_driving_video_pickle_path is not None:
input_driving_video_path = input_driving_video_pickle_path
if input_source_path is not None and input_driving_video_path is not None: if input_source_path is not None and input_driving_video_path is not None:
if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False: if v_tab_selection != 'Pickle' and osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False:
flag_crop_driving_video_input = True flag_crop_driving_video_input = True
log("The driving video is not square, it will be cropped to square automatically.") log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2) gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)