feat: support image driven mode and regional control (#336)

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: update

* feat: image driven, regional animation

* feat: image driven, regional animation

* feat: image driven, regional animation, doc

* feat: image driven, regional animation

* feat: image driven, regional animation

* feat: image driven, regional animation

* chore: refactor

* doc: update readme

* doc: update changelog

* feat: image driven, regional control

---------

Co-authored-by: zhangdingyun <zhangdingyun@kuaishou.com>
This commit is contained in:
Jianzhu Guo 2024-08-19 23:08:57 +08:00 committed by GitHub
parent 8a7682aaa4
commit a19c3e15fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 421 additions and 151 deletions

59
app.py
View File

@ -85,12 +85,12 @@ data_examples_i2v = [
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
data_examples_v2v = [
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, 3e-7],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
]
#################### interface logic ####################
@ -98,6 +98,7 @@ data_examples_v2v = [
retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale")
video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale")
driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8)
video_retargeting_silence = gr.Checkbox(value=False, label="keeping the lip silent")
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
@ -124,9 +125,6 @@ retargeting_output_image = gr.Image(type="numpy")
retargeting_output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video(autoplay=False)
output_video_paste_back = gr.Video(autoplay=False)
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
gr.HTML(load_description(title_md))
@ -196,6 +194,22 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[driving_video_input],
cache_examples=False,
)
with gr.TabItem("🖼️ Driving Image") as v_tab_image:
with gr.Accordion(open=True, label="Driving Image"):
driving_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_video_dir, "d30.jpg")],
[osp.join(example_video_dir, "d9.jpg")],
[osp.join(example_video_dir, "d19.jpg")],
[osp.join(example_video_dir, "d8.jpg")],
[osp.join(example_video_dir, "d12.jpg")],
[osp.join(example_video_dir, "d38.jpg")],
],
inputs=[driving_image_input],
cache_examples=False,
)
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
@ -212,8 +226,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
)
v_tab_selection = gr.Textbox(visible=False)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
v_tab_video.select(lambda: "Video", None, v_tab_selection)
v_tab_image.select(lambda: "Image", None, v_tab_selection)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
# with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
@ -229,9 +244,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_stitching_input = gr.Checkbox(value=True, label="stitching")
animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region")
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
@ -239,13 +254,16 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video_i2v.render()
output_video_i2v = gr.Video(autoplay=False, label="The animated video in the original image space")
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
output_video_concat_i2v = gr.Video(autoplay=False, label="The animated video")
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
with gr.Column():
output_image_i2i = gr.Image(type="numpy", label="The animated image in the original image space", visible=False)
with gr.Column():
output_image_concat_i2i = gr.Image(type="numpy", label="The animated image", visible=False)
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear")
with gr.Row():
# Examples
@ -279,7 +297,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
driving_smooth_observation_variance,
],
outputs=[output_image, output_image_paste_back],
@ -373,6 +390,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
video_retargeting_source_scale.render()
video_lip_retargeting_slider.render()
driving_smooth_observation_variance_retargeting.render()
video_retargeting_silence.render()
with gr.Row(visible=True):
process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary")
with gr.Row(visible=True):
@ -383,9 +401,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
examples=[
[osp.join(example_portrait_dir, "s13.mp4")],
# [osp.join(example_portrait_dir, "s18.mp4")],
[osp.join(example_portrait_dir, "s20.mp4")],
# [osp.join(example_portrait_dir, "s20.mp4")],
[osp.join(example_portrait_dir, "s29.mp4")],
[osp.join(example_portrait_dir, "s32.mp4")],
[osp.join(example_video_dir, "d3.mp4")],
],
inputs=[retargeting_input_video],
cache_examples=False,
@ -413,16 +432,17 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[
source_image_input,
source_video_input,
driving_video_pickle_input,
driving_video_input,
driving_image_input,
driving_video_pickle_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_stitching_input,
animation_region,
driving_option_input,
driving_multiplier,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
scale,
vx_ratio,
vy_ratio,
@ -433,10 +453,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
tab_selection,
v_tab_selection,
],
outputs=[output_video_i2v, output_video_concat_i2v],
outputs=[output_video_i2v, output_video_i2v, output_video_concat_i2v, output_video_concat_i2v, output_image_i2i, output_image_i2i, output_image_concat_i2i, output_image_concat_i2i],
show_progress=True
)
retargeting_input_image.change(
fn=gradio_pipeline.init_retargeting_image,
inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image],
@ -458,7 +479,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
process_button_retargeting_video.click(
fn=gpu_wrapped_execute_video_retargeting,
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video],
inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, video_retargeting_silence, flag_do_crop_input_retargeting_video],
outputs=[output_video, output_video_paste_back],
show_progress=True
)

View File

@ -0,0 +1,59 @@
## Image Driven and Regional Control
You can now **use an image as a driving signal** to drive the source image or video! Additionally, we **have refined the driving options to support expressions, pose, lips, eyes, or all** (all is consistent with the previous default method), which we name it regional control. The control is becoming more and more precise! 🎯
> Please note that image-based driving or regional control may not perform well in certain cases. Feel free to try different options, and be patient. 😊
> [!Note]
> We recognize that the project now offers more options, which have become increasingly complex, but due to our limited team capacity and resources, we havent fully documented them yet. We ask for your understanding and will work to improve the documentation over time. Contributions via PRs are welcome! If anyone is considering donating or sponsoring, feel free to leave a message in the GitHub Issues or Discussions. We will set up a payment account to reward the team members or support additional efforts in maintaining the project. 💖
### CLI Usage
It's very simple to use an image as a driving reference. Just set the `-d` argument to the driving image:
```bash
python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d30.jpg
```
To change the `animation_region` option, you can use the `--animation_region` argument to `exp`, `pose`, `lip`, `eyes`, or `all`. For example, to only drive the lip region, you can run by:
```bash
# only driving the lip region
python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d0.mp4 --animation_region lip
```
### Gradio Interface
<p align="center">
<img src="../image-driven-portrait-animation-2024-08-19.jpg" alt="LivePortrait" width="960px">
<br>
<strong>Image-driven Portrait Animation and Regional Control</strong>
</p>
### More Detailed Explanation
**flag_relative_motion**:
When using an image as the driving input, setting `--flag_relative_motion` to true will apply the motion deformation between the driving image and its canonical form. If set to false, the absolute motion of the driving image is used, which may amplify expression driving strength but could also cause identity leakage. This option corresponds to the `relative motion` toggle in the Gradio interface. Additionally, if both source and driving inputs are images, the output will be an image. If the source is a video and the driving input is an image, the output will be a video, with each frame driven by the image's motion. The Gradio interface automatically saves and displays the output in the appropriate format.
**animation_region**:
This argument offers five options:
- `exp`: Only the expression of the driving input influences the source.
- `pose`: Only the head pose drives the source.
- `lip`: Only lip movement drives the source.
- `eyes`: Only eye movement drives the source.
- `all`: All motions from the driving input are applied.
You can also select these options directly in the Gradio interface.
**Editing the Lip Region of the Source Video to a Neutral Expression**:
In response to requests for a more neutral lip region in the `Retargeting Video` of the Gradio interface, we've added a `keeping the lip silent` option. When selected, the animated video's lip region will adopt a neutral expression. However, this may cause inter-frame jitter or identity leakage, as it uses a mode similar to absolute driving. Note that the neutral expression may sometimes feature a slightly open mouth.
**Others**:
When both source and driving inputs are videos, the output motion may be a blend of both, due to the default setting of `--flag_relative_motion`. This option uses relative driving, where the motion offset of the current driving frame relative to the first driving frame is added to the source frame's motion. In contrast, `--no_flag_relative_motion` applies the driving frame's motion directly as the final driving motion.
For CLI usage, to retain only the driving video's motion in the output, use:
```bash
python inference.py --no_flag_relative_motion
```
In the Gradio interface, simply uncheck the relative motion option. Note that absolute driving may cause jitter or identity leakage in the animated video.

Binary file not shown.

After

Width:  |  Height:  |  Size: 544 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

View File

@ -38,7 +38,8 @@
## 🔥 Updates
- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, insipred by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md).
- **`2024/08/19`**: 🖼️ We support **image driven mode** and **regional control**. For details, see [**here**](./assets/docs/changelog/2024-08-19.md).
- **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md).
- **`2024/08/05`**: 📦 Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md).
- **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)!
- **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main). Simply unzip and double-click `run_windows.bat` to enjoy!
@ -247,6 +248,9 @@ And many more amazing contributions from our community!
## Acknowledgements 💐
We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions.
## Ethics Considerations 🛡️
Portrait animation technologies come with social risks, particularly the potential for misuse in creating deepfakes. To mitigate these risks, its crucial to follow ethical guidelines and adopt responsible usage practices. At present, the synthesized results contain visual artifacts that may help in detecting deepfakes. Please note that we do not assume any legal responsibility for the use of the results generated by this project.
## Citation 💖
If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
```bibtex

View File

@ -22,9 +22,8 @@ class ArgumentConfig(PrintableConfig):
flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP!
flag_normalize_lip: bool = True # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
@ -35,6 +34,7 @@ class ArgumentConfig(PrintableConfig):
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose, "all" means all regions
########## source crop arguments ##########
det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger

View File

@ -6,10 +6,14 @@ config dataclass used for inference
import cv2
from numpy import ndarray
import pickle as pkl
from dataclasses import dataclass, field
from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path
def load_lip_array():
with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f:
return pkl.load(f)
@dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig):
@ -34,7 +38,6 @@ class InferenceConfig(PrintableConfig):
device_id: int = 0
flag_normalize_lip: bool = True
flag_source_video_eye_retargeting: bool = False
flag_video_editing_head_rotation: bool = False
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True
@ -49,6 +52,7 @@ class InferenceConfig(PrintableConfig):
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
# NOT EXPORTED PARAMS
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
@ -61,4 +65,5 @@ class InferenceConfig(PrintableConfig):
output_fps: int = 25 # default output fps
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
lip_array: ndarray = field(default_factory=load_lip_array)
size_gif: int = 256 # default gif size, TO IMPLEMENT

View File

@ -146,16 +146,18 @@ class GradioPipeline(LivePortraitPipeline):
self,
input_source_image_path=None,
input_source_video_path=None,
input_driving_video_pickle_path=None,
input_driving_video_path=None,
input_driving_image_path=None,
input_driving_video_pickle_path=None,
flag_relative_input=True,
flag_do_crop_input=True,
flag_remap_input=True,
flag_stitching_input=True,
animation_region="all",
driving_option_input="pose-friendly",
driving_multiplier=1.0,
flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False,
# flag_video_editing_head_rotation=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
@ -177,6 +179,8 @@ class GradioPipeline(LivePortraitPipeline):
if v_tab_selection == 'Video':
input_driving_path = input_driving_video_path
elif v_tab_selection == 'Image':
input_driving_path = input_driving_image_path
elif v_tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path
else:
@ -195,10 +199,10 @@ class GradioPipeline(LivePortraitPipeline):
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'flag_stitching': flag_stitching_input,
'animation_region': animation_region,
'driving_option': driving_option_input,
'driving_multiplier': driving_multiplier,
'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
@ -211,10 +215,13 @@ class GradioPipeline(LivePortraitPipeline):
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
output_path, output_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat
if output_path.endswith(".jpg"):
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True)
else:
return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
else:
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
@ -308,6 +315,7 @@ class GradioPipeline(LivePortraitPipeline):
if input_lip_ratio != self.source_lip_ratio:
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
print(lip_delta)
x_d_new = x_d_new + \
(eyes_delta if eyes_delta is not None else 0) + \
(lip_delta if lip_delta is not None else 0)
@ -388,29 +396,51 @@ class GradioPipeline(LivePortraitPipeline):
return source_eye_ratio, source_lip_ratio
@torch.no_grad()
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True):
def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True):
""" retargeting the lip-open ratio of each source frame
"""
# disposable feature
device = self.live_portrait_wrapper.device
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
if input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
if not video_retargeting_silence:
f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
if input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
I_p_pstbk_lst = None
if flag_do_crop_input_retargeting_video:
I_p_pstbk_lst = []
I_p_lst = []
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
x_s_user_i = x_s_user_lst[i].to(device)
f_s_user_i = f_s_user_lst[i].to(device)
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
x_d_i_new = x_s_user_i + lip_delta_retargeting
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if flag_do_crop_input_retargeting_video:
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
I_p_pstbk_lst.append(I_p_pstbk)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \
self.prepare_video_lip_silence(input_video, device, flag_do_crop=flag_do_crop_input_retargeting_video)
I_p_pstbk_lst = None
if flag_do_crop_input_retargeting_video:
I_p_pstbk_lst = []
I_p_lst = []
for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
for i in track(range(n_frames), description='Silencing lip...', total=n_frames):
x_s_user_i = x_s_user_lst[i].to(device)
f_s_user_i = f_s_user_lst[i].to(device)
lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
x_d_i_new = x_s_user_i + lip_delta_retargeting
x_d_i_new = x_d_i_new_lst[i]
x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
@ -420,37 +450,37 @@ class GradioPipeline(LivePortraitPipeline):
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
I_p_pstbk_lst.append(I_p_pstbk)
mkdir(self.args.output_dir)
flag_source_has_audio = has_audio_stream(input_video)
mkdir(self.args.output_dir)
flag_source_has_audio = has_audio_stream(input_video)
######### build the final concatenation result #########
# source frame | generation
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
######### build the final concatenation result #########
# source frame | generation
frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
if flag_source_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
if flag_source_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=source_fps)
# save the animated result
wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=source_fps)
######### build the final result #########
if flag_source_has_audio:
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
add_audio_to_video(wfp, input_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
gr.Info("Run successfully!", duration=2)
return wfp_concat, wfp
######### build the final result #########
if flag_source_has_audio:
wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
add_audio_to_video(wfp, input_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
gr.Info("Run successfully!", duration=2)
return wfp_concat, wfp
@torch.no_grad()
def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
@ -503,12 +533,64 @@ class GradioPipeline(LivePortraitPipeline):
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32))
lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting)
return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames
else:
# when press the clear button, go here
raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
@torch.no_grad()
def prepare_video_lip_silence(self, input_video, device, flag_do_crop=True):
""" for keeping lips in the source video silent
"""
if input_video is not None:
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source video ########
source_rgb_lst = load_video(input_video)
source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
source_fps = int(get_fps(input_video))
n_frames = len(source_rgb_lst)
log(f"Load source video from {input_video}. FPS is {source_fps}")
if flag_do_crop:
ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
if len(ret_s["frame_crop_lst"]) != n_frames:
n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
else:
source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
source_M_c2o_lst, mask_ori_lst = None, None
c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
# save the motion template
I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], []
for i in track(range(n_frames), description='Preparing silencing lip...', total=n_frames):
x_s_info = source_template_dct['motion'][i]
x_s_info = dct2device(x_s_info, device)
scale_s = x_s_info['scale']
x_s_user = x_s_info['x_s']
x_c_s = x_s_info['kp']
R_s = x_s_info['R']
t_s = x_s_info['t']
delta_new = torch.zeros_like(x_s_info['exp']) + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :]
source_lmk = source_lmk_crop_lst[i]
img_crop_256x256 = img_crop_256x256_lst[i]
I_s = I_s_lst[i]
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_d_i_new = scale_s * (x_c_s @ R_s + delta_new) + t_s
f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new)
return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames
else:
# when press the clear button, go here
raise gr.Error("Please upload a source video as the input 🤗🤗🤗", duration=5)
class GradioPipelineAnimal(LivePortraitPipelineAnimal):
"""gradio for animal
"""

View File

@ -72,7 +72,6 @@ class LivePortraitPipeline(object):
c_lip = c_lip_lst[i].astype(np.float32)
template_dct['c_lip_lst'].append(c_lip)
return template_dct
def execute(self, args: ArgumentConfig):
@ -111,8 +110,11 @@ class LivePortraitPipeline(object):
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames']
if flag_is_source_video:
flag_is_driving_video = True if driving_n_frames > 1 else False
if flag_is_source_video and flag_is_driving_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
elif flag_is_source_video and not flag_is_driving_video:
n_frames = len(source_rgb_lst)
else:
n_frames = driving_n_frames
@ -123,25 +125,35 @@ class LivePortraitPipeline(object):
if args.flag_crop_driving_video:
log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.")
elif osp.exists(args.driving) and is_video(args.driving):
# load from video file, AND make motion template
output_fps = int(get_fps(args.driving))
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_rgb_lst = load_video(args.driving)
driving_n_frames = len(driving_rgb_lst)
elif osp.exists(args.driving):
if is_video(args.driving):
flag_is_driving_video = True
# load from video file, AND make motion template
output_fps = int(get_fps(args.driving))
log(f"Load driving video from: {args.driving}, FPS is {output_fps}")
driving_rgb_lst = load_video(args.driving)
elif is_image(args.driving):
flag_is_driving_video = False
driving_img_rgb = load_image_rgb(args.driving)
output_fps = 25
log(f"Load driving image from {args.driving}")
driving_rgb_lst = [driving_img_rgb]
else:
raise Exception(f"{args.driving} is not a supported type!")
######## make motion template ########
log("Start making driving motion template...")
if flag_is_source_video:
driving_n_frames = len(driving_rgb_lst)
if flag_is_source_video and flag_is_driving_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
driving_rgb_lst = driving_rgb_lst[:n_frames]
elif flag_is_source_video and not flag_is_driving_video:
n_frames = len(source_rgb_lst)
else:
n_frames = driving_n_frames
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
ret_d = self.cropper.crop_driving_video(driving_rgb_lst)
log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
if len(ret_d["frame_crop_lst"]) is not n_frames:
if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video:
n_frames = min(n_frames, len(ret_d["frame_crop_lst"]))
driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst]
@ -158,9 +170,11 @@ class LivePortraitPipeline(object):
wfp_template = remove_suffix(args.driving) + '.pkl'
dump(wfp_template, driving_template_dct)
log(f"Dump motion template to {wfp_template}")
else:
raise Exception(f"{args.driving} not exists or unsupported driving info types!")
raise Exception(f"{args.driving} does not exist!")
if not flag_is_driving_video:
c_d_eyes_lst = c_d_eyes_lst*n_frames
c_d_lip_lst = c_d_lip_lst*n_frames
######## prepare for pasteback ########
I_p_pstbk_lst = None
@ -196,17 +210,33 @@ class LivePortraitPipeline(object):
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys
if inf_cfg.flag_relative_motion:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
if flag_is_driving_video:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)]
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
if flag_is_driving_video:
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]
else:
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
if flag_is_driving_video:
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_exp_lst = [driving_template_dct['motion'][0]['exp']]
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]*n_frames
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
if flag_is_driving_video:
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_r_lst = [driving_template_dct['motion'][0][key_r]]
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]*n_frames
else: # if the input is a source image, process it only once
if inf_cfg.flag_do_crop:
@ -236,7 +266,10 @@ class LivePortraitPipeline(object):
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0]))
######## animate ########
log(f"The animated video consists of {n_frames} frames.")
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
log(f"The animated video consists of {n_frames} frames.")
else:
log(f"The output of image-driven portrait animation is an image.")
for i in track(range(n_frames), description='🚀Animating...', total=n_frames):
if flag_is_source_video: # source video
x_s_info = source_template_dct['motion'][i]
@ -272,43 +305,88 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
x_d_i_info = driving_template_dct['motion'][i]
if flag_is_source_video and not flag_is_driving_video:
x_d_i_info = driving_template_dct['motion'][0]
else:
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
if i == 0: # cache the first frame
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
x_d_0_info = x_d_i_info.copy()
delta_new = x_s_info['exp'].clone()
if inf_cfg.flag_relative_motion:
if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
R_new = x_d_r_lst_smooth[i] if flag_is_source_video else (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
else:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
if flag_is_source_video:
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
else:
if flag_is_driving_video:
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
else:
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device))
elif inf_cfg.animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
if flag_is_source_video:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
elif flag_is_driving_video:
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
else:
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :]
elif inf_cfg.animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
if flag_is_source_video:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
elif flag_is_driving_video:
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
else:
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, eyes_idx, :]
if inf_cfg.animation_region == "all":
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
else:
scale_new = x_s_info['scale']
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
t_new = x_s_info['t']
else:
if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
R_new = x_d_r_lst_smooth[i] if flag_is_source_video else R_d_i
else:
R_new = R_d_i
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] if flag_is_source_video else x_d_i_info['exp'][:, idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] if flag_is_source_video else x_d_i_info['exp'][:, 3:5, 1]
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] if flag_is_source_video else x_d_i_info['exp'][:, 5, 2]
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] if flag_is_source_video else x_d_i_info['exp'][:, 8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] if flag_is_source_video else x_d_i_info['exp'][:, 9, 1:]
elif inf_cfg.animation_region == "lip":
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, lip_idx, :]
elif inf_cfg.animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, eyes_idx, :]
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
t_new = x_d_i_info['t']
else:
t_new = x_s_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video:
if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video:
if i == 0:
x_d_0_new = x_d_i_new
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
@ -373,50 +451,68 @@ class LivePortraitPipeline(object):
mkdir(args.output_dir)
wfp_concat = None
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
######### build the final concatenation result #########
# driving frame | source frame | generation, or source frame | generation
if flag_is_source_video:
# driving frame | source frame | generation
if flag_is_source_video and flag_is_driving_video:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
elif flag_is_source_video and not flag_is_driving_video:
if flag_load_from_template:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# NOTE: update output fps
output_fps = source_fps if flag_is_source_video else output_fps
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)
if flag_source_has_audio or flag_driving_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4')
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
# NOTE: update output fps
output_fps = source_fps if flag_is_source_video else output_fps
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
if flag_source_has_audio or flag_driving_has_audio:
# final result with concatenation
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}, concat mode")
add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
# save the animated result
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps)
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build the final result #########
if flag_source_has_audio or flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}')
else:
images2video(I_p_lst, wfp=wfp, fps=output_fps)
######### build the final result #########
if flag_source_has_audio or flag_driving_has_audio:
wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4')
audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source
log(f"Audio is selected from {audio_from_which_video}")
add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio)
os.replace(wfp_with_audio, wfp)
log(f"Replace {wfp_with_audio} with {wfp}")
# final log
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concat: {wfp_concat}')
wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg')
cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1])
wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg')
if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1])
else:
cv2.imwrite(wfp, frames_concatenated[0][..., ::-1])
# final log
log(f'Animated image: {wfp}')
log(f'Animated image with concat: {wfp_concat}')
return wfp, wfp_concat

View File

@ -100,7 +100,10 @@ def squeeze_tensor_to_numpy(tensor):
def dct2device(dct: dict, device):
for key in dct:
dct[key] = torch.tensor(dct[key]).to(device)
if isinstance(dct[key], torch.Tensor):
dct[key] = dct[key].to(device)
else:
dct[key] = torch.tensor(dct[key]).to(device)
return dct

Binary file not shown.