mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 20:42:38 +00:00
chore: slightly refine the codebase
This commit is contained in:
parent
669487a2fe
commit
d09527c762
8
app.py
8
app.py
@ -44,8 +44,8 @@ data_examples = [
|
|||||||
#################### interface logic ####################
|
#################### interface logic ####################
|
||||||
|
|
||||||
# Define components first
|
# Define components first
|
||||||
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
|
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
|
||||||
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
|
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
|
||||||
retargeting_input_image = gr.Image(type="numpy")
|
retargeting_input_image = gr.Image(type="numpy")
|
||||||
output_image = gr.Image(type="numpy")
|
output_image = gr.Image(type="numpy")
|
||||||
output_image_paste_back = gr.Image(type="numpy")
|
output_image_paste_back = gr.Image(type="numpy")
|
||||||
@ -56,7 +56,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
gr.HTML(load_description(title_md))
|
gr.HTML(load_description(title_md))
|
||||||
gr.Markdown(load_description("assets/gradio_description_upload.md"))
|
gr.Markdown(load_description("assets/gradio_description_upload.md"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Accordion(open=True, label="Reference Portrait"):
|
with gr.Accordion(open=True, label="Source Portrait"):
|
||||||
image_input = gr.Image(type="filepath")
|
image_input = gr.Image(type="filepath")
|
||||||
with gr.Accordion(open=True, label="Driving Video"):
|
with gr.Accordion(open=True, label="Driving Video"):
|
||||||
video_input = gr.Video()
|
video_input = gr.Video()
|
||||||
@ -64,7 +64,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Accordion(open=True, label="Animation Options"):
|
with gr.Accordion(open=True, label="Animation Options"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
flag_relative_input = gr.Checkbox(value=True, label="relative pose")
|
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
|
||||||
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
|
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
|
||||||
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
<span style="font-size: 1.2em;">🔥 To animate the reference portrait with the driving video, please follow these steps:</span>
|
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
|
||||||
<div style="font-size: 1.2em; margin-left: 20px;">
|
<div style="font-size: 1.2em; margin-left: 20px;">
|
||||||
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
|
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
|
||||||
</div>
|
</div>
|
||||||
|
@ -1 +1 @@
|
|||||||
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
|
||||||
|
@ -1,4 +1,2 @@
|
|||||||
## 🤗 This is the official gradio demo for **LivePortrait**.
|
## 🤗 This is the official gradio demo for **LivePortrait**.
|
||||||
### Guidance for the gradio page:
|
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
|
||||||
<div style="font-size: 1.2em;">Please upload or use the webcam to get a reference portrait to the <strong>Reference Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from .base_config import PrintableConfig, make_abs_path
|
|||||||
@dataclass(repr=False) # use repr from PrintableConfig
|
@dataclass(repr=False) # use repr from PrintableConfig
|
||||||
class ArgumentConfig(PrintableConfig):
|
class ArgumentConfig(PrintableConfig):
|
||||||
########## input arguments ##########
|
########## input arguments ##########
|
||||||
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the reference portrait
|
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
|
||||||
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
||||||
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
||||||
#####################################
|
#####################################
|
||||||
@ -25,9 +25,9 @@ class ArgumentConfig(PrintableConfig):
|
|||||||
flag_eye_retargeting: bool = False
|
flag_eye_retargeting: bool = False
|
||||||
flag_lip_retargeting: bool = False
|
flag_lip_retargeting: bool = False
|
||||||
flag_stitching: bool = True # we recommend setting it to True!
|
flag_stitching: bool = True # we recommend setting it to True!
|
||||||
flag_relative: bool = True # whether to use relative pose
|
flag_relative: bool = True # whether to use relative motion
|
||||||
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
||||||
flag_do_crop: bool = True # whether to crop the reference portrait to the face-cropping space
|
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
|
||||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
||||||
#########################################
|
#########################################
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class InferenceConfig(PrintableConfig):
|
|||||||
flag_lip_retargeting: bool = False
|
flag_lip_retargeting: bool = False
|
||||||
flag_stitching: bool = True # we recommend setting it to True!
|
flag_stitching: bool = True # we recommend setting it to True!
|
||||||
|
|
||||||
flag_relative: bool = True # whether to use relative pose
|
flag_relative: bool = True # whether to use relative motion
|
||||||
anchor_frame: int = 0 # set this value if find_best_frame is True
|
anchor_frame: int = 0 # set this value if find_best_frame is True
|
||||||
|
|
||||||
input_shape: Tuple[int, int] = (256, 256) # input shape
|
input_shape: Tuple[int, int] = (256, 256) # input shape
|
||||||
@ -45,5 +45,5 @@ class InferenceConfig(PrintableConfig):
|
|||||||
ref_shape_n: int = 2
|
ref_shape_n: int = 2
|
||||||
|
|
||||||
device_id: int = 0
|
device_id: int = 0
|
||||||
flag_do_crop: bool = False # whether to crop the reference portrait to the face-cropping space
|
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
|
||||||
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
||||||
|
@ -64,7 +64,7 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
gr.Info("Run successfully!", duration=2)
|
gr.Info("Run successfully!", duration=2)
|
||||||
return video_path, video_path_concat,
|
return video_path, video_path_concat,
|
||||||
else:
|
else:
|
||||||
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
||||||
|
|
||||||
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
||||||
""" for single image retargeting
|
""" for single image retargeting
|
||||||
@ -74,12 +74,12 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
elif self.f_s_user is None:
|
elif self.f_s_user is None:
|
||||||
if self.start_prepare:
|
if self.start_prepare:
|
||||||
raise gr.Error(
|
raise gr.Error(
|
||||||
"The reference portrait is under processing 💥! Please wait for a second.",
|
"The source portrait is under processing 💥! Please wait for a second.",
|
||||||
duration=5
|
duration=5
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise gr.Error(
|
raise gr.Error(
|
||||||
"The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
|
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
|
||||||
duration=5
|
duration=5
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -107,7 +107,7 @@ class GradioPipeline(LivePortraitPipeline):
|
|||||||
gr.Info("Upload successfully!", duration=2)
|
gr.Info("Upload successfully!", duration=2)
|
||||||
self.start_prepare = True
|
self.start_prepare = True
|
||||||
inference_cfg = self.live_portrait_wrapper.cfg
|
inference_cfg = self.live_portrait_wrapper.cfg
|
||||||
######## process reference portrait ########
|
######## process source portrait ########
|
||||||
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
||||||
log(f"Load source image from {input_image_path}.")
|
log(f"Load source image from {input_image_path}.")
|
||||||
crop_info = self.cropper.crop_single_image(img_rgb)
|
crop_info = self.cropper.crop_single_image(img_rgb)
|
||||||
|
@ -40,7 +40,7 @@ class LivePortraitPipeline(object):
|
|||||||
|
|
||||||
def execute(self, args: ArgumentConfig):
|
def execute(self, args: ArgumentConfig):
|
||||||
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
||||||
######## process reference portrait ########
|
######## process source portrait ########
|
||||||
img_rgb = load_image_rgb(args.source_image)
|
img_rgb = load_image_rgb(args.source_image)
|
||||||
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
||||||
log(f"Load source image from {args.source_image}")
|
log(f"Load source image from {args.source_image}")
|
||||||
|
@ -10,12 +10,12 @@ import cv2
|
|||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from src.utils.timer import Timer
|
from .utils.timer import Timer
|
||||||
from src.utils.helper import load_model, concat_feat
|
from .utils.helper import load_model, concat_feat
|
||||||
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
||||||
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
||||||
from src.config.inference_config import InferenceConfig
|
from .config.inference_config import InferenceConfig
|
||||||
from src.utils.rprint import rlog as log
|
from .utils.rprint import rlog as log
|
||||||
|
|
||||||
|
|
||||||
class LivePortraitWrapper(object):
|
class LivePortraitWrapper(object):
|
||||||
|
@ -6,17 +6,14 @@ utility functions and classes to handle feature extraction and model loading
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import cv2
|
|
||||||
import torch
|
import torch
|
||||||
from rich.console import Console
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from src.modules.spade_generator import SPADEDecoder
|
from ..modules.spade_generator import SPADEDecoder
|
||||||
from src.modules.warping_network import WarpingNetwork
|
from ..modules.warping_network import WarpingNetwork
|
||||||
from src.modules.motion_extractor import MotionExtractor
|
from ..modules.motion_extractor import MotionExtractor
|
||||||
from src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
|
from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
|
||||||
from src.modules.stitching_retargeting_network import StitchingRetargetingNetwork
|
from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork
|
||||||
from .rprint import rlog as log
|
|
||||||
|
|
||||||
|
|
||||||
def suffix(filename):
|
def suffix(filename):
|
||||||
@ -45,6 +42,7 @@ def is_video(file_path):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_template(file_path):
|
def is_template(file_path):
|
||||||
if file_path.endswith(".pkl"):
|
if file_path.endswith(".pkl"):
|
||||||
return True
|
return True
|
||||||
@ -149,8 +147,8 @@ def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R
|
|||||||
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
|
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
|
||||||
return new_rotation, new_expression, new_translation, new_scale
|
return new_rotation, new_expression, new_translation, new_scale
|
||||||
|
|
||||||
|
|
||||||
def load_description(fp):
|
def load_description(fp):
|
||||||
with open(fp, 'r', encoding='utf-8') as f:
|
with open(fp, 'r', encoding='utf-8') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user