feat: support image driven mode and regional control (#336)
* feat: update * feat: update * feat: update * feat: update * feat: update * feat: update * feat: image driven, regional animation * feat: image driven, regional animation * feat: image driven, regional animation, doc * feat: image driven, regional animation * feat: image driven, regional animation * feat: image driven, regional animation * chore: refactor * doc: update readme * doc: update changelog * feat: image driven, regional control --------- Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
59
app.py
@ -85,12 +85,12 @@ data_examples_i2v = [
|
||||
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
|
||||
]
|
||||
data_examples_v2v = [
|
||||
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, 3e-7],
|
||||
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
|
||||
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
|
||||
]
|
||||
#################### interface logic ####################
|
||||
|
||||
@ -98,6 +98,7 @@ data_examples_v2v = [
|
||||
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
|
||||
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
|
||||
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||
video_retargeting_silence = gr.Checkbox(value=False, label="keeping the lip silent")
|
||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||
video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||
@ -124,9 +125,6 @@ retargeting_output_image = gr.Image(type="numpy")
|
||||
retargeting_output_image_paste_back = gr.Image(type="numpy")
|
||||
output_video = gr.Video(autoplay=False)
|
||||
output_video_paste_back = gr.Video(autoplay=False)
|
||||
output_video_i2v = gr.Video(autoplay=False)
|
||||
output_video_concat_i2v = gr.Video(autoplay=False)
|
||||
|
||||
|
||||
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
|
||||
gr.HTML(load_description(title_md))
|
||||
@ -196,6 +194,22 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
inputs=[driving_video_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
with gr.TabItem("🖼️ Driving Image") as v_tab_image:
|
||||
with gr.Accordion(open=True, label="Driving Image"):
|
||||
driving_image_input = gr.Image(type="filepath")
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[osp.join(example_video_dir, "d30.jpg")],
|
||||
[osp.join(example_video_dir, "d9.jpg")],
|
||||
[osp.join(example_video_dir, "d19.jpg")],
|
||||
[osp.join(example_video_dir, "d8.jpg")],
|
||||
[osp.join(example_video_dir, "d12.jpg")],
|
||||
[osp.join(example_video_dir, "d38.jpg")],
|
||||
],
|
||||
inputs=[driving_image_input],
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
|
||||
with gr.Accordion(open=True, label="Driving Pickle"):
|
||||
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
|
||||
@ -212,8 +226,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
)
|
||||
|
||||
v_tab_selection = gr.Textbox(visible=False)
|
||||
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
|
||||
v_tab_video.select(lambda: "Video", None, v_tab_selection)
|
||||
v_tab_image.select(lambda: "Image", None, v_tab_selection)
|
||||
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
|
||||
# with gr.Accordion(open=False, label="Animation Instructions"):
|
||||
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
|
||||
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
|
||||
@ -229,9 +244,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||
flag_stitching_input = gr.Checkbox(value=True, label="stitching")
|
||||
animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region")
|
||||
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
|
||||
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
|
||||
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
|
||||
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
|
||||
|
||||
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
|
||||
@ -239,13 +254,16 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
process_button_animation = gr.Button("🚀 Animate", variant="primary")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video in the original image space"):
|
||||
output_video_i2v.render()
|
||||
output_video_i2v = gr.Video(autoplay=False, label="The animated video in the original image space")
|
||||
with gr.Column():
|
||||
with gr.Accordion(open=True, label="The animated video"):
|
||||
output_video_concat_i2v.render()
|
||||
output_video_concat_i2v = gr.Video(autoplay=False, label="The animated video")
|
||||
with gr.Row():
|
||||
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
|
||||
with gr.Column():
|
||||
output_image_i2i = gr.Image(type="numpy", label="The animated image in the original image space", visible=False)
|
||||
with gr.Column():
|
||||
output_image_concat_i2i = gr.Image(type="numpy", label="The animated image", visible=False)
|
||||
with gr.Row():
|
||||
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear")
|
||||
|
||||
with gr.Row():
|
||||
# Examples
|
||||
@ -279,7 +297,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
flag_crop_driving_video_input,
|
||||
flag_video_editing_head_rotation,
|
||||
driving_smooth_observation_variance,
|
||||
],
|
||||
outputs=[output_image, output_image_paste_back],
|
||||
@ -373,6 +390,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
video_retargeting_source_scale.render()
|
||||
video_lip_retargeting_slider.render()
|
||||
driving_smooth_observation_variance_retargeting.render()
|
||||
video_retargeting_silence.render()
|
||||
with gr.Row(visible=True):
|
||||
process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
|
||||
with gr.Row(visible=True):
|
||||
@ -383,9 +401,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
examples=[
|
||||
[osp.join(example_portrait_dir, "s13.mp4")],
|
||||
# [osp.join(example_portrait_dir, "s18.mp4")],
|
||||
[osp.join(example_portrait_dir, "s20.mp4")],
|
||||
# [osp.join(example_portrait_dir, "s20.mp4")],
|
||||
[osp.join(example_portrait_dir, "s29.mp4")],
|
||||
[osp.join(example_portrait_dir, "s32.mp4")],
|
||||
[osp.join(example_video_dir, "d3.mp4")],
|
||||
],
|
||||
inputs=[retargeting_input_video],
|
||||
cache_examples=False,
|
||||
@ -413,16 +432,17 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
inputs=[
|
||||
source_image_input,
|
||||
source_video_input,
|
||||
driving_video_pickle_input,
|
||||
driving_video_input,
|
||||
driving_image_input,
|
||||
driving_video_pickle_input,
|
||||
flag_relative_input,
|
||||
flag_do_crop_input,
|
||||
flag_remap_input,
|
||||
flag_stitching_input,
|
||||
animation_region,
|
||||
driving_option_input,
|
||||
driving_multiplier,
|
||||
flag_crop_driving_video_input,
|
||||
flag_video_editing_head_rotation,
|
||||
scale,
|
||||
vx_ratio,
|
||||
vy_ratio,
|
||||
@ -433,10 +453,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
tab_selection,
|
||||
v_tab_selection,
|
||||
],
|
||||
outputs=[output_video_i2v, output_video_concat_i2v],
|
||||
outputs=[output_video_i2v, output_video_i2v, output_video_concat_i2v, output_video_concat_i2v, output_image_i2i, output_image_i2i, output_image_concat_i2i, output_image_concat_i2i],
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
|
||||
retargeting_input_image.change(
|
||||
fn=gradio_pipeline.init_retargeting_image,
|
||||
inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image],
|
||||
@ -458,7 +479,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
|
||||
|
||||
process_button_retargeting_video.click(
|
||||
fn=gpu_wrapped_execute_video_retargeting,
|
||||
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video],
|
||||
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, video_retargeting_silence, flag_do_crop_input_retargeting_video],
|
||||
outputs=[output_video, output_video_paste_back],
|
||||
show_progress=True
|
||||
)
|
||||
|
59
assets/docs/changelog/2024-08-19.md
Normal file
@ -0,0 +1,59 @@
|
||||
## Image Driven and Regional Control
|
||||
|
||||
You can now **use an image as a driving signal** to drive the source image or video! Additionally, we **have refined the driving options to support expressions, pose, lips, eyes, or all** (all is consistent with the previous default method), which we name it regional control. The control is becoming more and more precise! 🎯
|
||||
|
||||
> Please note that image-based driving or regional control may not perform well in certain cases. Feel free to try different options, and be patient. 😊
|
||||
|
||||
> [!Note]
|
||||
> We recognize that the project now offers more options, which have become increasingly complex, but due to our limited team capacity and resources, we haven’t fully documented them yet. We ask for your understanding and will work to improve the documentation over time. Contributions via PRs are welcome! If anyone is considering donating or sponsoring, feel free to leave a message in the GitHub Issues or Discussions. We will set up a payment account to reward the team members or support additional efforts in maintaining the project. 💖
|
||||
|
||||
|
||||
### CLI Usage
|
||||
It's very simple to use an image as a driving reference. Just set the `-d` argument to the driving image:
|
||||
|
||||
```bash
|
||||
python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d30.jpg
|
||||
```
|
||||
|
||||
To change the `animation_region` option, you can use the `--animation_region` argument to `exp`, `pose`, `lip`, `eyes`, or `all`. For example, to only drive the lip region, you can run by:
|
||||
|
||||
```bash
|
||||
# only driving the lip region
|
||||
python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d0.mp4 --animation_region lip
|
||||
```
|
||||
|
||||
### Gradio Interface
|
||||
|
||||
<p align="center">
|
||||
<img src="../image-driven-portrait-animation-2024-08-19.jpg" alt="LivePortrait" width="960px">
|
||||
<br>
|
||||
<strong>Image-driven Portrait Animation and Regional Control</strong>
|
||||
</p>
|
||||
|
||||
### More Detailed Explanation
|
||||
|
||||
**flag_relative_motion**:
|
||||
When using an image as the driving input, setting `--flag_relative_motion` to true will apply the motion deformation between the driving image and its canonical form. If set to false, the absolute motion of the driving image is used, which may amplify expression driving strength but could also cause identity leakage. This option corresponds to the `relative motion` toggle in the Gradio interface. Additionally, if both source and driving inputs are images, the output will be an image. If the source is a video and the driving input is an image, the output will be a video, with each frame driven by the image's motion. The Gradio interface automatically saves and displays the output in the appropriate format.
|
||||
|
||||
**animation_region**:
|
||||
This argument offers five options:
|
||||
|
||||
- `exp`: Only the expression of the driving input influences the source.
|
||||
- `pose`: Only the head pose drives the source.
|
||||
- `lip`: Only lip movement drives the source.
|
||||
- `eyes`: Only eye movement drives the source.
|
||||
- `all`: All motions from the driving input are applied.
|
||||
|
||||
You can also select these options directly in the Gradio interface.
|
||||
|
||||
**Editing the Lip Region of the Source Video to a Neutral Expression**:
|
||||
In response to requests for a more neutral lip region in the `Retargeting Video` of the Gradio interface, we've added a `keeping the lip silent` option. When selected, the animated video's lip region will adopt a neutral expression. However, this may cause inter-frame jitter or identity leakage, as it uses a mode similar to absolute driving. Note that the neutral expression may sometimes feature a slightly open mouth.
|
||||
|
||||
**Others**:
|
||||
When both source and driving inputs are videos, the output motion may be a blend of both, due to the default setting of `--flag_relative_motion`. This option uses relative driving, where the motion offset of the current driving frame relative to the first driving frame is added to the source frame's motion. In contrast, `--no_flag_relative_motion` applies the driving frame's motion directly as the final driving motion.
|
||||
|
||||
For CLI usage, to retain only the driving video's motion in the output, use:
|
||||
```bash
|
||||
python inference.py --no_flag_relative_motion
|
||||
```
|
||||
In the Gradio interface, simply uncheck the relative motion option. Note that absolute driving may cause jitter or identity leakage in the animated video.
|
BIN
assets/docs/image-driven-portrait-animation-2024-08-19.jpg
Normal file
After Width: | Height: | Size: 544 KiB |
BIN
assets/examples/driving/d12.jpg
Normal file
After Width: | Height: | Size: 97 KiB |
BIN
assets/examples/driving/d19.jpg
Normal file
After Width: | Height: | Size: 67 KiB |
BIN
assets/examples/driving/d30.jpg
Normal file
After Width: | Height: | Size: 75 KiB |
BIN
assets/examples/driving/d38.jpg
Normal file
After Width: | Height: | Size: 74 KiB |
BIN
assets/examples/driving/d8.jpg
Normal file
After Width: | Height: | Size: 91 KiB |
BIN
assets/examples/driving/d9.jpg
Normal file
After Width: | Height: | Size: 82 KiB |
@ -38,7 +38,8 @@
|
||||
|
||||
|
||||
## 🔥 Updates
|
||||
- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, insipred by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md).
|
||||
- **`2024/08/19`**: 🖼️ We support **image driven mode** and **regional control**. For details, see [**here**](./assets/docs/changelog/2024-08-19.md).
|
||||
- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md).
|
||||
- **`2024/08/05`**: 📦 Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md).
|
||||
- **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)!
|
||||
- **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main). Simply unzip and double-click `run_windows.bat` to enjoy!
|
||||
@ -247,6 +248,9 @@ And many more amazing contributions from our community!
|
||||
## Acknowledgements 💐
|
||||
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions.
|
||||
|
||||
## Ethics Considerations 🛡️
|
||||
Portrait animation technologies come with social risks, particularly the potential for misuse in creating deepfakes. To mitigate these risks, it’s crucial to follow ethical guidelines and adopt responsible usage practices. At present, the synthesized results contain visual artifacts that may help in detecting deepfakes. Please note that we do not assume any legal responsibility for the use of the results generated by this project.
|
||||
|
||||
## Citation 💖
|
||||
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
|
||||
```bibtex
|
||||
|
@ -22,9 +22,8 @@ class ArgumentConfig(PrintableConfig):
|
||||
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
|
||||
device_id: int = 0 # gpu device id
|
||||
flag_force_cpu: bool = False # force cpu inference, WIP!
|
||||
flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
||||
flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
||||
flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
|
||||
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
|
||||
flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
|
||||
flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
|
||||
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
|
||||
@ -35,6 +34,7 @@ class ArgumentConfig(PrintableConfig):
|
||||
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
|
||||
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose, "all" means all regions
|
||||
########## source crop arguments ##########
|
||||
det_thresh: float = 0.15 # detection threshold
|
||||
scale: float = 2.3 # the ratio of face area is smaller if scale is larger
|
||||
|
@ -6,10 +6,14 @@ config dataclass used for inference
|
||||
|
||||
import cv2
|
||||
from numpy import ndarray
|
||||
import pickle as pkl
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Tuple
|
||||
from .base_config import PrintableConfig, make_abs_path
|
||||
|
||||
def load_lip_array():
|
||||
with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f:
|
||||
return pkl.load(f)
|
||||
|
||||
@dataclass(repr=False) # use repr from PrintableConfig
|
||||
class InferenceConfig(PrintableConfig):
|
||||
@ -34,7 +38,6 @@ class InferenceConfig(PrintableConfig):
|
||||
device_id: int = 0
|
||||
flag_normalize_lip: bool = True
|
||||
flag_source_video_eye_retargeting: bool = False
|
||||
flag_video_editing_head_rotation: bool = False
|
||||
flag_eye_retargeting: bool = False
|
||||
flag_lip_retargeting: bool = False
|
||||
flag_stitching: bool = True
|
||||
@ -49,6 +52,7 @@ class InferenceConfig(PrintableConfig):
|
||||
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
||||
source_max_dim: int = 1280 # the max dim of height and width of source image or video
|
||||
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
|
||||
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
|
||||
|
||||
# NOT EXPORTED PARAMS
|
||||
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
|
||||
@ -61,4 +65,5 @@ class InferenceConfig(PrintableConfig):
|
||||
output_fps: int = 25 # default output fps
|
||||
|
||||
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
|
||||
lip_array: ndarray = field(default_factory=load_lip_array)
|
||||
size_gif: int = 256 # default gif size, TO IMPLEMENT
|
||||
|
@ -146,16 +146,18 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
self,
|
||||
input_source_image_path=None,
|
||||
input_source_video_path=None,
|
||||
input_driving_video_pickle_path=None,
|
||||
input_driving_video_path=None,
|
||||
input_driving_image_path=None,
|
||||
input_driving_video_pickle_path=None,
|
||||
flag_relative_input=True,
|
||||
flag_do_crop_input=True,
|
||||
flag_remap_input=True,
|
||||
flag_stitching_input=True,
|
||||
animation_region="all",
|
||||
driving_option_input="pose-friendly",
|
||||
driving_multiplier=1.0,
|
||||
flag_crop_driving_video_input=True,
|
||||
flag_video_editing_head_rotation=False,
|
||||
# flag_video_editing_head_rotation=False,
|
||||
scale=2.3,
|
||||
vx_ratio=0.0,
|
||||
vy_ratio=-0.125,
|
||||
@ -177,6 +179,8 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
|
||||
if v_tab_selection == 'Video':
|
||||
input_driving_path = input_driving_video_path
|
||||
elif v_tab_selection == 'Image':
|
||||
input_driving_path = input_driving_image_path
|
||||
elif v_tab_selection == 'Pickle':
|
||||
input_driving_path = input_driving_video_pickle_path
|
||||
else:
|
||||
@ -195,10 +199,10 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
'flag_do_crop': flag_do_crop_input,
|
||||
'flag_pasteback': flag_remap_input,
|
||||
'flag_stitching': flag_stitching_input,
|
||||
'animation_region': animation_region,
|
||||
'driving_option': driving_option_input,
|
||||
'driving_multiplier': driving_multiplier,
|
||||
'flag_crop_driving_video': flag_crop_driving_video_input,
|
||||
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
|
||||
'scale': scale,
|
||||
'vx_ratio': vx_ratio,
|
||||
'vy_ratio': vy_ratio,
|
||||
@ -211,10 +215,13 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
self.args = update_args(self.args, args_user)
|
||||
self.live_portrait_wrapper.update_config(self.args.__dict__)
|
||||
self.cropper.update_config(self.args.__dict__)
|
||||
# video driven animation
|
||||
video_path, video_path_concat = self.execute(self.args)
|
||||
|
||||
output_path, output_path_concat = self.execute(self.args)
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return video_path, video_path_concat
|
||||
if output_path.endswith(".jpg"):
|
||||
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True)
|
||||
else:
|
||||
return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
||||
else:
|
||||
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
|
||||
|
||||
@ -308,6 +315,7 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
if input_lip_ratio != self.source_lip_ratio:
|
||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
|
||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
|
||||
print(lip_delta)
|
||||
x_d_new = x_d_new + \
|
||||
(eyes_delta if eyes_delta is not None else 0) + \
|
||||
(lip_delta if lip_delta is not None else 0)
|
||||
@ -388,29 +396,51 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
return source_eye_ratio, source_lip_ratio
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True):
|
||||
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True):
|
||||
""" retargeting the lip-open ratio of each source frame
|
||||
"""
|
||||
# disposable feature
|
||||
device = self.live_portrait_wrapper.device
|
||||
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
|
||||
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
|
||||
|
||||
if input_lip_ratio is None:
|
||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||
if not video_retargeting_silence:
|
||||
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
|
||||
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
|
||||
if input_lip_ratio is None:
|
||||
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
||||
else:
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
|
||||
I_p_pstbk_lst = None
|
||||
if flag_do_crop_input_retargeting_video:
|
||||
I_p_pstbk_lst = []
|
||||
I_p_lst = []
|
||||
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
|
||||
x_s_user_i = x_s_user_lst[i].to(device)
|
||||
f_s_user_i = f_s_user_lst[i].to(device)
|
||||
|
||||
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
|
||||
x_d_i_new = x_s_user_i + lip_delta_retargeting
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
|
||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
I_p_lst.append(I_p_i)
|
||||
|
||||
if flag_do_crop_input_retargeting_video:
|
||||
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
|
||||
I_p_pstbk_lst.append(I_p_pstbk)
|
||||
else:
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \
|
||||
self.prepare_video_lip_silence(input_video, device, flag_do_crop=flag_do_crop_input_retargeting_video)
|
||||
|
||||
I_p_pstbk_lst = None
|
||||
if flag_do_crop_input_retargeting_video:
|
||||
I_p_pstbk_lst = []
|
||||
I_p_lst = []
|
||||
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
|
||||
for i in track(range(n_frames), description='Silencing lip...', total=n_frames):
|
||||
x_s_user_i = x_s_user_lst[i].to(device)
|
||||
f_s_user_i = f_s_user_lst[i].to(device)
|
||||
|
||||
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
|
||||
x_d_i_new = x_s_user_i + lip_delta_retargeting
|
||||
x_d_i_new = x_d_i_new_lst[i]
|
||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
|
||||
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
|
||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||
@ -420,37 +450,37 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
|
||||
I_p_pstbk_lst.append(I_p_pstbk)
|
||||
|
||||
mkdir(self.args.output_dir)
|
||||
flag_source_has_audio = has_audio_stream(input_video)
|
||||
mkdir(self.args.output_dir)
|
||||
flag_source_has_audio = has_audio_stream(input_video)
|
||||
|
||||
######### build the final concatenation result #########
|
||||
# source frame | generation
|
||||
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
|
||||
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
|
||||
######### build the final concatenation result #########
|
||||
# source frame | generation
|
||||
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
|
||||
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
|
||||
|
||||
if flag_source_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
|
||||
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
if flag_source_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
|
||||
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=source_fps)
|
||||
# save the animated result
|
||||
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=source_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio:
|
||||
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
|
||||
add_audio_to_video(wfp, input_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return wfp_concat, wfp
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio:
|
||||
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
|
||||
add_audio_to_video(wfp, input_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
gr.Info("Run successfully!", duration=2)
|
||||
return wfp_concat, wfp
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
|
||||
@ -503,12 +533,64 @@ class GradioPipeline(LivePortraitPipeline):
|
||||
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32))
|
||||
lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting)
|
||||
|
||||
|
||||
return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames
|
||||
else:
|
||||
# when press the clear button, go here
|
||||
raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_video_lip_silence(self, input_video, device, flag_do_crop=True):
|
||||
""" for keeping lips in the source video silent
|
||||
"""
|
||||
if input_video is not None:
|
||||
inference_cfg = self.live_portrait_wrapper.inference_cfg
|
||||
######## process source video ########
|
||||
source_rgb_lst = load_video(input_video)
|
||||
source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
|
||||
source_fps = int(get_fps(input_video))
|
||||
n_frames = len(source_rgb_lst)
|
||||
log(f"Load source video from {input_video}. FPS is {source_fps}")
|
||||
|
||||
if flag_do_crop:
|
||||
ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
|
||||
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
|
||||
if len(ret_s["frame_crop_lst"]) != n_frames:
|
||||
n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
|
||||
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
|
||||
mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
|
||||
else:
|
||||
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
|
||||
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
|
||||
source_M_c2o_lst, mask_ori_lst = None, None
|
||||
|
||||
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
|
||||
# save the motion template
|
||||
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
|
||||
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
|
||||
|
||||
f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], []
|
||||
for i in track(range(n_frames), description='Preparing silencing lip...', total=n_frames):
|
||||
x_s_info = source_template_dct['motion'][i]
|
||||
x_s_info = dct2device(x_s_info, device)
|
||||
scale_s = x_s_info['scale']
|
||||
x_s_user = x_s_info['x_s']
|
||||
x_c_s = x_s_info['kp']
|
||||
R_s = x_s_info['R']
|
||||
t_s = x_s_info['t']
|
||||
delta_new = torch.zeros_like(x_s_info['exp']) + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)
|
||||
for eyes_idx in [11, 13, 15, 16, 18]:
|
||||
delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :]
|
||||
source_lmk = source_lmk_crop_lst[i]
|
||||
img_crop_256x256 = img_crop_256x256_lst[i]
|
||||
I_s = I_s_lst[i]
|
||||
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||
x_d_i_new = scale_s * (x_c_s @ R_s + delta_new) + t_s
|
||||
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new)
|
||||
return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames
|
||||
else:
|
||||
# when press the clear button, go here
|
||||
raise gr.Error("Please upload a source video as the input 🤗🤗🤗", duration=5)
|
||||
|
||||
class GradioPipelineAnimal(LivePortraitPipelineAnimal):
|
||||
"""gradio for animal
|
||||
"""
|
||||
|
@ -72,7 +72,6 @@ class LivePortraitPipeline(object):
|
||||
c_lip = c_lip_lst[i].astype(np.float32)
|
||||
template_dct['c_lip_lst'].append(c_lip)
|
||||
|
||||
|
||||
return template_dct
|
||||
|
||||
def execute(self, args: ArgumentConfig):
|
||||
@ -111,8 +110,11 @@ class LivePortraitPipeline(object):
|
||||
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
|
||||
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
|
||||
driving_n_frames = driving_template_dct['n_frames']
|
||||
if flag_is_source_video:
|
||||
flag_is_driving_video = True if driving_n_frames > 1 else False
|
||||
if flag_is_source_video and flag_is_driving_video:
|
||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||
elif flag_is_source_video and not flag_is_driving_video:
|
||||
n_frames = len(source_rgb_lst)
|
||||
else:
|
||||
n_frames = driving_n_frames
|
||||
|
||||
@ -123,25 +125,35 @@ class LivePortraitPipeline(object):
|
||||
if args.flag_crop_driving_video:
|
||||
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
|
||||
|
||||
elif osp.exists(args.driving) and is_video(args.driving):
|
||||
# load from video file, AND make motion template
|
||||
output_fps = int(get_fps(args.driving))
|
||||
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
||||
|
||||
driving_rgb_lst = load_video(args.driving)
|
||||
driving_n_frames = len(driving_rgb_lst)
|
||||
|
||||
elif osp.exists(args.driving):
|
||||
if is_video(args.driving):
|
||||
flag_is_driving_video = True
|
||||
# load from video file, AND make motion template
|
||||
output_fps = int(get_fps(args.driving))
|
||||
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
|
||||
driving_rgb_lst = load_video(args.driving)
|
||||
elif is_image(args.driving):
|
||||
flag_is_driving_video = False
|
||||
driving_img_rgb = load_image_rgb(args.driving)
|
||||
output_fps = 25
|
||||
log(f"Load driving image from {args.driving}")
|
||||
driving_rgb_lst = [driving_img_rgb]
|
||||
else:
|
||||
raise Exception(f"{args.driving} is not a supported type!")
|
||||
######## make motion template ########
|
||||
log("Start making driving motion template...")
|
||||
if flag_is_source_video:
|
||||
driving_n_frames = len(driving_rgb_lst)
|
||||
if flag_is_source_video and flag_is_driving_video:
|
||||
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
|
||||
driving_rgb_lst = driving_rgb_lst[:n_frames]
|
||||
elif flag_is_source_video and not flag_is_driving_video:
|
||||
n_frames = len(source_rgb_lst)
|
||||
else:
|
||||
n_frames = driving_n_frames
|
||||
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
|
||||
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
|
||||
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
|
||||
if len(ret_d["frame_crop_lst"]) is not n_frames:
|
||||
if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video:
|
||||
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
|
||||
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
|
||||
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
|
||||
@ -158,9 +170,11 @@ class LivePortraitPipeline(object):
|
||||
wfp_template = remove_suffix(args.driving) + '.pkl'
|
||||
dump(wfp_template, driving_template_dct)
|
||||
log(f"Dump motion template to {wfp_template}")
|
||||
|
||||
else:
|
||||
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
|
||||
raise Exception(f"{args.driving} does not exist!")
|
||||
if not flag_is_driving_video:
|
||||
c_d_eyes_lst = c_d_eyes_lst*n_frames
|
||||
c_d_lip_lst = c_d_lip_lst*n_frames
|
||||
|
||||
######## prepare for pasteback ########
|
||||
I_p_pstbk_lst = None
|
||||
@ -196,17 +210,33 @@ class LivePortraitPipeline(object):
|
||||
|
||||
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
|
||||
if inf_cfg.flag_relative_motion:
|
||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if flag_is_driving_video:
|
||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else:
|
||||
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
if flag_is_driving_video:
|
||||
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else:
|
||||
x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]
|
||||
else:
|
||||
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
if flag_is_driving_video:
|
||||
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
|
||||
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else:
|
||||
x_d_exp_lst = [driving_template_dct['motion'][0]['exp']]
|
||||
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]*n_frames
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
if flag_is_driving_video:
|
||||
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
|
||||
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
|
||||
else:
|
||||
x_d_r_lst = [driving_template_dct['motion'][0][key_r]]
|
||||
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]*n_frames
|
||||
|
||||
else: # if the input is a source image, process it only once
|
||||
if inf_cfg.flag_do_crop:
|
||||
@ -236,7 +266,10 @@ class LivePortraitPipeline(object):
|
||||
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
|
||||
|
||||
######## animate ########
|
||||
log(f"The animated video consists of {n_frames} frames.")
|
||||
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
|
||||
log(f"The animated video consists of {n_frames} frames.")
|
||||
else:
|
||||
log(f"The output of image-driven portrait animation is an image.")
|
||||
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
|
||||
if flag_is_source_video: # source video
|
||||
x_s_info = source_template_dct['motion'][i]
|
||||
@ -272,43 +305,88 @@ class LivePortraitPipeline(object):
|
||||
|
||||
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
|
||||
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
|
||||
|
||||
x_d_i_info = driving_template_dct['motion'][i]
|
||||
if flag_is_source_video and not flag_is_driving_video:
|
||||
x_d_i_info = driving_template_dct['motion'][0]
|
||||
else:
|
||||
x_d_i_info = driving_template_dct['motion'][i]
|
||||
x_d_i_info = dct2device(x_d_i_info, device)
|
||||
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
|
||||
|
||||
if i == 0: # cache the first frame
|
||||
R_d_0 = R_d_i
|
||||
x_d_0_info = x_d_i_info
|
||||
x_d_0_info = x_d_i_info.copy()
|
||||
|
||||
delta_new = x_s_info['exp'].clone()
|
||||
if inf_cfg.flag_relative_motion:
|
||||
if flag_is_source_video:
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
R_new = x_d_r_lst_smooth[i]
|
||||
else:
|
||||
R_new = R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
R_new = x_d_r_lst_smooth[i] if flag_is_source_video else (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||
else:
|
||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
||||
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||
R_new = R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
|
||||
if flag_is_source_video:
|
||||
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
|
||||
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
|
||||
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
|
||||
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
|
||||
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
|
||||
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
|
||||
else:
|
||||
if flag_is_driving_video:
|
||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
||||
else:
|
||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
|
||||
elif inf_cfg.animation_region == "lip":
|
||||
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
||||
if flag_is_source_video:
|
||||
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
|
||||
elif flag_is_driving_video:
|
||||
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
|
||||
else:
|
||||
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :]
|
||||
elif inf_cfg.animation_region == "eyes":
|
||||
for eyes_idx in [11, 13, 15, 16, 18]:
|
||||
if flag_is_source_video:
|
||||
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
|
||||
elif flag_is_driving_video:
|
||||
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
|
||||
else:
|
||||
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, eyes_idx, :]
|
||||
if inf_cfg.animation_region == "all":
|
||||
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
||||
else:
|
||||
scale_new = x_s_info['scale']
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
||||
else:
|
||||
t_new = x_s_info['t']
|
||||
else:
|
||||
if flag_is_source_video:
|
||||
if inf_cfg.flag_video_editing_head_rotation:
|
||||
R_new = x_d_r_lst_smooth[i]
|
||||
else:
|
||||
R_new = R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
R_new = x_d_r_lst_smooth[i] if flag_is_source_video else R_d_i
|
||||
else:
|
||||
R_new = R_d_i
|
||||
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
|
||||
R_new = R_s
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
|
||||
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
|
||||
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] if flag_is_source_video else x_d_i_info['exp'][:, idx, :]
|
||||
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] if flag_is_source_video else x_d_i_info['exp'][:, 3:5, 1]
|
||||
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] if flag_is_source_video else x_d_i_info['exp'][:, 5, 2]
|
||||
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] if flag_is_source_video else x_d_i_info['exp'][:, 8, 2]
|
||||
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] if flag_is_source_video else x_d_i_info['exp'][:, 9, 1:]
|
||||
elif inf_cfg.animation_region == "lip":
|
||||
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
||||
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, lip_idx, :]
|
||||
elif inf_cfg.animation_region == "eyes":
|
||||
for eyes_idx in [11, 13, 15, 16, 18]:
|
||||
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, eyes_idx, :]
|
||||
scale_new = x_s_info['scale']
|
||||
t_new = x_d_i_info['t']
|
||||
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
|
||||
t_new = x_d_i_info['t']
|
||||
else:
|
||||
t_new = x_s_info['t']
|
||||
|
||||
t_new[..., 2].fill_(0) # zero tz
|
||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
||||
|
||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video:
|
||||
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video:
|
||||
if i == 0:
|
||||
x_d_0_new = x_d_i_new
|
||||
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
||||
@ -373,50 +451,68 @@ class LivePortraitPipeline(object):
|
||||
|
||||
mkdir(args.output_dir)
|
||||
wfp_concat = None
|
||||
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
|
||||
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
||||
|
||||
######### build the final concatenation result #########
|
||||
# driving frame | source frame | generation, or source frame | generation
|
||||
if flag_is_source_video:
|
||||
# driving frame | source frame | generation
|
||||
if flag_is_source_video and flag_is_driving_video:
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
||||
elif flag_is_source_video and not flag_is_driving_video:
|
||||
if flag_load_from_template:
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
|
||||
else:
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst)
|
||||
else:
|
||||
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
||||
|
||||
# NOTE: update output fps
|
||||
output_fps = source_fps if flag_is_source_video else output_fps
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
|
||||
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
|
||||
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
|
||||
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}, concat mode")
|
||||
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||
# NOTE: update output fps
|
||||
output_fps = source_fps if flag_is_source_video else output_fps
|
||||
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
|
||||
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
# final result with concatenation
|
||||
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}, concat mode")
|
||||
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
|
||||
os.replace(wfp_concat_with_audio, wfp_concat)
|
||||
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
|
||||
|
||||
# save the animated result
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}")
|
||||
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
|
||||
# final log
|
||||
if wfp_template not in (None, ''):
|
||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||
log(f'Animated video: {wfp}')
|
||||
log(f'Animated video with concat: {wfp_concat}')
|
||||
else:
|
||||
images2video(I_p_lst, wfp=wfp, fps=output_fps)
|
||||
|
||||
######### build the final result #########
|
||||
if flag_source_has_audio or flag_driving_has_audio:
|
||||
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
|
||||
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
|
||||
log(f"Audio is selected from {audio_from_which_video}")
|
||||
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
|
||||
os.replace(wfp_with_audio, wfp)
|
||||
log(f"Replace {wfp_with_audio} with {wfp}")
|
||||
|
||||
# final log
|
||||
if wfp_template not in (None, ''):
|
||||
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
|
||||
log(f'Animated video: {wfp}')
|
||||
log(f'Animated video with concat: {wfp_concat}')
|
||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg')
|
||||
cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1])
|
||||
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg')
|
||||
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
|
||||
cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1])
|
||||
else:
|
||||
cv2.imwrite(wfp, frames_concatenated[0][..., ::-1])
|
||||
# final log
|
||||
log(f'Animated image: {wfp}')
|
||||
log(f'Animated image with concat: {wfp_concat}')
|
||||
|
||||
return wfp, wfp_concat
|
||||
|
@ -100,7 +100,10 @@ def squeeze_tensor_to_numpy(tensor):
|
||||
|
||||
def dct2device(dct: dict, device):
|
||||
for key in dct:
|
||||
dct[key] = torch.tensor(dct[key]).to(device)
|
||||
if isinstance(dct[key], torch.Tensor):
|
||||
dct[key] = dct[key].to(device)
|
||||
else:
|
||||
dct[key] = torch.tensor(dct[key]).to(device)
|
||||
return dct
|
||||
|
||||
|
||||
|