feat: support image driven mode and regional control (#336)

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: image driven, regional animation

* feat: image driven, regional animation

* feat: image driven, regional animation, doc

* feat: image driven, regional animation

* feat: image driven, regional animation

* feat: image driven, regional animation

* chore: refactor

* doc: update readme

* doc: update changelog

* feat: image driven, regional control

---------

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
This commit is contained in:
Jianzhu Guo 2024-08-19 23:08:57 +08:00 committed by GitHub
parent 8a7682aaa4
commit a19c3e15fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 421 additions and 151 deletions

59
app.py
View File

@ -85,12 +85,12 @@ data_examples_i2v = [
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True], [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
] ]
data_examples_v2v = [ data_examples_v2v = [
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7], [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7], # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7], # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7], [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, 3e-7],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7], # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7], [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
] ]
#################### interface logic #################### #################### interface logic ####################
@ -98,6 +98,7 @@ data_examples_v2v = [
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale") retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale") video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8) driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
video_retargeting_silence = gr.Checkbox(value=False, label="keeping the lip silent")
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
@ -124,9 +125,6 @@ retargeting_output_image = gr.Image(type="numpy")
retargeting_output_image_paste_back = gr.Image(type="numpy") retargeting_output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video(autoplay=False) output_video = gr.Video(autoplay=False)
output_video_paste_back = gr.Video(autoplay=False) output_video_paste_back = gr.Video(autoplay=False)
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo: with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md)) gr.HTML(load_description(title_md))
@ -196,6 +194,22 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[driving_video_input], inputs=[driving_video_input],
cache_examples=False, cache_examples=False,
) )
with gr.TabItem("🖼️ Driving Image") as v_tab_image:
with gr.Accordion(open=True, label="Driving Image"):
driving_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_video_dir, "d30.jpg")],
[osp.join(example_video_dir, "d9.jpg")],
[osp.join(example_video_dir, "d19.jpg")],
[osp.join(example_video_dir, "d8.jpg")],
[osp.join(example_video_dir, "d12.jpg")],
[osp.join(example_video_dir, "d38.jpg")],
],
inputs=[driving_image_input],
cache_examples=False,
)
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle: with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"): with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"]) driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
@ -212,8 +226,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
) )
v_tab_selection = gr.Textbox(visible=False) v_tab_selection = gr.Textbox(visible=False)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
v_tab_video.select(lambda: "Video", None, v_tab_selection) v_tab_video.select(lambda: "Video", None, v_tab_selection)
v_tab_image.select(lambda: "Image", None, v_tab_selection)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
# with gr.Accordion(open=False, label="Animation Instructions"): # with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md")) # gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"): with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
@ -229,9 +244,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_stitching_input = gr.Checkbox(value=True, label="stitching") flag_stitching_input = gr.Checkbox(value=True, label="stitching")
animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region")
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)") driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02) driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8) driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md")) gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
@ -239,13 +254,16 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
process_button_animation = gr.Button("🚀 Animate", variant="primary") process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"): output_video_i2v = gr.Video(autoplay=False, label="The animated video in the original image space")
output_video_i2v.render()
with gr.Column(): with gr.Column():
with gr.Accordion(open=True, label="The animated video"): output_video_concat_i2v = gr.Video(autoplay=False, label="The animated video")
output_video_concat_i2v.render()
with gr.Row(): with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear") with gr.Column():
output_image_i2i = gr.Image(type="numpy", label="The animated image in the original image space", visible=False)
with gr.Column():
output_image_concat_i2i = gr.Image(type="numpy", label="The animated image", visible=False)
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear")
with gr.Row(): with gr.Row():
# Examples # Examples
@ -279,7 +297,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_do_crop_input, flag_do_crop_input,
flag_remap_input, flag_remap_input,
flag_crop_driving_video_input, flag_crop_driving_video_input,
flag_video_editing_head_rotation,
driving_smooth_observation_variance, driving_smooth_observation_variance,
], ],
outputs=[output_image, output_image_paste_back], outputs=[output_image, output_image_paste_back],
@ -373,6 +390,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
video_retargeting_source_scale.render() video_retargeting_source_scale.render()
video_lip_retargeting_slider.render() video_lip_retargeting_slider.render()
driving_smooth_observation_variance_retargeting.render() driving_smooth_observation_variance_retargeting.render()
video_retargeting_silence.render()
with gr.Row(visible=True): with gr.Row(visible=True):
process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary") process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
with gr.Row(visible=True): with gr.Row(visible=True):
@ -383,9 +401,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
examples=[ examples=[
[osp.join(example_portrait_dir, "s13.mp4")], [osp.join(example_portrait_dir, "s13.mp4")],
# [osp.join(example_portrait_dir, "s18.mp4")], # [osp.join(example_portrait_dir, "s18.mp4")],
[osp.join(example_portrait_dir, "s20.mp4")], # [osp.join(example_portrait_dir, "s20.mp4")],
[osp.join(example_portrait_dir, "s29.mp4")], [osp.join(example_portrait_dir, "s29.mp4")],
[osp.join(example_portrait_dir, "s32.mp4")], [osp.join(example_portrait_dir, "s32.mp4")],
[osp.join(example_video_dir, "d3.mp4")],
], ],
inputs=[retargeting_input_video], inputs=[retargeting_input_video],
cache_examples=False, cache_examples=False,
@ -413,16 +432,17 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[ inputs=[
source_image_input, source_image_input,
source_video_input, source_video_input,
driving_video_pickle_input,
driving_video_input, driving_video_input,
driving_image_input,
driving_video_pickle_input,
flag_relative_input, flag_relative_input,
flag_do_crop_input, flag_do_crop_input,
flag_remap_input, flag_remap_input,
flag_stitching_input, flag_stitching_input,
animation_region,
driving_option_input, driving_option_input,
driving_multiplier, driving_multiplier,
flag_crop_driving_video_input, flag_crop_driving_video_input,
flag_video_editing_head_rotation,
scale, scale,
vx_ratio, vx_ratio,
vy_ratio, vy_ratio,
@ -433,10 +453,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
tab_selection, tab_selection,
v_tab_selection, v_tab_selection,
], ],
outputs=[output_video_i2v, output_video_concat_i2v], outputs=[output_video_i2v, output_video_i2v, output_video_concat_i2v, output_video_concat_i2v, output_image_i2i, output_image_i2i, output_image_concat_i2i, output_image_concat_i2i],
show_progress=True show_progress=True
) )
retargeting_input_image.change( retargeting_input_image.change(
fn=gradio_pipeline.init_retargeting_image, fn=gradio_pipeline.init_retargeting_image,
inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image], inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image],
@ -458,7 +479,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
process_button_retargeting_video.click( process_button_retargeting_video.click(
fn=gpu_wrapped_execute_video_retargeting, fn=gpu_wrapped_execute_video_retargeting,
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video], inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, video_retargeting_silence, flag_do_crop_input_retargeting_video],
outputs=[output_video, output_video_paste_back], outputs=[output_video, output_video_paste_back],
show_progress=True show_progress=True
) )

View File

@ -0,0 +1,59 @@
## Image Driven and Regional Control
You can now **use an image as a driving signal** to drive the source image or video! Additionally, we **have refined the driving options to support expressions, pose, lips, eyes, or all** (all is consistent with the previous default method), which we name it regional control. The control is becoming more and more precise! 🎯
> Please note that image-based driving or regional control may not perform well in certain cases. Feel free to try different options, and be patient. 😊
> [!Note]
> We recognize that the project now offers more options, which have become increasingly complex, but due to our limited team capacity and resources, we havent fully documented them yet. We ask for your understanding and will work to improve the documentation over time. Contributions via PRs are welcome! If anyone is considering donating or sponsoring, feel free to leave a message in the GitHub Issues or Discussions. We will set up a payment account to reward the team members or support additional efforts in maintaining the project. 💖
### CLI Usage
It's very simple to use an image as a driving reference. Just set the `-d` argument to the driving image:
```bash
python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d30.jpg
```
To change the `animation_region` option, you can use the `--animation_region` argument to `exp`, `pose`, `lip`, `eyes`, or `all`. For example, to only drive the lip region, you can run by:
```bash
# only driving the lip region
python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d0.mp4 --animation_region lip
```
### Gradio Interface
<p align="center">
<img src="../image-driven-portrait-animation-2024-08-19.jpg" alt="LivePortrait" width="960px">
<br>
<strong>Image-driven Portrait Animation and Regional Control</strong>
</p>
### More Detailed Explanation
**flag_relative_motion**:
When using an image as the driving input, setting `--flag_relative_motion` to true will apply the motion deformation between the driving image and its canonical form. If set to false, the absolute motion of the driving image is used, which may amplify expression driving strength but could also cause identity leakage. This option corresponds to the `relative motion` toggle in the Gradio interface. Additionally, if both source and driving inputs are images, the output will be an image. If the source is a video and the driving input is an image, the output will be a video, with each frame driven by the image's motion. The Gradio interface automatically saves and displays the output in the appropriate format.
**animation_region**:
This argument offers five options:
- `exp`: Only the expression of the driving input influences the source.
- `pose`: Only the head pose drives the source.
- `lip`: Only lip movement drives the source.
- `eyes`: Only eye movement drives the source.
- `all`: All motions from the driving input are applied.
You can also select these options directly in the Gradio interface.
**Editing the Lip Region of the Source Video to a Neutral Expression**:
In response to requests for a more neutral lip region in the `Retargeting Video` of the Gradio interface, we've added a `keeping the lip silent` option. When selected, the animated video's lip region will adopt a neutral expression. However, this may cause inter-frame jitter or identity leakage, as it uses a mode similar to absolute driving. Note that the neutral expression may sometimes feature a slightly open mouth.
**Others**:
When both source and driving inputs are videos, the output motion may be a blend of both, due to the default setting of `--flag_relative_motion`. This option uses relative driving, where the motion offset of the current driving frame relative to the first driving frame is added to the source frame's motion. In contrast, `--no_flag_relative_motion` applies the driving frame's motion directly as the final driving motion.
For CLI usage, to retain only the driving video's motion in the output, use:
```bash
python inference.py --no_flag_relative_motion
```
In the Gradio interface, simply uncheck the relative motion option. Note that absolute driving may cause jitter or identity leakage in the animated video.

Binary file not shown.

After

Width:  |  Height:  |  Size: 544 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

View File

@ -38,7 +38,8 @@
## 🔥 Updates ## 🔥 Updates
- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, insipred by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md). - **`2024/08/19`**: 🖼️ We support **image driven mode** and **regional control**. For details, see [**here**](./assets/docs/changelog/2024-08-19.md).
- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md).
- **`2024/08/05`**: 📦 Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md). - **`2024/08/05`**: 📦 Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md).
- **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)! - **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)!
- **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main). Simply unzip and double-click `run_windows.bat` to enjoy! - **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main). Simply unzip and double-click `run_windows.bat` to enjoy!
@ -247,6 +248,9 @@ And many more amazing contributions from our community!
## Acknowledgements 💐 ## Acknowledgements 💐
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions. We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions.
## Ethics Considerations 🛡️
Portrait animation technologies come with social risks, particularly the potential for misuse in creating deepfakes. To mitigate these risks, its crucial to follow ethical guidelines and adopt responsible usage practices. At present, the synthesized results contain visual artifacts that may help in detecting deepfakes. Please note that we do not assume any legal responsibility for the use of the results generated by this project.
## Citation 💖 ## Citation 💖
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX: If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
```bibtex ```bibtex

View File

@ -22,9 +22,8 @@ class ArgumentConfig(PrintableConfig):
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
device_id: int = 0 # gpu device id device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP! flag_force_cpu: bool = False # force cpu inference, WIP!
flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
@ -35,6 +34,7 @@ class ArgumentConfig(PrintableConfig):
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly" driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose, "all" means all regions
########## source crop arguments ########## ########## source crop arguments ##########
det_thresh: float = 0.15 # detection threshold det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger scale: float = 2.3 # the ratio of face area is smaller if scale is larger

View File

@ -6,10 +6,14 @@ config dataclass used for inference
import cv2 import cv2
from numpy import ndarray from numpy import ndarray
import pickle as pkl
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path from .base_config import PrintableConfig, make_abs_path
def load_lip_array():
with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f:
return pkl.load(f)
@dataclass(repr=False) # use repr from PrintableConfig @dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig): class InferenceConfig(PrintableConfig):
@ -34,7 +38,6 @@ class InferenceConfig(PrintableConfig):
device_id: int = 0 device_id: int = 0
flag_normalize_lip: bool = True flag_normalize_lip: bool = True
flag_source_video_eye_retargeting: bool = False flag_source_video_eye_retargeting: bool = False
flag_video_editing_head_rotation: bool = False
flag_eye_retargeting: bool = False flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False flag_lip_retargeting: bool = False
flag_stitching: bool = True flag_stitching: bool = True
@ -49,6 +52,7 @@ class InferenceConfig(PrintableConfig):
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
source_max_dim: int = 1280 # the max dim of height and width of source image or video source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
# NOT EXPORTED PARAMS # NOT EXPORTED PARAMS
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
@ -61,4 +65,5 @@ class InferenceConfig(PrintableConfig):
output_fps: int = 25 # default output fps output_fps: int = 25 # default output fps
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
lip_array: ndarray = field(default_factory=load_lip_array)
size_gif: int = 256 # default gif size, TO IMPLEMENT size_gif: int = 256 # default gif size, TO IMPLEMENT

View File

@ -146,16 +146,18 @@ class GradioPipeline(LivePortraitPipeline):
self, self,
input_source_image_path=None, input_source_image_path=None,
input_source_video_path=None, input_source_video_path=None,
input_driving_video_pickle_path=None,
input_driving_video_path=None, input_driving_video_path=None,
input_driving_image_path=None,
input_driving_video_pickle_path=None,
flag_relative_input=True, flag_relative_input=True,
flag_do_crop_input=True, flag_do_crop_input=True,
flag_remap_input=True, flag_remap_input=True,
flag_stitching_input=True, flag_stitching_input=True,
animation_region="all",
driving_option_input="pose-friendly", driving_option_input="pose-friendly",
driving_multiplier=1.0, driving_multiplier=1.0,
flag_crop_driving_video_input=True, flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False, # flag_video_editing_head_rotation=False,
scale=2.3, scale=2.3,
vx_ratio=0.0, vx_ratio=0.0,
vy_ratio=-0.125, vy_ratio=-0.125,
@ -177,6 +179,8 @@ class GradioPipeline(LivePortraitPipeline):
if v_tab_selection == 'Video': if v_tab_selection == 'Video':
input_driving_path = input_driving_video_path input_driving_path = input_driving_video_path
elif v_tab_selection == 'Image':
input_driving_path = input_driving_image_path
elif v_tab_selection == 'Pickle': elif v_tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path input_driving_path = input_driving_video_pickle_path
else: else:
@ -195,10 +199,10 @@ class GradioPipeline(LivePortraitPipeline):
'flag_do_crop': flag_do_crop_input, 'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input, 'flag_pasteback': flag_remap_input,
'flag_stitching': flag_stitching_input, 'flag_stitching': flag_stitching_input,
'animation_region': animation_region,
'driving_option': driving_option_input, 'driving_option': driving_option_input,
'driving_multiplier': driving_multiplier, 'driving_multiplier': driving_multiplier,
'flag_crop_driving_video': flag_crop_driving_video_input, 'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale, 'scale': scale,
'vx_ratio': vx_ratio, 'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio, 'vy_ratio': vy_ratio,
@ -211,10 +215,13 @@ class GradioPipeline(LivePortraitPipeline):
self.args = update_args(self.args, args_user) self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__) self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__) self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args) output_path, output_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2) gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat if output_path.endswith(".jpg"):
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True)
else:
return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
else: else:
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5) raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
@ -308,6 +315,7 @@ class GradioPipeline(LivePortraitPipeline):
if input_lip_ratio != self.source_lip_ratio: if input_lip_ratio != self.source_lip_ratio:
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
print(lip_delta)
x_d_new = x_d_new + \ x_d_new = x_d_new + \
(eyes_delta if eyes_delta is not None else 0) + \ (eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0) (lip_delta if lip_delta is not None else 0)
@ -388,14 +396,15 @@ class GradioPipeline(LivePortraitPipeline):
return source_eye_ratio, source_lip_ratio return source_eye_ratio, source_lip_ratio
@torch.no_grad() @torch.no_grad()
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True): def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True):
""" retargeting the lip-open ratio of each source frame """ retargeting the lip-open ratio of each source frame
""" """
# disposable feature # disposable feature
device = self.live_portrait_wrapper.device device = self.live_portrait_wrapper.device
if not video_retargeting_silence:
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \ f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video) self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
if input_lip_ratio is None: if input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5) raise gr.Error("Invalid ratio input 💥!", duration=5)
else: else:
@ -416,6 +425,27 @@ class GradioPipeline(LivePortraitPipeline):
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i) I_p_lst.append(I_p_i)
if flag_do_crop_input_retargeting_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
I_p_pstbk_lst.append(I_p_pstbk)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \
self.prepare_video_lip_silence(input_video, device, flag_do_crop=flag_do_crop_input_retargeting_video)
I_p_pstbk_lst = None
if flag_do_crop_input_retargeting_video:
I_p_pstbk_lst = []
I_p_lst = []
for i in track(range(n_frames), description='Silencing lip...', total=n_frames):
x_s_user_i = x_s_user_lst[i].to(device)
f_s_user_i = f_s_user_lst[i].to(device)
x_d_i_new = x_d_i_new_lst[i]
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if flag_do_crop_input_retargeting_video: if flag_do_crop_input_retargeting_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
I_p_pstbk_lst.append(I_p_pstbk) I_p_pstbk_lst.append(I_p_pstbk)
@ -503,12 +533,64 @@ class GradioPipeline(LivePortraitPipeline):
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32)) f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32))
lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting) lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting)
return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames
else: else:
# when press the clear button, go here # when press the clear button, go here
raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5) raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
@torch.no_grad()
def prepare_video_lip_silence(self, input_video, device, flag_do_crop=True):
""" for keeping lips in the source video silent
"""
if input_video is not None:
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source video ########
source_rgb_lst = load_video(input_video)
source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
source_fps = int(get_fps(input_video))
n_frames = len(source_rgb_lst)
log(f"Load source video from {input_video}. FPS is {source_fps}")
if flag_do_crop:
ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
if len(ret_s["frame_crop_lst"]) != n_frames:
n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
else:
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
source_M_c2o_lst, mask_ori_lst = None, None
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
# save the motion template
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], []
for i in track(range(n_frames), description='Preparing silencing lip...', total=n_frames):
x_s_info = source_template_dct['motion'][i]
x_s_info = dct2device(x_s_info, device)
scale_s = x_s_info['scale']
x_s_user = x_s_info['x_s']
x_c_s = x_s_info['kp']
R_s = x_s_info['R']
t_s = x_s_info['t']
delta_new = torch.zeros_like(x_s_info['exp']) + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :]
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_d_i_new = scale_s * (x_c_s @ R_s + delta_new) + t_s
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new)
return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames
else:
# when press the clear button, go here
raise gr.Error("Please upload a source video as the input 🤗🤗🤗", duration=5)
class GradioPipelineAnimal(LivePortraitPipelineAnimal): class GradioPipelineAnimal(LivePortraitPipelineAnimal):
"""gradio for animal """gradio for animal
""" """

View File

@ -72,7 +72,6 @@ class LivePortraitPipeline(object):
c_lip = c_lip_lst[i].astype(np.float32) c_lip = c_lip_lst[i].astype(np.float32)
template_dct['c_lip_lst'].append(c_lip) template_dct['c_lip_lst'].append(c_lip)
return template_dct return template_dct
def execute(self, args: ArgumentConfig): def execute(self, args: ArgumentConfig):
@ -111,8 +110,11 @@ class LivePortraitPipeline(object):
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames'] driving_n_frames = driving_template_dct['n_frames']
if flag_is_source_video: flag_is_driving_video = True if driving_n_frames > 1 else False
if flag_is_source_video and flag_is_driving_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
elif flag_is_source_video and not flag_is_driving_video:
n_frames = len(source_rgb_lst)
else: else:
n_frames = driving_n_frames n_frames = driving_n_frames
@ -123,25 +125,35 @@ class LivePortraitPipeline(object):
if args.flag_crop_driving_video: if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving) and is_video(args.driving): elif osp.exists(args.driving):
if is_video(args.driving):
flag_is_driving_video = True
# load from video file, AND make motion template # load from video file, AND make motion template
output_fps = int(get_fps(args.driving)) output_fps = int(get_fps(args.driving))
log(f"Load driving video from: {args.driving}, FPS is {output_fps}") log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_rgb_lst = load_video(args.driving) driving_rgb_lst = load_video(args.driving)
driving_n_frames = len(driving_rgb_lst) elif is_image(args.driving):
flag_is_driving_video = False
driving_img_rgb = load_image_rgb(args.driving)
output_fps = 25
log(f"Load driving image from {args.driving}")
driving_rgb_lst = [driving_img_rgb]
else:
raise Exception(f"{args.driving} is not a supported type!")
######## make motion template ######## ######## make motion template ########
log("Start making driving motion template...") log("Start making driving motion template...")
if flag_is_source_video: driving_n_frames = len(driving_rgb_lst)
if flag_is_source_video and flag_is_driving_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
driving_rgb_lst = driving_rgb_lst[:n_frames] driving_rgb_lst = driving_rgb_lst[:n_frames]
elif flag_is_source_video and not flag_is_driving_video:
n_frames = len(source_rgb_lst)
else: else:
n_frames = driving_n_frames n_frames = driving_n_frames
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)): if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
ret_d = self.cropper.crop_driving_video(driving_rgb_lst) ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames: if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"])) n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst'] driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
@ -158,9 +170,11 @@ class LivePortraitPipeline(object):
wfp_template = remove_suffix(args.driving) + '.pkl' wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, driving_template_dct) dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}") log(f"Dump motion template to {wfp_template}")
else: else:
raise Exception(f"{args.driving} not exists or unsupported driving info types!") raise Exception(f"{args.driving} does not exist!")
if not flag_is_driving_video:
c_d_eyes_lst = c_d_eyes_lst*n_frames
c_d_lip_lst = c_d_lip_lst*n_frames
######## prepare for pasteback ######## ######## prepare for pasteback ########
I_p_pstbk_lst = None I_p_pstbk_lst = None
@ -196,17 +210,33 @@ class LivePortraitPipeline(object):
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
if inf_cfg.flag_relative_motion: if inf_cfg.flag_relative_motion:
if flag_is_driving_video:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation: else:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
if flag_is_driving_video:
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else: else:
x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]
else:
if flag_is_driving_video:
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)] x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation: else:
x_d_exp_lst = [driving_template_dct['motion'][0]['exp']]
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]*n_frames
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
if flag_is_driving_video:
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)] x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_r_lst = [driving_template_dct['motion'][0][key_r]]
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]*n_frames
else: # if the input is a source image, process it only once else: # if the input is a source image, process it only once
if inf_cfg.flag_do_crop: if inf_cfg.flag_do_crop:
@ -236,7 +266,10 @@ class LivePortraitPipeline(object):
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
######## animate ######## ######## animate ########
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
log(f"The animated video consists of {n_frames} frames.") log(f"The animated video consists of {n_frames} frames.")
else:
log(f"The output of image-driven portrait animation is an image.")
for i in track(range(n_frames), description='🚀Animating...', total=n_frames): for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
if flag_is_source_video: # source video if flag_is_source_video: # source video
x_s_info = source_template_dct['motion'][i] x_s_info = source_template_dct['motion'][i]
@ -272,43 +305,88 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0])) mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
if flag_is_source_video and not flag_is_driving_video:
x_d_i_info = driving_template_dct['motion'][0]
else:
x_d_i_info = driving_template_dct['motion'][i] x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device) x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
if i == 0: # cache the first frame if i == 0: # cache the first frame
R_d_0 = R_d_i R_d_0 = R_d_i
x_d_0_info = x_d_i_info x_d_0_info = x_d_i_info.copy()
delta_new = x_s_info['exp'].clone()
if inf_cfg.flag_relative_motion: if inf_cfg.flag_relative_motion:
if flag_is_source_video: if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
if inf_cfg.flag_video_editing_head_rotation: R_new = x_d_r_lst_smooth[i] if flag_is_source_video else (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
R_new = x_d_r_lst_smooth[i]
else: else:
R_new = R_s R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
if flag_is_source_video:
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
else: else:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s if flag_is_driving_video:
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) else:
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
elif inf_cfg.animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
if flag_is_source_video:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
elif flag_is_driving_video:
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
else:
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :]
elif inf_cfg.animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
if flag_is_source_video:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
elif flag_is_driving_video:
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
else:
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, eyes_idx, :]
if inf_cfg.animation_region == "all":
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
else:
scale_new = x_s_info['scale']
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else: else:
if flag_is_source_video: t_new = x_s_info['t']
if inf_cfg.flag_video_editing_head_rotation: else:
R_new = x_d_r_lst_smooth[i] if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
R_new = x_d_r_lst_smooth[i] if flag_is_source_video else R_d_i
else: else:
R_new = R_s R_new = R_s
else: if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
R_new = R_d_i for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp'] delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] if flag_is_source_video else x_d_i_info['exp'][:, idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] if flag_is_source_video else x_d_i_info['exp'][:, 3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] if flag_is_source_video else x_d_i_info['exp'][:, 5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] if flag_is_source_video else x_d_i_info['exp'][:, 8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] if flag_is_source_video else x_d_i_info['exp'][:, 9, 1:]
elif inf_cfg.animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, lip_idx, :]
elif inf_cfg.animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, eyes_idx, :]
scale_new = x_s_info['scale'] scale_new = x_s_info['scale']
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
t_new = x_d_i_info['t'] t_new = x_d_i_info['t']
else:
t_new = x_s_info['t']
t_new[..., 2].fill_(0) # zero tz t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video: if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video:
if i == 0: if i == 0:
x_d_0_new = x_d_i_new x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
@ -373,15 +451,22 @@ class LivePortraitPipeline(object):
mkdir(args.output_dir) mkdir(args.output_dir)
wfp_concat = None wfp_concat = None
######### build the final concatenation result #########
# driving frame | source frame | generation
if flag_is_source_video and flag_is_driving_video:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
elif flag_is_source_video and not flag_is_driving_video:
if flag_load_from_template:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
######### build the final concatenation result #########
# driving frame | source frame | generation, or source frame | generation
if flag_is_source_video:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# NOTE: update output fps # NOTE: update output fps
@ -418,5 +503,16 @@ class LivePortraitPipeline(object):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}') log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}') log(f'Animated video with concat: {wfp_concat}')
else:
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg')
cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1])
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1])
else:
cv2.imwrite(wfp, frames_concatenated[0][..., ::-1])
# final log
log(f'Animated image: {wfp}')
log(f'Animated image with concat: {wfp_concat}')
return wfp, wfp_concat return wfp, wfp_concat

View File

@ -100,6 +100,9 @@ def squeeze_tensor_to_numpy(tensor):
def dct2device(dct: dict, device): def dct2device(dct: dict, device):
for key in dct: for key in dct:
if isinstance(dct[key], torch.Tensor):
dct[key] = dct[key].to(device)
else:
dct[key] = torch.tensor(dct[key]).to(device) dct[key] = torch.tensor(dct[key]).to(device)
return dct return dct

Binary file not shown.