diff --git a/app.py b/app.py index 35a57f2..e386542 100644 --- a/app.py +++ b/app.py @@ -85,12 +85,12 @@ data_examples_i2v = [ [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True], ] data_examples_v2v = [ - [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7], + [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7], # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7], # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7], - [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7], + [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, 3e-7], # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7], - [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7], + [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7], ] #################### interface logic #################### @@ -98,6 +98,7 @@ data_examples_v2v = [ retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale") video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale") driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8) +video_retargeting_silence = gr.Checkbox(value=False, label="keeping the lip silent") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") @@ -124,9 +125,6 @@ retargeting_output_image = gr.Image(type="numpy") retargeting_output_image_paste_back = gr.Image(type="numpy") output_video = gr.Video(autoplay=False) output_video_paste_back = gr.Video(autoplay=False) -output_video_i2v = gr.Video(autoplay=False) -output_video_concat_i2v = gr.Video(autoplay=False) - with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo: gr.HTML(load_description(title_md)) @@ -196,6 +194,22 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San inputs=[driving_video_input], cache_examples=False, ) + with gr.TabItem("πŸ–ΌοΈ Driving Image") as v_tab_image: + with gr.Accordion(open=True, label="Driving Image"): + driving_image_input = gr.Image(type="filepath") + gr.Examples( + examples=[ + [osp.join(example_video_dir, "d30.jpg")], + [osp.join(example_video_dir, "d9.jpg")], + [osp.join(example_video_dir, "d19.jpg")], + [osp.join(example_video_dir, "d8.jpg")], + [osp.join(example_video_dir, "d12.jpg")], + [osp.join(example_video_dir, "d38.jpg")], + ], + inputs=[driving_image_input], + cache_examples=False, + ) + with gr.TabItem("πŸ“ Driving Pickle") as v_tab_pickle: with gr.Accordion(open=True, label="Driving Pickle"): driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"]) @@ -212,8 +226,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San ) v_tab_selection = gr.Textbox(visible=False) - v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection) v_tab_video.select(lambda: "Video", None, v_tab_selection) + v_tab_image.select(lambda: "Image", None, v_tab_selection) + v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection) # with gr.Accordion(open=False, label="Animation Instructions"): # gr.Markdown(load_description("assets/gradio/gradio_description_animation.md")) with gr.Accordion(open=True, label="Cropping Options for Driving Video"): @@ -229,9 +244,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_stitching_input = gr.Checkbox(value=True, label="stitching") + animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region") driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)") driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02) - flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)") driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8) gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md")) @@ -239,13 +254,16 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San process_button_animation = gr.Button("πŸš€ Animate", variant="primary") with gr.Row(): with gr.Column(): - with gr.Accordion(open=True, label="The animated video in the original image space"): - output_video_i2v.render() + output_video_i2v = gr.Video(autoplay=False, label="The animated video in the original image space") with gr.Column(): - with gr.Accordion(open=True, label="The animated video"): - output_video_concat_i2v.render() + output_video_concat_i2v = gr.Video(autoplay=False, label="The animated video") with gr.Row(): - process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear") + with gr.Column(): + output_image_i2i = gr.Image(type="numpy", label="The animated image in the original image space", visible=False) + with gr.Column(): + output_image_concat_i2i = gr.Image(type="numpy", label="The animated image", visible=False) + with gr.Row(): + process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear") with gr.Row(): # Examples @@ -279,7 +297,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San flag_do_crop_input, flag_remap_input, flag_crop_driving_video_input, - flag_video_editing_head_rotation, driving_smooth_observation_variance, ], outputs=[output_image, output_image_paste_back], @@ -373,6 +390,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San video_retargeting_source_scale.render() video_lip_retargeting_slider.render() driving_smooth_observation_variance_retargeting.render() + video_retargeting_silence.render() with gr.Row(visible=True): process_button_retargeting_video = gr.Button("πŸš— Retargeting Video", variant="primary") with gr.Row(visible=True): @@ -383,9 +401,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San examples=[ [osp.join(example_portrait_dir, "s13.mp4")], # [osp.join(example_portrait_dir, "s18.mp4")], - [osp.join(example_portrait_dir, "s20.mp4")], + # [osp.join(example_portrait_dir, "s20.mp4")], [osp.join(example_portrait_dir, "s29.mp4")], [osp.join(example_portrait_dir, "s32.mp4")], + [osp.join(example_video_dir, "d3.mp4")], ], inputs=[retargeting_input_video], cache_examples=False, @@ -413,16 +432,17 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San inputs=[ source_image_input, source_video_input, - driving_video_pickle_input, driving_video_input, + driving_image_input, + driving_video_pickle_input, flag_relative_input, flag_do_crop_input, flag_remap_input, flag_stitching_input, + animation_region, driving_option_input, driving_multiplier, flag_crop_driving_video_input, - flag_video_editing_head_rotation, scale, vx_ratio, vy_ratio, @@ -433,10 +453,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San tab_selection, v_tab_selection, ], - outputs=[output_video_i2v, output_video_concat_i2v], + outputs=[output_video_i2v, output_video_i2v, output_video_concat_i2v, output_video_concat_i2v, output_image_i2i, output_image_i2i, output_image_concat_i2i, output_image_concat_i2i], show_progress=True ) + retargeting_input_image.change( fn=gradio_pipeline.init_retargeting_image, inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image], @@ -458,7 +479,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San process_button_retargeting_video.click( fn=gpu_wrapped_execute_video_retargeting, - inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video], + inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, video_retargeting_silence, flag_do_crop_input_retargeting_video], outputs=[output_video, output_video_paste_back], show_progress=True ) diff --git a/assets/docs/changelog/2024-08-19.md b/assets/docs/changelog/2024-08-19.md new file mode 100644 index 0000000..828c7ae --- /dev/null +++ b/assets/docs/changelog/2024-08-19.md @@ -0,0 +1,59 @@ +## Image Driven and Regional Control + +You can now **use an image as a driving signal** to drive the source image or video! Additionally, we **have refined the driving options to support expressions, pose, lips, eyes, or all** (all is consistent with the previous default method), which we name it regional control. The control is becoming more and more precise! 🎯 + +> Please note that image-based driving or regional control may not perform well in certain cases. Feel free to try different options, and be patient. 😊 + +> [!Note] +> We recognize that the project now offers more options, which have become increasingly complex, but due to our limited team capacity and resources, we haven’t fully documented them yet. We ask for your understanding and will work to improve the documentation over time. Contributions via PRs are welcome! If anyone is considering donating or sponsoring, feel free to leave a message in the GitHub Issues or Discussions. We will set up a payment account to reward the team members or support additional efforts in maintaining the project. πŸ’– + + +### CLI Usage +It's very simple to use an image as a driving reference. Just set the `-d` argument to the driving image: + +```bash +python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d30.jpg +``` + +To change the `animation_region` option, you can use the `--animation_region` argument to `exp`, `pose`, `lip`, `eyes`, or `all`. For example, to only drive the lip region, you can run by: + +```bash +# only driving the lip region +python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d0.mp4 --animation_region lip +``` + +### Gradio Interface + +

+ LivePortrait +
+ Image-driven Portrait Animation and Regional Control +

+ +### More Detailed Explanation + +**flag_relative_motion**: +When using an image as the driving input, setting `--flag_relative_motion` to true will apply the motion deformation between the driving image and its canonical form. If set to false, the absolute motion of the driving image is used, which may amplify expression driving strength but could also cause identity leakage. This option corresponds to the `relative motion` toggle in the Gradio interface. Additionally, if both source and driving inputs are images, the output will be an image. If the source is a video and the driving input is an image, the output will be a video, with each frame driven by the image's motion. The Gradio interface automatically saves and displays the output in the appropriate format. + +**animation_region**: +This argument offers five options: + +- `exp`: Only the expression of the driving input influences the source. +- `pose`: Only the head pose drives the source. +- `lip`: Only lip movement drives the source. +- `eyes`: Only eye movement drives the source. +- `all`: All motions from the driving input are applied. + +You can also select these options directly in the Gradio interface. + +**Editing the Lip Region of the Source Video to a Neutral Expression**: +In response to requests for a more neutral lip region in the `Retargeting Video` of the Gradio interface, we've added a `keeping the lip silent` option. When selected, the animated video's lip region will adopt a neutral expression. However, this may cause inter-frame jitter or identity leakage, as it uses a mode similar to absolute driving. Note that the neutral expression may sometimes feature a slightly open mouth. + +**Others**: +When both source and driving inputs are videos, the output motion may be a blend of both, due to the default setting of `--flag_relative_motion`. This option uses relative driving, where the motion offset of the current driving frame relative to the first driving frame is added to the source frame's motion. In contrast, `--no_flag_relative_motion` applies the driving frame's motion directly as the final driving motion. + +For CLI usage, to retain only the driving video's motion in the output, use: +```bash +python inference.py --no_flag_relative_motion +``` +In the Gradio interface, simply uncheck the relative motion option. Note that absolute driving may cause jitter or identity leakage in the animated video. diff --git a/assets/docs/image-driven-portrait-animation-2024-08-19.jpg b/assets/docs/image-driven-portrait-animation-2024-08-19.jpg new file mode 100644 index 0000000..db14759 Binary files /dev/null and b/assets/docs/image-driven-portrait-animation-2024-08-19.jpg differ diff --git a/assets/examples/driving/d12.jpg b/assets/examples/driving/d12.jpg new file mode 100644 index 0000000..ad32430 Binary files /dev/null and b/assets/examples/driving/d12.jpg differ diff --git a/assets/examples/driving/d19.jpg b/assets/examples/driving/d19.jpg new file mode 100644 index 0000000..59cb845 Binary files /dev/null and b/assets/examples/driving/d19.jpg differ diff --git a/assets/examples/driving/d30.jpg b/assets/examples/driving/d30.jpg new file mode 100644 index 0000000..d47b480 Binary files /dev/null and b/assets/examples/driving/d30.jpg differ diff --git a/assets/examples/driving/d38.jpg b/assets/examples/driving/d38.jpg new file mode 100644 index 0000000..2a061d8 Binary files /dev/null and b/assets/examples/driving/d38.jpg differ diff --git a/assets/examples/driving/d8.jpg b/assets/examples/driving/d8.jpg new file mode 100644 index 0000000..f7b4ec9 Binary files /dev/null and b/assets/examples/driving/d8.jpg differ diff --git a/assets/examples/driving/d9.jpg b/assets/examples/driving/d9.jpg new file mode 100644 index 0000000..8fe5c16 Binary files /dev/null and b/assets/examples/driving/d9.jpg differ diff --git a/readme.md b/readme.md index e4a6488..b78de8e 100644 --- a/readme.md +++ b/readme.md @@ -38,7 +38,8 @@ ## πŸ”₯ Updates -- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, insipred by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md). +- **`2024/08/19`**: πŸ–ΌοΈ We support **image driven mode** and **regional control**. For details, see [**here**](./assets/docs/changelog/2024-08-19.md). +- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md). - **`2024/08/05`**: πŸ“¦ Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md). - **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)! - **`2024/07/25`**: πŸ“¦ Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main). Simply unzip and double-click `run_windows.bat` to enjoy! @@ -247,6 +248,9 @@ And many more amazing contributions from our community! ## Acknowledgements πŸ’ We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions. +## Ethics Considerations πŸ›‘οΈ +Portrait animation technologies come with social risks, particularly the potential for misuse in creating deepfakes. To mitigate these risks, it’s crucial to follow ethical guidelines and adopt responsible usage practices. At present, the synthesized results contain visual artifacts that may help in detecting deepfakes. Please note that we do not assume any legal responsibility for the use of the results generated by this project. + ## Citation πŸ’– If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX: ```bibtex diff --git a/src/config/argument_config.py b/src/config/argument_config.py index bea2d2f..055f5f8 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -22,9 +22,8 @@ class ArgumentConfig(PrintableConfig): flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video device_id: int = 0 # gpu device id flag_force_cpu: bool = False # force cpu inference, WIP! - flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False + flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering - flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal @@ -35,6 +34,7 @@ class ArgumentConfig(PrintableConfig): driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly" driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video + animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose, "all" means all regions ########## source crop arguments ########## det_thresh: float = 0.15 # detection threshold scale: float = 2.3 # the ratio of face area is smaller if scale is larger diff --git a/src/config/inference_config.py b/src/config/inference_config.py index 38f1ecf..c9ed197 100644 --- a/src/config/inference_config.py +++ b/src/config/inference_config.py @@ -6,10 +6,14 @@ config dataclass used for inference import cv2 from numpy import ndarray +import pickle as pkl from dataclasses import dataclass, field from typing import Literal, Tuple from .base_config import PrintableConfig, make_abs_path +def load_lip_array(): + with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f: + return pkl.load(f) @dataclass(repr=False) # use repr from PrintableConfig class InferenceConfig(PrintableConfig): @@ -34,7 +38,6 @@ class InferenceConfig(PrintableConfig): device_id: int = 0 flag_normalize_lip: bool = True flag_source_video_eye_retargeting: bool = False - flag_video_editing_head_rotation: bool = False flag_eye_retargeting: bool = False flag_lip_retargeting: bool = False flag_stitching: bool = True @@ -49,6 +52,7 @@ class InferenceConfig(PrintableConfig): driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy source_max_dim: int = 1280 # the max dim of height and width of source image or video source_division: int = 2 # make sure the height and width of source image or video can be divided by this number + animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose # NOT EXPORTED PARAMS lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip @@ -61,4 +65,5 @@ class InferenceConfig(PrintableConfig): output_fps: int = 25 # default output fps mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) + lip_array: ndarray = field(default_factory=load_lip_array) size_gif: int = 256 # default gif size, TO IMPLEMENT diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index 38dcde5..5b63fe1 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -146,16 +146,18 @@ class GradioPipeline(LivePortraitPipeline): self, input_source_image_path=None, input_source_video_path=None, - input_driving_video_pickle_path=None, input_driving_video_path=None, + input_driving_image_path=None, + input_driving_video_pickle_path=None, flag_relative_input=True, flag_do_crop_input=True, flag_remap_input=True, flag_stitching_input=True, + animation_region="all", driving_option_input="pose-friendly", driving_multiplier=1.0, flag_crop_driving_video_input=True, - flag_video_editing_head_rotation=False, + # flag_video_editing_head_rotation=False, scale=2.3, vx_ratio=0.0, vy_ratio=-0.125, @@ -177,6 +179,8 @@ class GradioPipeline(LivePortraitPipeline): if v_tab_selection == 'Video': input_driving_path = input_driving_video_path + elif v_tab_selection == 'Image': + input_driving_path = input_driving_image_path elif v_tab_selection == 'Pickle': input_driving_path = input_driving_video_pickle_path else: @@ -195,10 +199,10 @@ class GradioPipeline(LivePortraitPipeline): 'flag_do_crop': flag_do_crop_input, 'flag_pasteback': flag_remap_input, 'flag_stitching': flag_stitching_input, + 'animation_region': animation_region, 'driving_option': driving_option_input, 'driving_multiplier': driving_multiplier, 'flag_crop_driving_video': flag_crop_driving_video_input, - 'flag_video_editing_head_rotation': flag_video_editing_head_rotation, 'scale': scale, 'vx_ratio': vx_ratio, 'vy_ratio': vy_ratio, @@ -211,10 +215,13 @@ class GradioPipeline(LivePortraitPipeline): self.args = update_args(self.args, args_user) self.live_portrait_wrapper.update_config(self.args.__dict__) self.cropper.update_config(self.args.__dict__) - # video driven animation - video_path, video_path_concat = self.execute(self.args) + + output_path, output_path_concat = self.execute(self.args) gr.Info("Run successfully!", duration=2) - return video_path, video_path_concat + if output_path.endswith(".jpg"): + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True) + else: + return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) else: raise gr.Error("Please upload the source portrait or source video, and driving video πŸ€—πŸ€—πŸ€—", duration=5) @@ -308,6 +315,7 @@ class GradioPipeline(LivePortraitPipeline): if input_lip_ratio != self.source_lip_ratio: combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) + print(lip_delta) x_d_new = x_d_new + \ (eyes_delta if eyes_delta is not None else 0) + \ (lip_delta if lip_delta is not None else 0) @@ -388,29 +396,51 @@ class GradioPipeline(LivePortraitPipeline): return source_eye_ratio, source_lip_ratio @torch.no_grad() - def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True): + def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True): """ retargeting the lip-open ratio of each source frame """ # disposable feature device = self.live_portrait_wrapper.device - f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \ - self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video) - if input_lip_ratio is None: - raise gr.Error("Invalid ratio input πŸ’₯!", duration=5) + if not video_retargeting_silence: + f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \ + self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video) + if input_lip_ratio is None: + raise gr.Error("Invalid ratio input πŸ’₯!", duration=5) + else: + inference_cfg = self.live_portrait_wrapper.inference_cfg + + I_p_pstbk_lst = None + if flag_do_crop_input_retargeting_video: + I_p_pstbk_lst = [] + I_p_lst = [] + for i in track(range(n_frames), description='Retargeting video...', total=n_frames): + x_s_user_i = x_s_user_lst[i].to(device) + f_s_user_i = f_s_user_lst[i].to(device) + + lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i] + x_d_i_new = x_s_user_i + lip_delta_retargeting + x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new) + out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new) + I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] + I_p_lst.append(I_p_i) + + if flag_do_crop_input_retargeting_video: + I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) + I_p_pstbk_lst.append(I_p_pstbk) else: inference_cfg = self.live_portrait_wrapper.inference_cfg + f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \ + self.prepare_video_lip_silence(input_video, device, flag_do_crop=flag_do_crop_input_retargeting_video) I_p_pstbk_lst = None if flag_do_crop_input_retargeting_video: I_p_pstbk_lst = [] I_p_lst = [] - for i in track(range(n_frames), description='Retargeting video...', total=n_frames): + for i in track(range(n_frames), description='Silencing lip...', total=n_frames): x_s_user_i = x_s_user_lst[i].to(device) f_s_user_i = f_s_user_lst[i].to(device) - - lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i] - x_d_i_new = x_s_user_i + lip_delta_retargeting + x_d_i_new = x_d_i_new_lst[i] x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new) out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new) I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] @@ -420,37 +450,37 @@ class GradioPipeline(LivePortraitPipeline): I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) I_p_pstbk_lst.append(I_p_pstbk) - mkdir(self.args.output_dir) - flag_source_has_audio = has_audio_stream(input_video) + mkdir(self.args.output_dir) + flag_source_has_audio = has_audio_stream(input_video) - ######### build the final concatenation result ######### - # source frame | generation - frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst) - wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4') - images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps) + ######### build the final concatenation result ######### + # source frame | generation + frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst) + wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4') + images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps) - if flag_source_has_audio: - # final result with concatenation - wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4') - add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio) - os.replace(wfp_concat_with_audio, wfp_concat) - log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") + if flag_source_has_audio: + # final result with concatenation + wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4') + add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio) + os.replace(wfp_concat_with_audio, wfp_concat) + log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") - # save the animated result - wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4') - if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: - images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps) - else: - images2video(I_p_lst, wfp=wfp, fps=source_fps) + # save the animated result + wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4') + if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: + images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps) + else: + images2video(I_p_lst, wfp=wfp, fps=source_fps) - ######### build the final result ######### - if flag_source_has_audio: - wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4') - add_audio_to_video(wfp, input_video, wfp_with_audio) - os.replace(wfp_with_audio, wfp) - log(f"Replace {wfp_with_audio} with {wfp}") - gr.Info("Run successfully!", duration=2) - return wfp_concat, wfp + ######### build the final result ######### + if flag_source_has_audio: + wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4') + add_audio_to_video(wfp, input_video, wfp_with_audio) + os.replace(wfp_with_audio, wfp) + log(f"Replace {wfp_with_audio} with {wfp}") + gr.Info("Run successfully!", duration=2) + return wfp_concat, wfp @torch.no_grad() def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True): @@ -503,12 +533,64 @@ class GradioPipeline(LivePortraitPipeline): f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32)) lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting) - return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames else: # when press the clear button, go here raise gr.Error("Please upload a source video as the retargeting input πŸ€—πŸ€—πŸ€—", duration=5) + @torch.no_grad() + def prepare_video_lip_silence(self, input_video, device, flag_do_crop=True): + """ for keeping lips in the source video silent + """ + if input_video is not None: + inference_cfg = self.live_portrait_wrapper.inference_cfg + ######## process source video ######## + source_rgb_lst = load_video(input_video) + source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst] + source_fps = int(get_fps(input_video)) + n_frames = len(source_rgb_lst) + log(f"Load source video from {input_video}. FPS is {source_fps}") + + if flag_do_crop: + ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg) + log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.') + if len(ret_s["frame_crop_lst"]) != n_frames: + n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"])) + img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst'] + mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst] + else: + source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst) + img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256 + source_M_c2o_lst, mask_ori_lst = None, None + + c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst) + # save the motion template + I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst) + source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps) + + f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], [] + for i in track(range(n_frames), description='Preparing silencing lip...', total=n_frames): + x_s_info = source_template_dct['motion'][i] + x_s_info = dct2device(x_s_info, device) + scale_s = x_s_info['scale'] + x_s_user = x_s_info['x_s'] + x_c_s = x_s_info['kp'] + R_s = x_s_info['R'] + t_s = x_s_info['t'] + delta_new = torch.zeros_like(x_s_info['exp']) + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device) + for eyes_idx in [11, 13, 15, 16, 18]: + delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :] + source_lmk = source_lmk_crop_lst[i] + img_crop_256x256 = img_crop_256x256_lst[i] + I_s = I_s_lst[i] + f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) + x_d_i_new = scale_s * (x_c_s @ R_s + delta_new) + t_s + f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new) + return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames + else: + # when press the clear button, go here + raise gr.Error("Please upload a source video as the input πŸ€—πŸ€—πŸ€—", duration=5) + class GradioPipelineAnimal(LivePortraitPipelineAnimal): """gradio for animal """ diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index 7f06d2a..396b427 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -72,7 +72,6 @@ class LivePortraitPipeline(object): c_lip = c_lip_lst[i].astype(np.float32) template_dct['c_lip_lst'].append(c_lip) - return template_dct def execute(self, args: ArgumentConfig): @@ -111,8 +110,11 @@ class LivePortraitPipeline(object): c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] driving_n_frames = driving_template_dct['n_frames'] - if flag_is_source_video: + flag_is_driving_video = True if driving_n_frames > 1 else False + if flag_is_source_video and flag_is_driving_video: n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames + elif flag_is_source_video and not flag_is_driving_video: + n_frames = len(source_rgb_lst) else: n_frames = driving_n_frames @@ -123,25 +125,35 @@ class LivePortraitPipeline(object): if args.flag_crop_driving_video: log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") - elif osp.exists(args.driving) and is_video(args.driving): - # load from video file, AND make motion template - output_fps = int(get_fps(args.driving)) - log(f"Load driving video from: {args.driving}, FPS is {output_fps}") - - driving_rgb_lst = load_video(args.driving) - driving_n_frames = len(driving_rgb_lst) - + elif osp.exists(args.driving): + if is_video(args.driving): + flag_is_driving_video = True + # load from video file, AND make motion template + output_fps = int(get_fps(args.driving)) + log(f"Load driving video from: {args.driving}, FPS is {output_fps}") + driving_rgb_lst = load_video(args.driving) + elif is_image(args.driving): + flag_is_driving_video = False + driving_img_rgb = load_image_rgb(args.driving) + output_fps = 25 + log(f"Load driving image from {args.driving}") + driving_rgb_lst = [driving_img_rgb] + else: + raise Exception(f"{args.driving} is not a supported type!") ######## make motion template ######## log("Start making driving motion template...") - if flag_is_source_video: + driving_n_frames = len(driving_rgb_lst) + if flag_is_source_video and flag_is_driving_video: n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames driving_rgb_lst = driving_rgb_lst[:n_frames] + elif flag_is_source_video and not flag_is_driving_video: + n_frames = len(source_rgb_lst) else: n_frames = driving_n_frames if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)): ret_d = self.cropper.crop_driving_video(driving_rgb_lst) log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') - if len(ret_d["frame_crop_lst"]) is not n_frames: + if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video: n_frames = min(n_frames, len(ret_d["frame_crop_lst"])) driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst'] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] @@ -158,9 +170,11 @@ class LivePortraitPipeline(object): wfp_template = remove_suffix(args.driving) + '.pkl' dump(wfp_template, driving_template_dct) log(f"Dump motion template to {wfp_template}") - else: - raise Exception(f"{args.driving} not exists or unsupported driving info types!") + raise Exception(f"{args.driving} does not exist!") + if not flag_is_driving_video: + c_d_eyes_lst = c_d_eyes_lst*n_frames + c_d_lip_lst = c_d_lip_lst*n_frames ######## prepare for pasteback ######## I_p_pstbk_lst = None @@ -196,17 +210,33 @@ class LivePortraitPipeline(object): key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys if inf_cfg.flag_relative_motion: - x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] - x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) - if inf_cfg.flag_video_editing_head_rotation: - x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] - x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) + if flag_is_driving_video: + x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] + x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) + else: + x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)] + x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst] + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + if flag_is_driving_video: + x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] + x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) + else: + x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)] + x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst] else: - x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)] - x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) - if inf_cfg.flag_video_editing_head_rotation: - x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)] - x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) + if flag_is_driving_video: + x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)] + x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) + else: + x_d_exp_lst = [driving_template_dct['motion'][0]['exp']] + x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]*n_frames + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + if flag_is_driving_video: + x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)] + x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) + else: + x_d_r_lst = [driving_template_dct['motion'][0][key_r]] + x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]*n_frames else: # if the input is a source image, process it only once if inf_cfg.flag_do_crop: @@ -236,7 +266,10 @@ class LivePortraitPipeline(object): mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) ######## animate ######## - log(f"The animated video consists of {n_frames} frames.") + if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video): + log(f"The animated video consists of {n_frames} frames.") + else: + log(f"The output of image-driven portrait animation is an image.") for i in track(range(n_frames), description='πŸš€Animating...', total=n_frames): if flag_is_source_video: # source video x_s_info = source_template_dct['motion'][i] @@ -272,43 +305,88 @@ class LivePortraitPipeline(object): if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0])) - - x_d_i_info = driving_template_dct['motion'][i] + if flag_is_source_video and not flag_is_driving_video: + x_d_i_info = driving_template_dct['motion'][0] + else: + x_d_i_info = driving_template_dct['motion'][i] x_d_i_info = dct2device(x_d_i_info, device) R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys if i == 0: # cache the first frame R_d_0 = R_d_i - x_d_0_info = x_d_i_info + x_d_0_info = x_d_i_info.copy() + delta_new = x_s_info['exp'].clone() if inf_cfg.flag_relative_motion: - if flag_is_source_video: - if inf_cfg.flag_video_editing_head_rotation: - R_new = x_d_r_lst_smooth[i] - else: - R_new = R_s + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + R_new = x_d_r_lst_smooth[i] if flag_is_source_video else (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s else: - R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s - - delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) - scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) - t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) + R_new = R_s + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": + if flag_is_source_video: + for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: + delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] + delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] + delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] + delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] + delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] + else: + if flag_is_driving_video: + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + else: + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)) + elif inf_cfg.animation_region == "lip": + for lip_idx in [6, 12, 14, 17, 19, 20]: + if flag_is_source_video: + delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] + elif flag_is_driving_video: + delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] + else: + delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :] + elif inf_cfg.animation_region == "eyes": + for eyes_idx in [11, 13, 15, 16, 18]: + if flag_is_source_video: + delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] + elif flag_is_driving_video: + delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :] + else: + delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, eyes_idx, :] + if inf_cfg.animation_region == "all": + scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) + else: + scale_new = x_s_info['scale'] + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) + else: + t_new = x_s_info['t'] else: - if flag_is_source_video: - if inf_cfg.flag_video_editing_head_rotation: - R_new = x_d_r_lst_smooth[i] - else: - R_new = R_s + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + R_new = x_d_r_lst_smooth[i] if flag_is_source_video else R_d_i else: - R_new = R_d_i - delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp'] + R_new = R_s + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": + for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: + delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] if flag_is_source_video else x_d_i_info['exp'][:, idx, :] + delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] if flag_is_source_video else x_d_i_info['exp'][:, 3:5, 1] + delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] if flag_is_source_video else x_d_i_info['exp'][:, 5, 2] + delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] if flag_is_source_video else x_d_i_info['exp'][:, 8, 2] + delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] if flag_is_source_video else x_d_i_info['exp'][:, 9, 1:] + elif inf_cfg.animation_region == "lip": + for lip_idx in [6, 12, 14, 17, 19, 20]: + delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, lip_idx, :] + elif inf_cfg.animation_region == "eyes": + for eyes_idx in [11, 13, 15, 16, 18]: + delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, eyes_idx, :] scale_new = x_s_info['scale'] - t_new = x_d_i_info['t'] + if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": + t_new = x_d_i_info['t'] + else: + t_new = x_s_info['t'] t_new[..., 2].fill_(0) # zero tz x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new - if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video: + if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video: if i == 0: x_d_0_new = x_d_i_new motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) @@ -373,50 +451,68 @@ class LivePortraitPipeline(object): mkdir(args.output_dir) wfp_concat = None - flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) - flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) - ######### build the final concatenation result ######### - # driving frame | source frame | generation, or source frame | generation - if flag_is_source_video: + # driving frame | source frame | generation + if flag_is_source_video and flag_is_driving_video: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) + elif flag_is_source_video and not flag_is_driving_video: + if flag_load_from_template: + frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) + else: + frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst) else: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst) - wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') - # NOTE: update output fps - output_fps = source_fps if flag_is_source_video else output_fps - images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) + if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video): + flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) + flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) - if flag_source_has_audio or flag_driving_has_audio: - # final result with concatenation - wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') - audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source - log(f"Audio is selected from {audio_from_which_video}, concat mode") - add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) - os.replace(wfp_concat_with_audio, wfp_concat) - log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") + wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') - # save the animated result - wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') - if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: - images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) + # NOTE: update output fps + output_fps = source_fps if flag_is_source_video else output_fps + images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) + + if flag_source_has_audio or flag_driving_has_audio: + # final result with concatenation + wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') + audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source + log(f"Audio is selected from {audio_from_which_video}, concat mode") + add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) + os.replace(wfp_concat_with_audio, wfp_concat) + log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") + + # save the animated result + wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') + if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: + images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) + else: + images2video(I_p_lst, wfp=wfp, fps=output_fps) + + ######### build the final result ######### + if flag_source_has_audio or flag_driving_has_audio: + wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') + audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source + log(f"Audio is selected from {audio_from_which_video}") + add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) + os.replace(wfp_with_audio, wfp) + log(f"Replace {wfp_with_audio} with {wfp}") + + # final log + if wfp_template not in (None, ''): + log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') + log(f'Animated video: {wfp}') + log(f'Animated video with concat: {wfp_concat}') else: - images2video(I_p_lst, wfp=wfp, fps=output_fps) - - ######### build the final result ######### - if flag_source_has_audio or flag_driving_has_audio: - wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') - audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source - log(f"Audio is selected from {audio_from_which_video}") - add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) - os.replace(wfp_with_audio, wfp) - log(f"Replace {wfp_with_audio} with {wfp}") - - # final log - if wfp_template not in (None, ''): - log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') - log(f'Animated video: {wfp}') - log(f'Animated video with concat: {wfp_concat}') + wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg') + cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1]) + wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg') + if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: + cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1]) + else: + cv2.imwrite(wfp, frames_concatenated[0][..., ::-1]) + # final log + log(f'Animated image: {wfp}') + log(f'Animated image with concat: {wfp_concat}') return wfp, wfp_concat diff --git a/src/utils/helper.py b/src/utils/helper.py index 98bcdf8..dda9a69 100644 --- a/src/utils/helper.py +++ b/src/utils/helper.py @@ -100,7 +100,10 @@ def squeeze_tensor_to_numpy(tensor): def dct2device(dct: dict, device): for key in dct: - dct[key] = torch.tensor(dct[key]).to(device) + if isinstance(dct[key], torch.Tensor): + dct[key] = dct[key].to(device) + else: + dct[key] = torch.tensor(dct[key]).to(device) return dct diff --git a/src/utils/resources/lip_array.pkl b/src/utils/resources/lip_array.pkl new file mode 100644 index 0000000..61fe803 Binary files /dev/null and b/src/utils/resources/lip_array.pkl differ