diff --git a/app.py b/app.py index af15991..f618263 100644 --- a/app.py +++ b/app.py @@ -26,18 +26,20 @@ def fast_check_ffmpeg(): except: return False + # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) if not fast_check_ffmpeg(): raise ImportError( - "FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html" + "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" ) # specify configs for inference inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig +# global_tab_selection = None gradio_pipeline = GradioPipeline( inference_cfg=inference_cfg, @@ -55,10 +57,10 @@ def gpu_wrapped_execute_image(*args, **kwargs): # assets -title_md = "assets/gradio_title.md" +title_md = "assets/gradio/gradio_title.md" example_portrait_dir = "assets/examples/source" example_video_dir = "assets/examples/driving" -data_examples = [ +data_examples_i2v = [ [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], @@ -66,6 +68,14 @@ data_examples = [ [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True], ] +data_examples_v2v = [ + [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-6], + # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-6], + # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-6], + [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-6], + # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-6], + [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-6], +] #################### interface logic #################### # Define components first @@ -74,80 +84,148 @@ lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="tar retargeting_input_image = gr.Image(type="filepath") output_image = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy") -output_video = gr.Video() -output_video_concat = gr.Video() +output_video_i2v = gr.Video(autoplay=True) +output_video_concat_i2v = gr.Video(autoplay=True) +output_video_v2v = gr.Video(autoplay=True) +output_video_concat_v2v = gr.Video(autoplay=True) -with gr.Blocks(theme=gr.themes.Soft()) as demo: + +with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo: gr.HTML(load_description(title_md)) - gr.Markdown(load_description("assets/gradio_description_upload.md")) + + gr.Markdown(load_description("assets/gradio/gradio_description_upload.md")) with gr.Row(): - with gr.Accordion(open=True, label="Source Portrait"): - image_input = gr.Image(type="filepath") - gr.Examples( - examples=[ - [osp.join(example_portrait_dir, "s9.jpg")], - [osp.join(example_portrait_dir, "s6.jpg")], - [osp.join(example_portrait_dir, "s10.jpg")], - [osp.join(example_portrait_dir, "s5.jpg")], - [osp.join(example_portrait_dir, "s7.jpg")], - [osp.join(example_portrait_dir, "s12.jpg")], - ], - inputs=[image_input], - cache_examples=False, - ) - with gr.Accordion(open=True, label="Driving Video"): - video_input = gr.Video() - gr.Examples( - examples=[ - [osp.join(example_video_dir, "d0.mp4")], - [osp.join(example_video_dir, "d18.mp4")], - [osp.join(example_video_dir, "d19.mp4")], - [osp.join(example_video_dir, "d14.mp4")], - [osp.join(example_video_dir, "d6.mp4")], - ], - inputs=[video_input], - cache_examples=False, - ) + with gr.Column(): + with gr.Tabs(): + with gr.TabItem("πŸ–ΌοΈ Source Image") as tab_image: + with gr.Accordion(open=True, label="Source Image"): + source_image_input = gr.Image(type="filepath") + gr.Examples( + examples=[ + [osp.join(example_portrait_dir, "s9.jpg")], + [osp.join(example_portrait_dir, "s6.jpg")], + [osp.join(example_portrait_dir, "s10.jpg")], + [osp.join(example_portrait_dir, "s5.jpg")], + [osp.join(example_portrait_dir, "s7.jpg")], + [osp.join(example_portrait_dir, "s12.jpg")], + ], + inputs=[source_image_input], + cache_examples=False, + ) + + with gr.TabItem("🎞️ Source Video") as tab_video: + with gr.Accordion(open=True, label="Source Video"): + source_video_input = gr.Video() + gr.Examples( + examples=[ + [osp.join(example_portrait_dir, "s13.mp4")], + # [osp.join(example_portrait_dir, "s14.mp4")], + # [osp.join(example_portrait_dir, "s15.mp4")], + [osp.join(example_portrait_dir, "s18.mp4")], + # [osp.join(example_portrait_dir, "s19.mp4")], + [osp.join(example_portrait_dir, "s20.mp4")], + ], + inputs=[source_video_input], + cache_examples=False, + ) + + tab_selection = gr.Textbox(visible=False) + tab_image.select(lambda: "Image", None, tab_selection) + tab_video.select(lambda: "Video", None, tab_selection) + with gr.Accordion(open=True, label="Cropping Options for Source Image or Video"): + with gr.Row(): + flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)") + scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=2.9, step=0.05) + vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01) + vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01) + + with gr.Column(): + with gr.Accordion(open=True, label="Driving Video"): + driving_video_input = gr.Video() + gr.Examples( + examples=[ + [osp.join(example_video_dir, "d0.mp4")], + [osp.join(example_video_dir, "d18.mp4")], + [osp.join(example_video_dir, "d19.mp4")], + [osp.join(example_video_dir, "d14.mp4")], + [osp.join(example_video_dir, "d6.mp4")], + ], + inputs=[driving_video_input], + cache_examples=False, + ) + # with gr.Accordion(open=False, label="Animation Instructions"): + # gr.Markdown(load_description("assets/gradio/gradio_description_animation.md")) + with gr.Accordion(open=True, label="Cropping Options for Driving Video"): + with gr.Row(): + flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)") + scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=2.9, step=0.05) + vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01) + vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01) + with gr.Row(): - with gr.Accordion(open=False, label="Animation Instructions and Options"): - gr.Markdown(load_description("assets/gradio_description_animation.md")) + with gr.Accordion(open=True, label="Animation Options"): with gr.Row(): flag_relative_input = gr.Checkbox(value=True, label="relative motion") - flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)") flag_remap_input = gr.Checkbox(value=True, label="paste-back") - flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)") + flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)") + driving_smooth_observation_variance = gr.Number(value=3e-6, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-11) + + gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md")) with gr.Row(): with gr.Column(): process_button_animation = gr.Button("πŸš€ Animate", variant="primary") with gr.Column(): - process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear") + process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear") with gr.Row(): with gr.Column(): with gr.Accordion(open=True, label="The animated video in the original image space"): - output_video.render() + output_video_i2v.render() with gr.Column(): with gr.Accordion(open=True, label="The animated video"): - output_video_concat.render() + output_video_concat_i2v.render() + with gr.Row(): # Examples gr.Markdown("## You could also choose the examples below by one click ⬇️") with gr.Row(): - gr.Examples( - examples=data_examples, - fn=gpu_wrapped_execute_video, - inputs=[ - image_input, - video_input, - flag_relative_input, - flag_do_crop_input, - flag_remap_input, - flag_crop_driving_video_input - ], - outputs=[output_image, output_image_paste_back], - examples_per_page=len(data_examples), - cache_examples=False, - ) - gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True) + with gr.Tabs(): + with gr.TabItem("πŸ–ΌοΈ Portrait Animation"): + gr.Examples( + examples=data_examples_i2v, + fn=gpu_wrapped_execute_video, + inputs=[ + source_image_input, + driving_video_input, + flag_relative_input, + flag_do_crop_input, + flag_remap_input, + flag_crop_driving_video_input, + ], + outputs=[output_image, output_image_paste_back], + examples_per_page=len(data_examples_i2v), + cache_examples=False, + ) + with gr.TabItem("🎞️ Portrait Video Editing"): + gr.Examples( + examples=data_examples_v2v, + fn=gpu_wrapped_execute_video, + inputs=[ + source_video_input, + driving_video_input, + flag_relative_input, + flag_do_crop_input, + flag_remap_input, + flag_crop_driving_video_input, + flag_video_editing_head_rotation, + driving_smooth_observation_variance, + ], + outputs=[output_image, output_image_paste_back], + examples_per_page=len(data_examples_v2v), + cache_examples=False, + ) + + # Retargeting + gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True) with gr.Row(visible=True): eye_retargeting_slider.render() lip_retargeting_slider.render() @@ -185,6 +263,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Column(): with gr.Accordion(open=True, label="Paste-back Result"): output_image_paste_back.render() + # binding functions for buttons process_button_retargeting.click( # fn=gradio_pipeline.execute_image, @@ -196,18 +275,27 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: process_button_animation.click( fn=gpu_wrapped_execute_video, inputs=[ - image_input, - video_input, + source_image_input, + source_video_input, + driving_video_input, flag_relative_input, flag_do_crop_input, flag_remap_input, - flag_crop_driving_video_input + flag_crop_driving_video_input, + flag_video_editing_head_rotation, + scale, + vx_ratio, + vy_ratio, + scale_crop_driving_video, + vx_ratio_crop_driving_video, + vy_ratio_crop_driving_video, + driving_smooth_observation_variance, + tab_selection, ], - outputs=[output_video, output_video_concat], + outputs=[output_video_i2v, output_video_concat_i2v], show_progress=True ) - demo.launch( server_port=args.server_port, share=args.share, diff --git a/assets/docs/changelog/2024-07-10.md b/assets/docs/changelog/2024-07-10.md index fe0fa72..8dc4aa6 100644 --- a/assets/docs/changelog/2024-07-10.md +++ b/assets/docs/changelog/2024-07-10.md @@ -9,7 +9,7 @@ The popularity of LivePortrait has exceeded our expectations. If you encounter a - Driving video auto-cropping: Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`. -- Motion template making: Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving_info` option. +- Motion template making: Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option. ### About driving video diff --git a/assets/docs/changelog/2024-07-19.md b/assets/docs/changelog/2024-07-19.md new file mode 100644 index 0000000..9e1491a --- /dev/null +++ b/assets/docs/changelog/2024-07-19.md @@ -0,0 +1,18 @@ +## 2024/07/19 + +**Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! πŸŽ‰** +We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk). + +### Updates + +- Portrait video editing (v2v): Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with`--flag_source_video_eye_retargeting`. + +- More options in Gradio: We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control. + + +### Community Contributions + +- **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance: + - [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KwaiVGI/LivePortrait/issues/150)) + - [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KwaiVGI/LivePortrait/issues/126) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait), [#142](https://github.com/KwaiVGI/LivePortrait/issues/142)) +- **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KwaiVGI/LivePortrait/issues/119). diff --git a/assets/examples/source/s13.mp4 b/assets/examples/source/s13.mp4 new file mode 100644 index 0000000..bac030c Binary files /dev/null and b/assets/examples/source/s13.mp4 differ diff --git a/assets/examples/source/s18.mp4 b/assets/examples/source/s18.mp4 new file mode 100644 index 0000000..838c5a8 Binary files /dev/null and b/assets/examples/source/s18.mp4 differ diff --git a/assets/examples/source/s20.mp4 b/assets/examples/source/s20.mp4 new file mode 100644 index 0000000..625f7d7 Binary files /dev/null and b/assets/examples/source/s20.mp4 differ diff --git a/assets/gradio/gradio_description_animate_clear.md b/assets/gradio/gradio_description_animate_clear.md new file mode 100644 index 0000000..96d5fee --- /dev/null +++ b/assets/gradio/gradio_description_animate_clear.md @@ -0,0 +1,6 @@ +
+ Step 3: Click the πŸš€ Animate button below to generate, or click 🧹 Clear to erase the results +
+ diff --git a/assets/gradio/gradio_description_animation.md b/assets/gradio/gradio_description_animation.md new file mode 100644 index 0000000..126c4ce --- /dev/null +++ b/assets/gradio/gradio_description_animation.md @@ -0,0 +1,19 @@ +πŸ”₯ To animate the source image or video with the driving video, please follow these steps: +
+1. In the Animation Options for Source Image or Video section, we recommend enabling the do crop (source) option if faces occupy a small portion of your source image or video. +
+
+2. In the Animation Options for Driving Video section, the relative head rotation and smooth strength options only take effect if the source input is a video. +
+
+3. Press the πŸš€ Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video. +
+
+4. If you want to upload your own driving video, the best practice: + + - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`. + - Focus on the head area, similar to the example videos. + - Minimize shoulder movement. + - Make sure the first frame of driving video is a frontal face with **neutral expression**. + +
diff --git a/assets/gradio_description_retargeting.md b/assets/gradio/gradio_description_retargeting.md similarity index 100% rename from assets/gradio_description_retargeting.md rename to assets/gradio/gradio_description_retargeting.md diff --git a/assets/gradio/gradio_description_upload.md b/assets/gradio/gradio_description_upload.md new file mode 100644 index 0000000..5c3907d --- /dev/null +++ b/assets/gradio/gradio_description_upload.md @@ -0,0 +1,13 @@ +
+
+
+
+ Step 1: Upload a Source Image or Video (any aspect ratio) ⬇️ +
+
+
+
+ Step 2: Upload a Driving Video (any aspect ratio) ⬇️ +
+
+
diff --git a/assets/gradio/gradio_title.md b/assets/gradio/gradio_title.md new file mode 100644 index 0000000..ad9e4ca --- /dev/null +++ b/assets/gradio/gradio_title.md @@ -0,0 +1,20 @@ +
+
+

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

+ + + +
+ +   + Project Page +   + +   + +   + +
+
+
diff --git a/assets/gradio_description_animation.md b/assets/gradio_description_animation.md deleted file mode 100644 index cad1ad6..0000000 --- a/assets/gradio_description_animation.md +++ /dev/null @@ -1,16 +0,0 @@ -πŸ”₯ To animate the source portrait with the driving video, please follow these steps: -
-1. In the Animation Options section, we recommend enabling the do crop (source) option if faces occupy a small portion of your image. -
-
-2. Press the πŸš€ Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. -
-
-3. If you want to upload your own driving video, the best practice: - - - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`. - - Focus on the head area, similar to the example videos. - - Minimize shoulder movement. - - Make sure the first frame of driving video is a frontal face with **neutral expression**. - -
diff --git a/assets/gradio_description_upload.md b/assets/gradio_description_upload.md deleted file mode 100644 index 035a6c2..0000000 --- a/assets/gradio_description_upload.md +++ /dev/null @@ -1,2 +0,0 @@ -## πŸ€— This is the official gradio demo for **LivePortrait**. -
Please upload or use a webcam to get a Source Portrait (any aspect ratio) and upload a Driving Video (1:1 aspect ratio, or any aspect ratio with do crop (driving video) checked).
diff --git a/assets/gradio_title.md b/assets/gradio_title.md deleted file mode 100644 index c9bbfc2..0000000 --- a/assets/gradio_title.md +++ /dev/null @@ -1,11 +0,0 @@ -
-
-

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

-
- Project Page - - -
-
-
diff --git a/inference.py b/inference.py index aeffaf6..9480523 100644 --- a/inference.py +++ b/inference.py @@ -22,10 +22,10 @@ def fast_check_ffmpeg(): def fast_check_args(args: ArgumentConfig): - if not osp.exists(args.source_image): - raise FileNotFoundError(f"source image not found: {args.source_image}") - if not osp.exists(args.driving_info): - raise FileNotFoundError(f"driving info not found: {args.driving_info}") + if not osp.exists(args.source): + raise FileNotFoundError(f"source info not found: {args.source}") + if not osp.exists(args.driving): + raise FileNotFoundError(f"driving info not found: {args.driving}") def main(): @@ -35,10 +35,9 @@ def main(): if not fast_check_ffmpeg(): raise ImportError( - "FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html" + "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" ) - # fast check the args fast_check_args(args) # specify configs for inference diff --git a/readme.md b/readme.md index 8c79a5e..c96975c 100644 --- a/readme.md +++ b/readme.md @@ -39,6 +39,7 @@ ## πŸ”₯ Updates +- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md). - **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143). - **`2024/07/10`**: πŸ’ͺ We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md). - **`2024/07/09`**: πŸ€— We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)! @@ -47,11 +48,11 @@ -## Introduction +## Introduction πŸ“– This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168). We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) πŸ’–. -## πŸ”₯ Getting Started +## Getting Started 🏁 ### 1. Clone the code and prepare the environment ```bash git clone https://github.com/KwaiVGI/LivePortrait @@ -61,9 +62,10 @@ cd LivePortrait conda create -n LivePortrait python==3.9 conda activate LivePortrait -# install dependencies with pip (for Linux and Windows) +# install dependencies with pip +# for Linux and Windows users pip install -r requirements.txt -# for macOS with Apple Silicon +# for macOS with Apple Silicon users pip install -r requirements_macOS.txt ``` @@ -113,7 +115,7 @@ python inference.py PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py ``` -If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result. +If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image or video, and generated result.

image @@ -122,18 +124,18 @@ If the script runs successfully, you will get an output mp4 file named `animatio Or, you can change the input by specifying the `-s` and `-d` arguments: ```bash +# source input is an image python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 -# disable pasting back to run faster -python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback +# source input is a video ✨ +python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d0.mp4 # more options to see python inference.py -h ``` -#### Driving video auto-cropping - -πŸ“• To use your own driving video, we **recommend**: +#### Driving video auto-cropping πŸ“’πŸ“’πŸ“’ +To use your own driving video, we **recommend**: ⬇️ - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`. - Focus on the head area, similar to the example videos. - Minimize shoulder movement. @@ -144,25 +146,24 @@ Below is a auto-cropping case by `--flag_crop_driving_video`: python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video ``` -If you find the results of auto-cropping is not well, you can modify the `--scale_crop_video`, `--vy_ratio_crop_video` options to adjust the scale and offset, or do it manually. +If you find the results of auto-cropping is not well, you can modify the `--scale_crop_driving_video`, `--vy_ratio_crop_driving_video` options to adjust the scale and offset, or do it manually. #### Motion template making You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as: ```bash -python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl +python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl # portrait animation +python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d5.pkl # portrait video editing ``` -**Discover more interesting results on our [Homepage](https://liveportrait.github.io)** 😊 - ### 4. Gradio interface πŸ€— We also provide a Gradio interface for a better experience, just run by: ```bash -# For Linux and Windows: +# For Linux and Windows users (and macOS with Intel??) python app.py -# For macOS with Apple Silicon, Intel not supported, this maybe 20x slower than RTX 4090 +# For macOS with Apple Silicon users, Intel not supported, this maybe 20x slower than RTX 4090 PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py ``` @@ -210,7 +211,7 @@ Discover the invaluable resources contributed by our community to enhance your L And many more amazing contributions from our community! -## Acknowledgements +## Acknowledgements πŸ’ We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions. ## Citation πŸ’– diff --git a/requirements_base.txt b/requirements_base.txt index 59fad00..b071f2c 100644 --- a/requirements_base.txt +++ b/requirements_base.txt @@ -19,3 +19,4 @@ matplotlib==3.9.0 imageio-ffmpeg==0.5.1 tyro==0.8.5 gradio==4.37.1 +pykalman==0.9.7 diff --git a/src/config/argument_config.py b/src/config/argument_config.py index aa86713..7a130e8 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -14,35 +14,39 @@ from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class ArgumentConfig(PrintableConfig): ########## input arguments ########## - source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait - driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) + source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait or video + driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video ########## inference arguments ########## flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False. flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video - device_id: int = 0 # gpu device id - flag_force_cpu: bool = False # force cpu inference, WIP! - flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False - flag_eye_retargeting: bool = False # not recommend to be True, WIP - flag_lip_retargeting: bool = False # not recommend to be True, WIP + device_id: int = 0 # gpu device id + flag_force_cpu: bool = False # force cpu inference, WIP! + flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False + flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering + flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video + flag_eye_retargeting: bool = False # not recommend to be True, WIP + flag_lip_retargeting: bool = False # not recommend to be True, WIP flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large flag_relative_motion: bool = True # whether to use relative motion flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space - flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space - flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True + flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space + driving_smooth_observation_variance: float = 3e-6 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy - ########## crop arguments ########## - scale: float = 2.3 # the ratio of face area is smaller if scale is larger + ########## source crop arguments ########## + scale: float = 2.3 # the ratio of face area is smaller if scale is larger vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space + flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True - scale_crop_video: float = 2.2 # scale factor for cropping video - vx_ratio_crop_video: float = 0. # adjust y offset - vy_ratio_crop_video: float = -0.1 # adjust x offset + ########## driving crop arguments ########## + scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video + vx_ratio_crop_driving_video: float = 0. # adjust y offset + vy_ratio_crop_driving_video: float = -0.1 # adjust x offset ########## gradio arguments ########## - server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server - share: bool = False # whether to share the server to public + server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server + share: bool = False # whether to share the server to public server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation diff --git a/src/config/crop_config.py b/src/config/crop_config.py index f3b12ef..c7d64a5 100644 --- a/src/config/crop_config.py +++ b/src/config/crop_config.py @@ -15,15 +15,16 @@ class CropConfig(PrintableConfig): landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx" device_id: int = 0 # gpu device id flag_force_cpu: bool = False # force cpu inference, WIP - ########## source image cropping option ########## + ########## source image or video cropping option ########## dsize: int = 512 # crop size - scale: float = 2.5 # scale factor + scale: float = 2.8 # scale factor vx_ratio: float = 0 # vx ratio vy_ratio: float = -0.125 # vy ratio +up, -down max_face_num: int = 0 # max face number, 0 mean no limit + flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True ########## driving video auto cropping option ########## - scale_crop_video: float = 2.2 # 2.0 # scale factor for cropping video - vx_ratio_crop_video: float = 0.0 # adjust y offset - vy_ratio_crop_video: float = -0.1 # adjust x offset + scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video + vx_ratio_crop_driving_video: float = 0.0 # adjust y offset + vy_ratio_crop_driving_video: float = -0.1 # adjust x offset direction: str = "large-small" # direction of cropping diff --git a/src/config/inference_config.py b/src/config/inference_config.py index 7c6c718..adb313f 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -25,7 +25,9 @@ class InferenceConfig(PrintableConfig): flag_use_half_precision: bool = True flag_crop_driving_video: bool = False device_id: int = 0 - flag_lip_zero: bool = True + flag_normalize_lip: bool = True + flag_source_video_eye_retargeting: bool = False + flag_video_editing_head_rotation: bool = False flag_eye_retargeting: bool = False flag_lip_retargeting: bool = False flag_stitching: bool = True @@ -37,7 +39,9 @@ class InferenceConfig(PrintableConfig): flag_do_torch_compile: bool = False # NOT EXPORTED PARAMS - lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero + lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip + source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video + driving_smooth_observation_variance: float = 3e-6 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy anchor_frame: int = 0 # TO IMPLEMENT input_shape: Tuple[int, int] = (256, 256) # input shape @@ -47,5 +51,5 @@ class InferenceConfig(PrintableConfig): mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) size_gif: int = 256 # default gif size, TO IMPLEMENT - source_max_dim: int = 1280 # the max dim of height and width of source image - source_division: int = 2 # make sure the height and width of source image can be divided by this number + source_max_dim: int = 1280 # the max dim of height and width of source image or video + source_division: int = 2 # make sure the height and width of source image or video can be divided by this number diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index f7343f7..7b01370 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -3,6 +3,8 @@ """ Pipeline for gradio """ + +import os.path as osp import gradio as gr from .config.argument_config import ArgumentConfig @@ -11,6 +13,7 @@ from .utils.io import load_img_online from .utils.rprint import rlog as log from .utils.crop import prepare_paste_back, paste_back from .utils.camera import get_rotation_matrix +from .utils.helper import is_square_video def update_args(args, user_args): @@ -31,23 +34,53 @@ class GradioPipeline(LivePortraitPipeline): def execute_video( self, - input_image_path, - input_video_path, - flag_relative_input, - flag_do_crop_input, - flag_remap_input, - flag_crop_driving_video_input + input_source_image_path=None, + input_source_video_path=None, + input_driving_video_path=None, + flag_relative_input=True, + flag_do_crop_input=True, + flag_remap_input=True, + flag_crop_driving_video_input=True, + flag_video_editing_head_rotation=False, + scale=2.3, + vx_ratio=0.0, + vy_ratio=-0.125, + scale_crop_driving_video=2.2, + vx_ratio_crop_driving_video=0.0, + vy_ratio_crop_driving_video=-0.1, + driving_smooth_observation_variance=3e-6, + tab_selection=None, ): - """ for video driven potrait animation + """ for video-driven potrait animation or video editing """ - if input_image_path is not None and input_video_path is not None: + if tab_selection == 'Image': + input_source_path = input_source_image_path + elif tab_selection == 'Video': + input_source_path = input_source_video_path + else: + input_source_path = input_source_image_path + + if input_source_path is not None and input_driving_video_path is not None: + if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False: + flag_crop_driving_video_input = True + log("The source video is not square, the driving video will be cropped to square automatically.") + gr.Info("The source video is not square, the driving video will be cropped to square automatically.", duration=2) + args_user = { - 'source_image': input_image_path, - 'driving_info': input_video_path, - 'flag_relative': flag_relative_input, + 'source': input_source_path, + 'driving': input_driving_video_path, + 'flag_relative_motion': flag_relative_input, 'flag_do_crop': flag_do_crop_input, 'flag_pasteback': flag_remap_input, - 'flag_crop_driving_video': flag_crop_driving_video_input + 'flag_crop_driving_video': flag_crop_driving_video_input, + 'flag_video_editing_head_rotation': flag_video_editing_head_rotation, + 'scale': scale, + 'vx_ratio': vx_ratio, + 'vy_ratio': vy_ratio, + 'scale_crop_driving_video': scale_crop_driving_video, + 'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video, + 'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video, + 'driving_smooth_observation_variance': driving_smooth_observation_variance, } # update config from user input self.args = update_args(self.args, args_user) @@ -58,7 +91,7 @@ class GradioPipeline(LivePortraitPipeline): gr.Info("Run successfully!", duration=2) return video_path, video_path_concat, else: - raise gr.Error("The input source portrait or driving video hasn't been prepared yet πŸ’₯!", duration=5) + raise gr.Error("Please upload the source portrait or source video, and driving video πŸ€—πŸ€—πŸ€—", duration=5) def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True): """ for single image retargeting @@ -79,9 +112,8 @@ class GradioPipeline(LivePortraitPipeline): # βˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) - num_kp = x_s_user.shape[1] # default: use x_s - x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) + x_d_new = x_s_user + eyes_delta + lip_delta # D(W(f_s; x_s, xβ€²_d)) out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new) out = self.live_portrait_wrapper.parse_output(out['out'])[0] @@ -114,4 +146,4 @@ class GradioPipeline(LivePortraitPipeline): return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb else: # when press the clear button, go here - raise gr.Error("The retargeting input hasn't been prepared yet πŸ’₯!", duration=5) + raise gr.Error("Please upload a source portrait as the retargeting input πŸ€—πŸ€—πŸ€—", duration=5) diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index dc5ea01..59fc9e1 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -20,8 +20,9 @@ from .utils.cropper import Cropper from .utils.camera import get_rotation_matrix from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream from .utils.crop import _transform_img, prepare_paste_back, paste_back -from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load -from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix +from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load +from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image +from .utils.filter import smooth from .utils.rprint import rlog as log # from .utils.viz import viz_lmk from .live_portrait_wrapper import LivePortraitWrapper @@ -37,125 +38,252 @@ class LivePortraitPipeline(object): self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg) self.cropper: Cropper = Cropper(crop_cfg=crop_cfg) + def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs): + n_frames = I_lst.shape[0] + template_dct = { + 'n_frames': n_frames, + 'output_fps': kwargs.get('output_fps', 25), + 'motion': [], + 'c_eyes_lst': [], + 'c_lip_lst': [], + 'x_i_info_lst': [], + } + + for i in track(range(n_frames), description='Making motion templates...', total=n_frames): + # collect s, R, Ξ΄ and t for inference + I_i = I_lst[i] + x_i_info = self.live_portrait_wrapper.get_kp_info(I_i) + R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll']) + + item_dct = { + 'scale': x_i_info['scale'].cpu().numpy().astype(np.float32), + 'R': R_i.cpu().numpy().astype(np.float32), + 'exp': x_i_info['exp'].cpu().numpy().astype(np.float32), + 't': x_i_info['t'].cpu().numpy().astype(np.float32), + } + + template_dct['motion'].append(item_dct) + + c_eyes = c_eyes_lst[i].astype(np.float32) + template_dct['c_eyes_lst'].append(c_eyes) + + c_lip = c_lip_lst[i].astype(np.float32) + template_dct['c_lip_lst'].append(c_lip) + + template_dct['x_i_info_lst'].append(x_i_info) + + return template_dct + def execute(self, args: ArgumentConfig): # for convenience inf_cfg = self.live_portrait_wrapper.inference_cfg - device = self.live_portrait_wrapper.device + device = self.live_portrait_wrapper.device crop_cfg = self.cropper.crop_cfg - ######## process source portrait ######## - img_rgb = load_image_rgb(args.source_image) - img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division) - log(f"Load source image from {args.source_image}") - - crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg) - if crop_info is None: - raise Exception("No face detected in the source image!") - source_lmk = crop_info['lmk_crop'] - img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] - - if inf_cfg.flag_do_crop: - I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) - else: - img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256 - I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) - x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) - x_c_s = x_s_info['kp'] - R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) - f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) - x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) - - flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite - if flag_lip_zero: - # let lip-open scalar to be 0 at first - c_d_lip_before_animation = [0.] - combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) - if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold: - flag_lip_zero = False - else: - lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) - ############################################ + ######## load source input ######## + flag_is_source_video = False + source_fps = None + if is_image(args.source): + flag_is_source_video = False + img_rgb = load_image_rgb(args.source) + img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division) + log(f"Load source image from {args.source}") + source_rgb_lst = [img_rgb] + elif is_video(args.source): + flag_is_source_video = True + source_rgb_lst = load_video(args.source) + source_rgb_lst = [resize_to_limit(img, inf_cfg.source_max_dim, inf_cfg.source_division) for img in source_rgb_lst] + source_fps = int(get_fps(args.source)) + log(f"Load source video from {args.source}, FPS is {source_fps}") + else: # source input is an unknown format + raise Exception(f"Unknown source format: {args.source}") ######## process driving info ######## - flag_load_from_template = is_template(args.driving_info) + flag_load_from_template = is_template(args.driving) driving_rgb_crop_256x256_lst = None wfp_template = None if flag_load_from_template: # NOTE: load from template, it is fast, but the cropping video is None - log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green') - template_dct = load(args.driving_info) - n_frames = template_dct['n_frames'] + log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green') + driving_template_dct = load(args.driving) + c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys + c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] + driving_n_frames = driving_template_dct['n_frames'] + if flag_is_source_video: + n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames + else: + n_frames = driving_n_frames # set output_fps - output_fps = template_dct.get('output_fps', inf_cfg.output_fps) + output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps) log(f'The FPS of template: {output_fps}') if args.flag_crop_driving_video: log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") - elif osp.exists(args.driving_info) and is_video(args.driving_info): + elif osp.exists(args.driving) and is_video(args.driving): # load from video file, AND make motion template - log(f"Load video: {args.driving_info}") - if osp.isdir(args.driving_info): - output_fps = inf_cfg.output_fps - else: - output_fps = int(get_fps(args.driving_info)) - log(f'The FPS of {args.driving_info} is: {output_fps}') + output_fps = int(get_fps(args.driving)) + log(f"Load driving video from: {args.driving}, FPS is {output_fps}") - log(f"Load video file (mp4 mov avi etc...): {args.driving_info}") - driving_rgb_lst = load_driving_info(args.driving_info) + driving_rgb_lst = load_video(args.driving) + driving_n_frames = len(driving_rgb_lst) ######## make motion template ######## - log("Start making motion template...") + log("Start making driving motion template...") + if flag_is_source_video: + n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames + driving_rgb_lst = driving_rgb_lst[:n_frames] + else: + n_frames = driving_n_frames if inf_cfg.flag_crop_driving_video: - ret = self.cropper.crop_driving_video(driving_rgb_lst) - log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.') - driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst'] + ret_d = self.cropper.crop_driving_video(driving_rgb_lst) + log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') + if len(ret_d["frame_crop_lst"]) is not n_frames: + n_frames = min(n_frames, len(ret_d["frame_crop_lst"])) + driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst'] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] else: driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst) driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256 + ####################################### - c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst) + c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_ratio(driving_lmk_crop_lst) # save the motion template - I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst) - template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps) + I_d_lst = self.live_portrait_wrapper.prepare_videos(driving_rgb_crop_256x256_lst) + driving_template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps) - wfp_template = remove_suffix(args.driving_info) + '.pkl' - dump(wfp_template, template_dct) + wfp_template = remove_suffix(args.driving) + '.pkl' + dump(wfp_template, driving_template_dct) log(f"Dump motion template to {wfp_template}") - n_frames = I_d_lst.shape[0] else: - raise Exception(f"{args.driving_info} not exists or unsupported driving info types!") - ######################################### + raise Exception(f"{args.driving} not exists or unsupported driving info types!") ######## prepare for pasteback ######## I_p_pstbk_lst = None if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: - mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) I_p_pstbk_lst = [] log("Prepared pasteback mask done.") - ######################################### I_p_lst = [] R_d_0, x_d_0_info = None, None + flag_normalize_lip = inf_cfg.flag_normalize_lip # not overwrite + flag_source_video_eye_retargeting = inf_cfg.flag_source_video_eye_retargeting # not overwrite + lip_delta_before_animation, eye_delta_before_animation = None, None + ######## process source info ######## + if flag_is_source_video: + log(f"Start making source motion template...") + + source_rgb_lst = source_rgb_lst[:n_frames] + if inf_cfg.flag_do_crop: + ret_s = self.cropper.crop_source_video(source_rgb_lst, crop_cfg) + log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.') + if len(ret_s["frame_crop_lst"]) is not n_frames: + n_frames = min(n_frames, len(ret_s["frame_crop_lst"])) + img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst'] + else: + source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst) + img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256 + + c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst) + # save the motion template + I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst) + source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps) + + x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] + x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) + if inf_cfg.flag_video_editing_head_rotation: + key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys + x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] + x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) + else: # if the input is a source image, process it only once + crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg) + if crop_info is None: + raise Exception("No face detected in the source image!") + source_lmk = crop_info['lmk_crop'] + img_crop_256x256 = crop_info['img_crop_256x256'] + + if inf_cfg.flag_do_crop: + I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) + else: + img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256 + I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) + x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) + x_c_s = x_s_info['kp'] + R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) + x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) + + # let lip-open scalar to be 0 at first + if flag_normalize_lip: + c_d_lip_before_animation = [0.] + combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) + if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold: + lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) + + if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: + mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) + + ######## animate ######## + log(f"The animated video consists of {n_frames} frames.") for i in track(range(n_frames), description='πŸš€Animating...', total=n_frames): - x_d_i_info = template_dct['motion'][i] - x_d_i_info = dct2device(x_d_i_info, device) - R_d_i = x_d_i_info['R_d'] + if flag_is_source_video: # source video + x_s_info_tiny = source_template_dct['motion'][i] + x_s_info_tiny = dct2device(x_s_info_tiny, device) - if i == 0: + source_lmk = source_lmk_crop_lst[i] + img_crop_256x256 = img_crop_256x256_lst[i] + I_s = I_s_lst[i] + + x_s_info = source_template_dct['x_i_info_lst'][i] + x_c_s = x_s_info['kp'] + R_s = x_s_info_tiny['R'] + f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) + x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) + + # let lip-open scalar to be 0 at first if the input is a video + if flag_normalize_lip: + c_d_lip_before_animation = [0.] + combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) + if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold: + lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) + + # let eye-open scalar to be the same as the first frame if the latter is eye-open state + if flag_source_video_eye_retargeting: + if i == 0: + combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0] + c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]] + if c_d_eye_before_animation_frame_zero[0][0] < inf_cfg.source_video_eye_retargeting_threshold: + c_d_eye_before_animation_frame_zero = [[0.39]] + combined_eye_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk) + eye_delta_before_animation = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation) + + if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back + mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0])) + + x_d_i_info = driving_template_dct['motion'][i] + x_d_i_info = dct2device(x_d_i_info, device) + R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys + + if i == 0: # cache the first frame R_d_0 = R_d_i x_d_0_info = x_d_i_info if inf_cfg.flag_relative_motion: - R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s - delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) - scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) - t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) + if flag_is_source_video: + if inf_cfg.flag_video_editing_head_rotation: + R_new = x_d_r_lst_smooth[i] + else: + R_new = R_s + else: + R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s + + delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) + t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) else: R_new = R_d_i delta_new = x_d_i_info['exp'] @@ -168,16 +296,20 @@ class LivePortraitPipeline(object): # Algorithm 1: if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: # without stitching or retargeting - if flag_lip_zero: - x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) + if flag_normalize_lip and lip_delta_before_animation is not None: + x_d_i_new += lip_delta_before_animation + if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: + x_d_i_new += eye_delta_before_animation else: pass elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: # with stitching and without retargeting - if flag_lip_zero: - x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) + if flag_normalize_lip and lip_delta_before_animation is not None: + x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation else: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: + x_d_i_new += eye_delta_before_animation else: eyes_delta, lip_delta = None, None if inf_cfg.flag_eye_retargeting: @@ -193,12 +325,12 @@ class LivePortraitPipeline(object): if inf_cfg.flag_relative_motion: # use x_s x_d_i_new = x_s + \ - (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ - (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + (eyes_delta if eyes_delta is not None else 0) + \ + (lip_delta if lip_delta is not None else 0) else: # use x_d,i x_d_i_new = x_d_i_new + \ - (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ - (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + (eyes_delta if eyes_delta is not None else 0) + \ + (lip_delta if lip_delta is not None else 0) if inf_cfg.flag_stitching: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) @@ -208,38 +340,52 @@ class LivePortraitPipeline(object): I_p_lst.append(I_p_i) if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: - # TODO: pasteback is slow, considering optimize it using multi-threading or GPU - I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float) + # TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU + if flag_is_source_video: + I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float) + else: + I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], source_rgb_lst[0], mask_ori_float) I_p_pstbk_lst.append(I_p_pstbk) mkdir(args.output_dir) wfp_concat = None - flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info) + flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) + flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) - ######### build final concat result ######### - # driving frame | source image | generation, or source image | generation - frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst) - wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4') + ######### build the final concatenation result ######### + # driving frame | source frame | generation, or source frame | generation + if flag_is_source_video: + frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) + else: + frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst) + wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') + + # NOTE: update output fps + output_fps = source_fps if flag_is_source_video else output_fps images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) - if flag_has_audio: - # final result with concat - wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4') - add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio) + if flag_source_has_audio or flag_driving_has_audio: + # final result with concatenation + wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') + audio_from_which_video = args.source if flag_source_has_audio else args.driving + log(f"Audio is selected from {audio_from_which_video}, concat mode") + add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) os.replace(wfp_concat_with_audio, wfp_concat) log(f"Replace {wfp_concat} with {wfp_concat_with_audio}") - # save drived result - wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') + # save the animated result + wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) else: images2video(I_p_lst, wfp=wfp, fps=output_fps) - ######### build final result ######### - if flag_has_audio: - wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4') - add_audio_to_video(wfp, args.driving_info, wfp_with_audio) + ######### build the final result ######### + if flag_source_has_audio or flag_driving_has_audio: + wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') + audio_from_which_video = args.source if flag_source_has_audio else args.driving + log(f"Audio is selected from {audio_from_which_video}") + add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) os.replace(wfp_with_audio, wfp) log(f"Replace {wfp} with {wfp_with_audio}") @@ -250,36 +396,3 @@ class LivePortraitPipeline(object): log(f'Animated video with concat: {wfp_concat}') return wfp, wfp_concat - - def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs): - n_frames = I_d_lst.shape[0] - template_dct = { - 'n_frames': n_frames, - 'output_fps': kwargs.get('output_fps', 25), - 'motion': [], - 'c_d_eyes_lst': [], - 'c_d_lip_lst': [], - } - - for i in track(range(n_frames), description='Making motion templates...', total=n_frames): - # collect s_d, R_d, Ξ΄_d and t_d for inference - I_d_i = I_d_lst[i] - x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) - R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) - - item_dct = { - 'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32), - 'R_d': R_d_i.cpu().numpy().astype(np.float32), - 'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32), - 't': x_d_i_info['t'].cpu().numpy().astype(np.float32), - } - - template_dct['motion'].append(item_dct) - - c_d_eyes = c_d_eyes_lst[i].astype(np.float32) - template_dct['c_d_eyes_lst'].append(c_d_eyes) - - c_d_lip = c_d_lip_lst[i].astype(np.float32) - template_dct['c_d_lip_lst'].append(c_d_lip) - - return template_dct diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py index d5e9916..1298f37 100644 --- a/src/live_portrait_wrapper.py +++ b/src/live_portrait_wrapper.py @@ -95,7 +95,7 @@ class LivePortraitWrapper(object): x = x.to(self.device) return x - def prepare_driving_videos(self, imgs) -> torch.Tensor: + def prepare_videos(self, imgs) -> torch.Tensor: """ construct the input as standard imgs: NxBxHxWx3, uint8 """ @@ -216,7 +216,7 @@ class LivePortraitWrapper(object): with torch.no_grad(): delta = self.stitching_retargeting_module['eye'](feat_eye) - return delta + return delta.reshape(-1, kp_source.shape[1], 3) def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor: """ @@ -229,7 +229,7 @@ class LivePortraitWrapper(object): with torch.no_grad(): delta = self.stitching_retargeting_module['lip'](feat_lip) - return delta + return delta.reshape(-1, kp_source.shape[1], 3) def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: """ @@ -301,10 +301,10 @@ class LivePortraitWrapper(object): return out - def calc_driving_ratio(self, driving_lmk_lst): + def calc_ratio(self, lmk_lst): input_eye_ratio_lst = [] input_lip_ratio_lst = [] - for lmk in driving_lmk_lst: + for lmk in lmk_lst: # for eyes retargeting input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) # for lip retargeting diff --git a/src/utils/cropper.py b/src/utils/cropper.py index 7cc4152..8c17bfb 100644 --- a/src/utils/cropper.py +++ b/src/utils/cropper.py @@ -31,6 +31,7 @@ class Trajectory: end: int = -1 # end frame lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list + M_c2o_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # M_c2o list frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list @@ -104,6 +105,7 @@ class Cropper(object): scale=crop_cfg.scale, vx_ratio=crop_cfg.vx_ratio, vy_ratio=crop_cfg.vy_ratio, + flag_do_rot=crop_cfg.flag_do_rot, ) lmk = self.landmark_runner.run(img_rgb, lmk) @@ -115,6 +117,58 @@ class Cropper(object): return ret_dct + def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig): + """Tracking based landmarks/alignment and cropping""" + trajectory = Trajectory() + for idx, frame_rgb in enumerate(source_rgb_lst): + if idx == 0 or trajectory.start == -1: + src_face = self.face_analysis_wrapper.get( + contiguous(frame_rgb[..., ::-1]), + flag_do_landmark_2d_106=True, + direction=crop_cfg.direction, + max_face_num=crop_cfg.max_face_num, + ) + if len(src_face) == 0: + log(f"No face detected in the frame #{idx}") + continue + elif len(src_face) > 1: + log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.") + src_face = src_face[0] + lmk = src_face.landmark_2d_106 + lmk = self.landmark_runner.run(frame_rgb, lmk) + trajectory.start, trajectory.end = idx, idx + else: + lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1]) + trajectory.end = idx + trajectory.lmk_lst.append(lmk) + + # crop the face + ret_dct = crop_image( + frame_rgb, # ndarray + lmk, # 106x2 or Nx2 + dsize=crop_cfg.dsize, + scale=crop_cfg.scale, + vx_ratio=crop_cfg.vx_ratio, + vy_ratio=crop_cfg.vy_ratio, + flag_do_rot=crop_cfg.flag_do_rot, + ) + lmk = self.landmark_runner.run(frame_rgb, lmk) + ret_dct["lmk_crop"] = lmk + + # update a 256x256 version for network input + ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA) + ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize + + trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop_256x256"]) + trajectory.lmk_crop_lst.append(ret_dct["lmk_crop_256x256"]) + trajectory.M_c2o_lst.append(ret_dct['M_c2o']) + + return { + "frame_crop_lst": trajectory.frame_rgb_crop_lst, + "lmk_crop_lst": trajectory.lmk_crop_lst, + "M_c2o_lst": trajectory.M_c2o_lst, + } + def crop_driving_video(self, driving_rgb_lst, **kwargs): """Tracking based landmarks/alignment and cropping""" trajectory = Trajectory() @@ -142,9 +196,9 @@ class Cropper(object): trajectory.lmk_lst.append(lmk) ret_bbox = parse_bbox_from_landmark( lmk, - scale=self.crop_cfg.scale_crop_video, - vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video, - vy_ratio=self.crop_cfg.vy_ratio_crop_video, + scale=self.crop_cfg.scale_crop_driving_video, + vx_ratio_crop_driving_video=self.crop_cfg.vx_ratio_crop_driving_video, + vy_ratio=self.crop_cfg.vy_ratio_crop_driving_video, )["bbox"] bbox = [ ret_bbox[0, 0], @@ -174,6 +228,7 @@ class Cropper(object): "lmk_crop_lst": trajectory.lmk_crop_lst, } + def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs): """Tracking based landmarks/alignment""" trajectory = Trajectory() diff --git a/src/utils/filter.py b/src/utils/filter.py new file mode 100644 index 0000000..2ee6abc --- /dev/null +++ b/src/utils/filter.py @@ -0,0 +1,19 @@ +# coding: utf-8 + +import torch +import numpy as np +from pykalman import KalmanFilter + + +def smooth(x_d_lst, shape, device, observation_variance=3e-6, process_variance=1e-5): + x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst] + x_d_stacked = np.vstack(x_d_lst_reshape) + kf = KalmanFilter( + initial_state_mean=x_d_stacked[0], + n_dim_obs=x_d_stacked.shape[1], + transition_covariance=process_variance * np.eye(x_d_stacked.shape[1]), + observation_covariance=observation_variance * np.eye(x_d_stacked.shape[1]) + ) + smoothed_state_means, _ = kf.smooth(x_d_stacked) + x_d_lst_smooth = [torch.tensor(state_mean.reshape(shape[-2:]), dtype=torch.float32, device=device) for state_mean in smoothed_state_means] + return x_d_lst_smooth diff --git a/src/utils/helper.py b/src/utils/helper.py index 0e2af94..777bf09 100644 --- a/src/utils/helper.py +++ b/src/utils/helper.py @@ -8,6 +8,8 @@ import os import os.path as osp import torch from collections import OrderedDict +import numpy as np +import cv2 from ..modules.spade_generator import SPADEDecoder from ..modules.warping_network import WarpingNetwork @@ -42,6 +44,11 @@ def remove_suffix(filepath): return osp.join(osp.dirname(filepath), basename(filepath)) +def is_image(file_path): + image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff') + return file_path.lower().endswith(image_extensions) + + def is_video(file_path): if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): return True @@ -143,3 +150,16 @@ def load_description(fp): with open(fp, 'r', encoding='utf-8') as f: content = f.read() return content + + +def is_square_video(video_path): + video = cv2.VideoCapture(video_path) + + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + video.release() + # if width != height: + # gr.Info(f"Uploaded video is not square, force do crop (driving) to be True") + + return width == height diff --git a/src/utils/io.py b/src/utils/io.py index 28c2d99..9e4bc69 100644 --- a/src/utils/io.py +++ b/src/utils/io.py @@ -1,7 +1,5 @@ # coding: utf-8 -import os -from glob import glob import os.path as osp import imageio import numpy as np @@ -18,23 +16,17 @@ def load_image_rgb(image_path: str): return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) -def load_driving_info(driving_info): - driving_video_ori = [] +def load_video(video_info, n_frames=-1): + reader = imageio.get_reader(video_info, "ffmpeg") - def load_images_from_directory(directory): - image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg'))) - return [load_image_rgb(im_path) for im_path in image_paths] + ret = [] + for idx, frame_rgb in enumerate(reader): + if n_frames > 0 and idx >= n_frames: + break + ret.append(frame_rgb) - def load_images_from_video(file_path): - reader = imageio.get_reader(file_path, "ffmpeg") - return [image for _, image in enumerate(reader)] - - if osp.isdir(driving_info): - driving_video_ori = load_images_from_directory(driving_info) - elif osp.isfile(driving_info): - driving_video_ori = load_images_from_video(driving_info) - - return driving_video_ori + reader.close() + return ret def contiguous(obj): diff --git a/src/utils/video.py b/src/utils/video.py index d9bf8cb..a6d06fa 100644 --- a/src/utils/video.py +++ b/src/utils/video.py @@ -80,14 +80,15 @@ def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)): return img -def concat_frames(driving_image_lst, source_image, I_p_lst): +def concat_frames(driving_image_lst, source_image_lst, I_p_lst): # TODO: add more concat style, e.g., left-down corner driving out_lst = [] h, w, _ = I_p_lst[0].shape + source_image_resized_lst = [cv2.resize(img, (w, h)) for img in source_image_lst] for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'): I_p = I_p_lst[idx] - source_image_resized = cv2.resize(source_image, (w, h)) + source_image_resized = source_image_resized_lst[idx] if len(source_image_lst) > 1 else source_image_resized_lst[0] if driving_image_lst is None: out = np.hstack((source_image_resized, I_p))