Added new inference code for webcam

This commit is contained in:
unknown 2024-07-10 12:48:10 +09:00
parent de88675006
commit 1a754339b5
7 changed files with 194 additions and 137 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 324 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 475 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 525 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 374 KiB

View File

@ -1,33 +1,182 @@
# coding: utf-8
import tyro import tyro
from src.config.argument_config import ArgumentConfig from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline from src.live_portrait_pipeline import LivePortraitPipeline
import cv2
import time
import numpy as np
def partial_fields(target_class, kwargs): def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def main(): def main():
# set tyro theme # set tyro theme
tyro.extras.set_accent_color("bright_cyan") tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig) args = tyro.cli(ArgumentConfig)
# specify configs for inference # specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig inference_cfg = partial_fields(InferenceConfig, args.__dict__)
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig crop_cfg = partial_fields(CropConfig, args.__dict__)
live_portrait_pipeline = LivePortraitPipeline( live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=inference_cfg, inference_cfg=inference_cfg,
crop_cfg=crop_cfg crop_cfg=crop_cfg
) )
# run # Initialize webcam 'assets/examples/driving/d6.mp4'
live_portrait_pipeline.execute(args) cap = cv2.VideoCapture(0)
# Process the first frame to initialize
ret, frame = cap.read()
if not ret:
print("Failed to capture image")
return
source_image_path = args.source_image # Set the source image path here
x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb = live_portrait_pipeline.execute_frame(frame, source_image_path)
while True:
# Capture frame-by-frame
ret, frame = cap.read()
if not ret:
break
# Process the frame
result = live_portrait_pipeline.generate_frame(x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, frame)
cv2.imshow('img_rgb Image', img_rgb)
cv2.imshow('Source Frame', frame)
# [Key Change] Convert the result from RGB to BGR before displaying
result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
# Display the resulting frame
cv2.imshow('Live Portrait', result_bgr)
# Press 'q' to exit the loop
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# When everything is done, release the capture
cap.release()
cv2.destroyAllWindows()
# live_portrait_pipeline.execute_frame(result_bgr)
if __name__ == '__main__': if __name__ == '__main__':
st = time.time()
main() main()
print("Generation time:", (time.time() - st) * 1000)
# 3. Reduced webcam latency 350 to 160
# import cv2
# import time
# import threading
# import numpy as np
# import tyro
# from src.config.argument_config import ArgumentConfig
# from src.config.inference_config import InferenceConfig
# from src.config.crop_config import CropConfig
# from src.live_portrait_pipeline import LivePortraitPipeline
# def partial_fields(target_class, kwargs):
# return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
# class VideoCaptureThread:
# def __init__(self, src=0):
# self.cap = cv2.VideoCapture(src)
# self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 480)
# self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
# self.cap.set(cv2.CAP_PROP_FPS, 60)
# if not self.cap.isOpened():
# print("Failed to open camera")
# self.running = False
# else:
# self.ret = False
# self.frame = None
# self.running = True
# self.thread = threading.Thread(target=self.update, args=())
# self.thread.start()
# def update(self):
# while self.running:
# self.ret, self.frame = self.cap.read()
# if not self.ret:
# print("Failed to read frame")
# break
# def read(self):
# return self.ret, self.frame
# def release(self):
# self.running = False
# self.thread.join()
# self.cap.release()
# def main():
# # Set tyro theme
# tyro.extras.set_accent_color("bright_cyan")
# args = tyro.cli(ArgumentConfig)
# # Specify configs for inference
# inference_cfg = partial_fields(InferenceConfig, args.__dict__)
# crop_cfg = partial_fields(CropConfig, args.__dict__)
# live_portrait_pipeline = LivePortraitPipeline(
# inference_cfg=inference_cfg,
# crop_cfg=crop_cfg
# )
# # Initialize webcam 'assets/examples/driving/d6.mp4'
# cap_thread = VideoCaptureThread(0)
# # Wait for the first frame to be captured
# while not cap_thread.ret and cap_thread.running:
# time.sleep(0.1)
# if not cap_thread.ret:
# print("Failed to capture image")
# cap_thread.release()
# return
# source_image_path = args.source_image # Set the source image path here
# ret, frame = cap_thread.read()
# x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb = live_portrait_pipeline.execute_frame(frame, source_image_path)
# while cap_thread.running:
# # Capture frame-by-frame
# ret, frame = cap_thread.read()
# if not ret:
# break
# # Process the frame
# result = live_portrait_pipeline.generate_frame(x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, frame)
# # cv2.imshow('img_rgb Image', img_rgb)
# cv2.imshow('Webcam Frame', frame)
# # Convert the result from RGB to BGR before displaying
# result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
# # Display the resulting frame
# cv2.imshow('Webcam Live Portrait', result_bgr)
# # Press 'q' to exit the loop
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# # When everything is done, release the capture
# cap_thread.release()
# cv2.destroyAllWindows()
# if __name__ == '__main__':
# st = time.time()
# main()
# print("Generation time:", (time.time() - st) * 1000)

View File

@ -1,13 +1,7 @@
# coding: utf-8
""" """
Pipeline of LivePortrait Pipeline of LivePortrait
""" """
# TODO:
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
# 2. pick样例图 source + driving
import cv2 import cv2
import numpy as np import numpy as np
import pickle import pickle
@ -38,153 +32,67 @@ class LivePortraitPipeline(object):
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg) self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
self.cropper = Cropper(crop_cfg=crop_cfg) self.cropper = Cropper(crop_cfg=crop_cfg)
def execute(self, args: ArgumentConfig): def execute_frame(self, frame, source_image_path):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience inference_cfg = self.live_portrait_wrapper.cfg # for convenience
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image) # Load and preprocess source image
img_rgb = load_image_rgb(source_image_path)
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
log(f"Load source image from {args.source_image}") log(f"Load source image from {source_image_path}")
crop_info = self.cropper.crop_single_image(img_rgb) crop_info = self.cropper.crop_single_image(img_rgb)
source_lmk = crop_info['lmk_crop'] source_lmk = crop_info['lmk_crop']
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
if inference_cfg.flag_do_crop: if inference_cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else: else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb) I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp'] x_c_s = x_s_info['kp']
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
lip_delta_before_animation = None
if inference_cfg.flag_lip_zero: if inference_cfg.flag_lip_zero:
# let lip-open scalar to be 0 at first
c_d_lip_before_animation = [0.] c_d_lip_before_animation = [0.]
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, crop_info['lmk_crop'])
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold: if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
inference_cfg.flag_lip_zero = False inference_cfg.flag_lip_zero = False
else: else:
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
############################################
######## process driving info ######## return x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb
if is_video(args.driving_info):
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}") def generate_frame(self, x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, driving_info):
# TODO: 这里track一下驱动视频 -> 构建模板 inference_cfg = self.live_portrait_wrapper.cfg # for convenience
driving_rgb_lst = load_driving_info(args.driving_info)
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # Process driving info
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256) driving_rgb = cv2.resize(driving_info, (256, 256))
n_frames = I_d_lst.shape[0] I_d_i = self.live_portrait_wrapper.prepare_driving_videos([driving_rgb])[0]
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
elif is_template(args.driving_info): R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
log(f"Load from video templates {args.driving_info}")
with open(args.driving_info, 'rb') as f: R_new = R_d_i @ R_s
template_lst, driving_lmk_lst = pickle.load(f) delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_s_info['exp'])
n_frames = template_lst[0]['n_frames'] scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_s_info['scale'])
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) t_new = x_s_info['t'] + (x_d_i_info['t'] - x_s_info['t'])
else: t_new[..., 2].fill_(0) # zero tz
raise Exception("Unsupported driving types!")
######################################### x_d_i_new = scale_new * (x_s @ R_new + delta_new) + t_new
if inference_cfg.flag_lip_zero and lip_delta_before_animation is not None:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
######## prepare for pasteback ########
if inference_cfg.flag_pasteback: if inference_cfg.flag_pasteback:
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_paste_lst = [] I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
######################################### return I_p_i_to_ori_blend
I_p_lst = []
R_d_0, x_d_0_info = None, None
for i in track(range(n_frames), description='Animating...', total=n_frames):
if is_video(args.driving_info):
# extract kp info by M
I_d_i = I_d_lst[i]
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
else:
# from template
x_d_i_info = template_lst[i]
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if inference_cfg.flag_relative:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
R_new = R_d_i
delta_new = x_d_i_info['exp']
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
# Algorithm 1:
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
# without stitching or retargeting
if inference_cfg.flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
pass
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
# with stitching and without retargeting
if inference_cfg.flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
if inference_cfg.flag_eye_retargeting:
c_d_eyes_i = input_eye_ratio_lst[i]
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
if inference_cfg.flag_lip_retargeting:
c_d_lip_i = input_lip_ratio_lst[i]
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
if inference_cfg.flag_relative: # use x_s
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
if inference_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
if inference_cfg.flag_pasteback:
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
I_p_paste_lst.append(I_p_i_to_ori_blend)
mkdir(args.output_dir)
wfp_concat = None
if is_video(args.driving_info):
frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
# save (driving frames, source image, drived frames) result
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat)
# save drived result
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
if inference_cfg.flag_pasteback:
images2video(I_p_paste_lst, wfp=wfp)
else: else:
images2video(I_p_lst, wfp=wfp) return I_p_i
return wfp, wfp_concat