mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2025-01-12 08:48:52 +00:00
Added new inference code for webcam
This commit is contained in:
parent
de88675006
commit
1a754339b5
BIN
assets/examples/source/MY_photo.jpg
Normal file
BIN
assets/examples/source/MY_photo.jpg
Normal file
Binary file not shown.
After (image error) Size: 36 KiB |
BIN
assets/examples/source/k1.png
Normal file
BIN
assets/examples/source/k1.png
Normal file
Binary file not shown.
After (image error) Size: 324 KiB |
BIN
assets/examples/source/k2.png
Normal file
BIN
assets/examples/source/k2.png
Normal file
Binary file not shown.
After (image error) Size: 475 KiB |
Binary file not shown.
Before (image error) Size: 525 KiB |
BIN
assets/examples/source/solo.png
Normal file
BIN
assets/examples/source/solo.png
Normal file
Binary file not shown.
After (image error) Size: 374 KiB |
163
inference.py
163
inference.py
@ -1,33 +1,182 @@
|
|||||||
# coding: utf-8
|
|
||||||
|
|
||||||
import tyro
|
import tyro
|
||||||
from src.config.argument_config import ArgumentConfig
|
from src.config.argument_config import ArgumentConfig
|
||||||
from src.config.inference_config import InferenceConfig
|
from src.config.inference_config import InferenceConfig
|
||||||
from src.config.crop_config import CropConfig
|
from src.config.crop_config import CropConfig
|
||||||
from src.live_portrait_pipeline import LivePortraitPipeline
|
from src.live_portrait_pipeline import LivePortraitPipeline
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def partial_fields(target_class, kwargs):
|
def partial_fields(target_class, kwargs):
|
||||||
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# set tyro theme
|
# set tyro theme
|
||||||
tyro.extras.set_accent_color("bright_cyan")
|
tyro.extras.set_accent_color("bright_cyan")
|
||||||
args = tyro.cli(ArgumentConfig)
|
args = tyro.cli(ArgumentConfig)
|
||||||
|
|
||||||
# specify configs for inference
|
# specify configs for inference
|
||||||
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
||||||
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
crop_cfg = partial_fields(CropConfig, args.__dict__)
|
||||||
|
|
||||||
live_portrait_pipeline = LivePortraitPipeline(
|
live_portrait_pipeline = LivePortraitPipeline(
|
||||||
inference_cfg=inference_cfg,
|
inference_cfg=inference_cfg,
|
||||||
crop_cfg=crop_cfg
|
crop_cfg=crop_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
# run
|
# Initialize webcam 'assets/examples/driving/d6.mp4'
|
||||||
live_portrait_pipeline.execute(args)
|
cap = cv2.VideoCapture(0)
|
||||||
|
|
||||||
|
# Process the first frame to initialize
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
print("Failed to capture image")
|
||||||
|
return
|
||||||
|
|
||||||
|
source_image_path = args.source_image # Set the source image path here
|
||||||
|
x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb = live_portrait_pipeline.execute_frame(frame, source_image_path)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Capture frame-by-frame
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process the frame
|
||||||
|
|
||||||
|
result = live_portrait_pipeline.generate_frame(x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, frame)
|
||||||
|
cv2.imshow('img_rgb Image', img_rgb)
|
||||||
|
cv2.imshow('Source Frame', frame)
|
||||||
|
|
||||||
|
|
||||||
|
# [Key Change] Convert the result from RGB to BGR before displaying
|
||||||
|
result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
|
||||||
|
# Display the resulting frame
|
||||||
|
cv2.imshow('Live Portrait', result_bgr)
|
||||||
|
|
||||||
|
# Press 'q' to exit the loop
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
# When everything is done, release the capture
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# live_portrait_pipeline.execute_frame(result_bgr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
st = time.time()
|
||||||
main()
|
main()
|
||||||
|
print("Generation time:", (time.time() - st) * 1000)
|
||||||
|
|
||||||
|
# 3. Reduced webcam latency 350 to 160
|
||||||
|
|
||||||
|
# import cv2
|
||||||
|
# import time
|
||||||
|
# import threading
|
||||||
|
# import numpy as np
|
||||||
|
# import tyro
|
||||||
|
# from src.config.argument_config import ArgumentConfig
|
||||||
|
# from src.config.inference_config import InferenceConfig
|
||||||
|
# from src.config.crop_config import CropConfig
|
||||||
|
# from src.live_portrait_pipeline import LivePortraitPipeline
|
||||||
|
|
||||||
|
# def partial_fields(target_class, kwargs):
|
||||||
|
# return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
|
||||||
|
|
||||||
|
# class VideoCaptureThread:
|
||||||
|
# def __init__(self, src=0):
|
||||||
|
# self.cap = cv2.VideoCapture(src)
|
||||||
|
# self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 480)
|
||||||
|
# self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
||||||
|
# self.cap.set(cv2.CAP_PROP_FPS, 60)
|
||||||
|
|
||||||
|
# if not self.cap.isOpened():
|
||||||
|
# print("Failed to open camera")
|
||||||
|
# self.running = False
|
||||||
|
# else:
|
||||||
|
# self.ret = False
|
||||||
|
# self.frame = None
|
||||||
|
# self.running = True
|
||||||
|
# self.thread = threading.Thread(target=self.update, args=())
|
||||||
|
# self.thread.start()
|
||||||
|
|
||||||
|
# def update(self):
|
||||||
|
# while self.running:
|
||||||
|
# self.ret, self.frame = self.cap.read()
|
||||||
|
# if not self.ret:
|
||||||
|
# print("Failed to read frame")
|
||||||
|
# break
|
||||||
|
|
||||||
|
# def read(self):
|
||||||
|
# return self.ret, self.frame
|
||||||
|
|
||||||
|
# def release(self):
|
||||||
|
# self.running = False
|
||||||
|
# self.thread.join()
|
||||||
|
# self.cap.release()
|
||||||
|
|
||||||
|
# def main():
|
||||||
|
# # Set tyro theme
|
||||||
|
# tyro.extras.set_accent_color("bright_cyan")
|
||||||
|
# args = tyro.cli(ArgumentConfig)
|
||||||
|
|
||||||
|
# # Specify configs for inference
|
||||||
|
# inference_cfg = partial_fields(InferenceConfig, args.__dict__)
|
||||||
|
# crop_cfg = partial_fields(CropConfig, args.__dict__)
|
||||||
|
|
||||||
|
# live_portrait_pipeline = LivePortraitPipeline(
|
||||||
|
# inference_cfg=inference_cfg,
|
||||||
|
# crop_cfg=crop_cfg
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Initialize webcam 'assets/examples/driving/d6.mp4'
|
||||||
|
# cap_thread = VideoCaptureThread(0)
|
||||||
|
|
||||||
|
# # Wait for the first frame to be captured
|
||||||
|
# while not cap_thread.ret and cap_thread.running:
|
||||||
|
# time.sleep(0.1)
|
||||||
|
|
||||||
|
# if not cap_thread.ret:
|
||||||
|
# print("Failed to capture image")
|
||||||
|
# cap_thread.release()
|
||||||
|
# return
|
||||||
|
|
||||||
|
# source_image_path = args.source_image # Set the source image path here
|
||||||
|
# ret, frame = cap_thread.read()
|
||||||
|
# x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb = live_portrait_pipeline.execute_frame(frame, source_image_path)
|
||||||
|
|
||||||
|
# while cap_thread.running:
|
||||||
|
# # Capture frame-by-frame
|
||||||
|
# ret, frame = cap_thread.read()
|
||||||
|
# if not ret:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# # Process the frame
|
||||||
|
# result = live_portrait_pipeline.generate_frame(x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, frame)
|
||||||
|
# # cv2.imshow('img_rgb Image', img_rgb)
|
||||||
|
# cv2.imshow('Webcam Frame', frame)
|
||||||
|
|
||||||
|
# # Convert the result from RGB to BGR before displaying
|
||||||
|
# result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# # Display the resulting frame
|
||||||
|
# cv2.imshow('Webcam Live Portrait', result_bgr)
|
||||||
|
|
||||||
|
# # Press 'q' to exit the loop
|
||||||
|
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
# break
|
||||||
|
|
||||||
|
# # When everything is done, release the capture
|
||||||
|
# cap_thread.release()
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# st = time.time()
|
||||||
|
# main()
|
||||||
|
# print("Generation time:", (time.time() - st) * 1000)
|
||||||
|
@ -1,13 +1,7 @@
|
|||||||
# coding: utf-8
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Pipeline of LivePortrait
|
Pipeline of LivePortrait
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
|
|
||||||
# 2. pick样例图 source + driving
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pickle
|
import pickle
|
||||||
@ -38,153 +32,67 @@ class LivePortraitPipeline(object):
|
|||||||
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
||||||
self.cropper = Cropper(crop_cfg=crop_cfg)
|
self.cropper = Cropper(crop_cfg=crop_cfg)
|
||||||
|
|
||||||
def execute(self, args: ArgumentConfig):
|
def execute_frame(self, frame, source_image_path):
|
||||||
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
||||||
######## process source portrait ########
|
|
||||||
img_rgb = load_image_rgb(args.source_image)
|
# Load and preprocess source image
|
||||||
|
img_rgb = load_image_rgb(source_image_path)
|
||||||
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
||||||
log(f"Load source image from {args.source_image}")
|
log(f"Load source image from {source_image_path}")
|
||||||
crop_info = self.cropper.crop_single_image(img_rgb)
|
crop_info = self.cropper.crop_single_image(img_rgb)
|
||||||
source_lmk = crop_info['lmk_crop']
|
source_lmk = crop_info['lmk_crop']
|
||||||
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
||||||
|
|
||||||
if inference_cfg.flag_do_crop:
|
if inference_cfg.flag_do_crop:
|
||||||
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
||||||
else:
|
else:
|
||||||
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
||||||
|
|
||||||
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
||||||
x_c_s = x_s_info['kp']
|
x_c_s = x_s_info['kp']
|
||||||
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
||||||
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
||||||
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
||||||
|
|
||||||
|
lip_delta_before_animation = None
|
||||||
if inference_cfg.flag_lip_zero:
|
if inference_cfg.flag_lip_zero:
|
||||||
# let lip-open scalar to be 0 at first
|
|
||||||
c_d_lip_before_animation = [0.]
|
c_d_lip_before_animation = [0.]
|
||||||
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, crop_info['lmk_crop'])
|
||||||
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
|
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
|
||||||
inference_cfg.flag_lip_zero = False
|
inference_cfg.flag_lip_zero = False
|
||||||
else:
|
else:
|
||||||
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
||||||
############################################
|
|
||||||
|
|
||||||
######## process driving info ########
|
return x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb
|
||||||
if is_video(args.driving_info):
|
|
||||||
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
|
def generate_frame(self, x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, driving_info):
|
||||||
# TODO: 这里track一下驱动视频 -> 构建模板
|
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
||||||
driving_rgb_lst = load_driving_info(args.driving_info)
|
|
||||||
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
|
# Process driving info
|
||||||
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
|
driving_rgb = cv2.resize(driving_info, (256, 256))
|
||||||
n_frames = I_d_lst.shape[0]
|
I_d_i = self.live_portrait_wrapper.prepare_driving_videos([driving_rgb])[0]
|
||||||
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
|
|
||||||
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
|
||||||
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
|
||||||
elif is_template(args.driving_info):
|
|
||||||
log(f"Load from video templates {args.driving_info}")
|
|
||||||
with open(args.driving_info, 'rb') as f:
|
|
||||||
template_lst, driving_lmk_lst = pickle.load(f)
|
|
||||||
n_frames = template_lst[0]['n_frames']
|
|
||||||
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
|
||||||
else:
|
|
||||||
raise Exception("Unsupported driving types!")
|
|
||||||
#########################################
|
|
||||||
|
|
||||||
######## prepare for pasteback ########
|
|
||||||
if inference_cfg.flag_pasteback:
|
|
||||||
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
|
||||||
I_p_paste_lst = []
|
|
||||||
#########################################
|
|
||||||
|
|
||||||
I_p_lst = []
|
|
||||||
R_d_0, x_d_0_info = None, None
|
|
||||||
for i in track(range(n_frames), description='Animating...', total=n_frames):
|
|
||||||
if is_video(args.driving_info):
|
|
||||||
# extract kp info by M
|
|
||||||
I_d_i = I_d_lst[i]
|
|
||||||
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
||||||
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
||||||
else:
|
|
||||||
# from template
|
|
||||||
x_d_i_info = template_lst[i]
|
|
||||||
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
|
|
||||||
R_d_i = x_d_i_info['R_d']
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
R_d_0 = R_d_i
|
|
||||||
x_d_0_info = x_d_i_info
|
|
||||||
|
|
||||||
if inference_cfg.flag_relative:
|
|
||||||
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
|
||||||
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
|
||||||
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
|
||||||
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
|
||||||
else:
|
|
||||||
R_new = R_d_i
|
|
||||||
delta_new = x_d_i_info['exp']
|
|
||||||
scale_new = x_s_info['scale']
|
|
||||||
t_new = x_d_i_info['t']
|
|
||||||
|
|
||||||
|
R_new = R_d_i @ R_s
|
||||||
|
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_s_info['exp'])
|
||||||
|
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_s_info['scale'])
|
||||||
|
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_s_info['t'])
|
||||||
t_new[..., 2].fill_(0) # zero tz
|
t_new[..., 2].fill_(0) # zero tz
|
||||||
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
|
||||||
|
|
||||||
# Algorithm 1:
|
x_d_i_new = scale_new * (x_s @ R_new + delta_new) + t_new
|
||||||
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
if inference_cfg.flag_lip_zero and lip_delta_before_animation is not None:
|
||||||
# without stitching or retargeting
|
|
||||||
if inference_cfg.flag_lip_zero:
|
|
||||||
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
||||||
else:
|
|
||||||
pass
|
|
||||||
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
|
||||||
# with stitching and without retargeting
|
|
||||||
if inference_cfg.flag_lip_zero:
|
|
||||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
|
||||||
else:
|
|
||||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
|
||||||
else:
|
|
||||||
eyes_delta, lip_delta = None, None
|
|
||||||
if inference_cfg.flag_eye_retargeting:
|
|
||||||
c_d_eyes_i = input_eye_ratio_lst[i]
|
|
||||||
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
|
||||||
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
|
||||||
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
|
||||||
if inference_cfg.flag_lip_retargeting:
|
|
||||||
c_d_lip_i = input_lip_ratio_lst[i]
|
|
||||||
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
|
||||||
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
|
||||||
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
|
||||||
|
|
||||||
if inference_cfg.flag_relative: # use x_s
|
|
||||||
x_d_i_new = x_s + \
|
|
||||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
|
||||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
|
||||||
else: # use x_d,i
|
|
||||||
x_d_i_new = x_d_i_new + \
|
|
||||||
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
|
||||||
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
|
||||||
|
|
||||||
if inference_cfg.flag_stitching:
|
|
||||||
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
|
||||||
|
|
||||||
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
||||||
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
||||||
I_p_lst.append(I_p_i)
|
|
||||||
|
|
||||||
if inference_cfg.flag_pasteback:
|
if inference_cfg.flag_pasteback:
|
||||||
|
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
||||||
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
||||||
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
return I_p_i_to_ori_blend
|
||||||
|
|
||||||
mkdir(args.output_dir)
|
|
||||||
wfp_concat = None
|
|
||||||
if is_video(args.driving_info):
|
|
||||||
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
|
|
||||||
# save (driving frames, source image, drived frames) result
|
|
||||||
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
|
||||||
images2video(frames_concatenated, wfp=wfp_concat)
|
|
||||||
|
|
||||||
# save drived result
|
|
||||||
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
|
||||||
if inference_cfg.flag_pasteback:
|
|
||||||
images2video(I_p_paste_lst, wfp=wfp)
|
|
||||||
else:
|
else:
|
||||||
images2video(I_p_lst, wfp=wfp)
|
return I_p_i
|
||||||
|
|
||||||
return wfp, wfp_concat
|
|
||||||
|
Loading…
Reference in New Issue
Block a user