diff --git a/assets/examples/source/MY_photo.jpg b/assets/examples/source/MY_photo.jpg new file mode 100644 index 0000000..85043d3 Binary files /dev/null and b/assets/examples/source/MY_photo.jpg differ diff --git a/assets/examples/source/k1.png b/assets/examples/source/k1.png new file mode 100644 index 0000000..7d9a350 Binary files /dev/null and b/assets/examples/source/k1.png differ diff --git a/assets/examples/source/k2.png b/assets/examples/source/k2.png new file mode 100644 index 0000000..e0658ca Binary files /dev/null and b/assets/examples/source/k2.png differ diff --git a/assets/examples/source/s10.jpg b/assets/examples/source/s10.jpg deleted file mode 100644 index ee9616b..0000000 Binary files a/assets/examples/source/s10.jpg and /dev/null differ diff --git a/assets/examples/source/solo.png b/assets/examples/source/solo.png new file mode 100644 index 0000000..76f9961 Binary files /dev/null and b/assets/examples/source/solo.png differ diff --git a/inference.py b/inference.py index 8387e7f..2edfea4 100644 --- a/inference.py +++ b/inference.py @@ -1,33 +1,182 @@ -# coding: utf-8 - import tyro from src.config.argument_config import ArgumentConfig from src.config.inference_config import InferenceConfig from src.config.crop_config import CropConfig from src.live_portrait_pipeline import LivePortraitPipeline +import cv2 +import time +import numpy as np def partial_fields(target_class, kwargs): return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) - def main(): # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) # specify configs for inference - inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig - crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig + inference_cfg = partial_fields(InferenceConfig, args.__dict__) + crop_cfg = partial_fields(CropConfig, args.__dict__) live_portrait_pipeline = LivePortraitPipeline( inference_cfg=inference_cfg, crop_cfg=crop_cfg ) - # run - live_portrait_pipeline.execute(args) + # Initialize webcam 'assets/examples/driving/d6.mp4' + cap = cv2.VideoCapture(0) + + # Process the first frame to initialize + ret, frame = cap.read() + if not ret: + print("Failed to capture image") + return + + source_image_path = args.source_image # Set the source image path here + x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb = live_portrait_pipeline.execute_frame(frame, source_image_path) + + while True: + # Capture frame-by-frame + ret, frame = cap.read() + + if not ret: + break + + # Process the frame + + result = live_portrait_pipeline.generate_frame(x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, frame) + cv2.imshow('img_rgb Image', img_rgb) + cv2.imshow('Source Frame', frame) + + + # [Key Change] Convert the result from RGB to BGR before displaying + result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + + + # Display the resulting frame + cv2.imshow('Live Portrait', result_bgr) + + # Press 'q' to exit the loop + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + # When everything is done, release the capture + cap.release() + cv2.destroyAllWindows() + + # live_portrait_pipeline.execute_frame(result_bgr) if __name__ == '__main__': + st = time.time() main() + print("Generation time:", (time.time() - st) * 1000) + +# 3. Reduced webcam latency 350 to 160 + +# import cv2 +# import time +# import threading +# import numpy as np +# import tyro +# from src.config.argument_config import ArgumentConfig +# from src.config.inference_config import InferenceConfig +# from src.config.crop_config import CropConfig +# from src.live_portrait_pipeline import LivePortraitPipeline + +# def partial_fields(target_class, kwargs): +# return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) + +# class VideoCaptureThread: +# def __init__(self, src=0): +# self.cap = cv2.VideoCapture(src) +# self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 480) +# self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) +# self.cap.set(cv2.CAP_PROP_FPS, 60) + +# if not self.cap.isOpened(): +# print("Failed to open camera") +# self.running = False +# else: +# self.ret = False +# self.frame = None +# self.running = True +# self.thread = threading.Thread(target=self.update, args=()) +# self.thread.start() + +# def update(self): +# while self.running: +# self.ret, self.frame = self.cap.read() +# if not self.ret: +# print("Failed to read frame") +# break + +# def read(self): +# return self.ret, self.frame + +# def release(self): +# self.running = False +# self.thread.join() +# self.cap.release() + +# def main(): +# # Set tyro theme +# tyro.extras.set_accent_color("bright_cyan") +# args = tyro.cli(ArgumentConfig) + +# # Specify configs for inference +# inference_cfg = partial_fields(InferenceConfig, args.__dict__) +# crop_cfg = partial_fields(CropConfig, args.__dict__) + +# live_portrait_pipeline = LivePortraitPipeline( +# inference_cfg=inference_cfg, +# crop_cfg=crop_cfg +# ) + +# # Initialize webcam 'assets/examples/driving/d6.mp4' +# cap_thread = VideoCaptureThread(0) + +# # Wait for the first frame to be captured +# while not cap_thread.ret and cap_thread.running: +# time.sleep(0.1) + +# if not cap_thread.ret: +# print("Failed to capture image") +# cap_thread.release() +# return + +# source_image_path = args.source_image # Set the source image path here +# ret, frame = cap_thread.read() +# x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb = live_portrait_pipeline.execute_frame(frame, source_image_path) + +# while cap_thread.running: +# # Capture frame-by-frame +# ret, frame = cap_thread.read() +# if not ret: +# break + +# # Process the frame +# result = live_portrait_pipeline.generate_frame(x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, frame) +# # cv2.imshow('img_rgb Image', img_rgb) +# cv2.imshow('Webcam Frame', frame) + +# # Convert the result from RGB to BGR before displaying +# result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + +# # Display the resulting frame +# cv2.imshow('Webcam Live Portrait', result_bgr) + +# # Press 'q' to exit the loop +# if cv2.waitKey(1) & 0xFF == ord('q'): +# break + +# # When everything is done, release the capture +# cap_thread.release() +# cv2.destroyAllWindows() + +# if __name__ == '__main__': +# st = time.time() +# main() +# print("Generation time:", (time.time() - st) * 1000) diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index 7fda1f5..b777ed4 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -1,13 +1,7 @@ -# coding: utf-8 - """ Pipeline of LivePortrait """ -# TODO: -# 1. 当前假定所有的模板都是已经裁好的,需要修改下 -# 2. pick样例图 source + driving - import cv2 import numpy as np import pickle @@ -38,153 +32,67 @@ class LivePortraitPipeline(object): self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg) self.cropper = Cropper(crop_cfg=crop_cfg) - def execute(self, args: ArgumentConfig): - inference_cfg = self.live_portrait_wrapper.cfg # for convenience - ######## process source portrait ######## - img_rgb = load_image_rgb(args.source_image) + def execute_frame(self, frame, source_image_path): + inference_cfg = self.live_portrait_wrapper.cfg # for convenience + + # Load and preprocess source image + img_rgb = load_image_rgb(source_image_path) img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) - log(f"Load source image from {args.source_image}") + log(f"Load source image from {source_image_path}") crop_info = self.cropper.crop_single_image(img_rgb) source_lmk = crop_info['lmk_crop'] img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] + if inference_cfg.flag_do_crop: I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) else: I_s = self.live_portrait_wrapper.prepare_source(img_rgb) + x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_c_s = x_s_info['kp'] R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) + lip_delta_before_animation = None if inference_cfg.flag_lip_zero: - # let lip-open scalar to be 0 at first c_d_lip_before_animation = [0.] - combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) + combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, crop_info['lmk_crop']) if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold: inference_cfg.flag_lip_zero = False else: lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) - ############################################ - ######## process driving info ######## - if is_video(args.driving_info): - log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}") - # TODO: 这里track一下驱动视频 -> 构建模板 - driving_rgb_lst = load_driving_info(args.driving_info) - driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] - I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256) - n_frames = I_d_lst.shape[0] - if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting: - driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst) - input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) - elif is_template(args.driving_info): - log(f"Load from video templates {args.driving_info}") - with open(args.driving_info, 'rb') as f: - template_lst, driving_lmk_lst = pickle.load(f) - n_frames = template_lst[0]['n_frames'] - input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) - else: - raise Exception("Unsupported driving types!") - ######################################### + return x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb + + def generate_frame(self, x_s, f_s, R_s, x_s_info, lip_delta_before_animation, crop_info, img_rgb, driving_info): + inference_cfg = self.live_portrait_wrapper.cfg # for convenience + + # Process driving info + driving_rgb = cv2.resize(driving_info, (256, 256)) + I_d_i = self.live_portrait_wrapper.prepare_driving_videos([driving_rgb])[0] + + + x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) + R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) + + R_new = R_d_i @ R_s + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_s_info['exp']) + scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_s_info['scale']) + t_new = x_s_info['t'] + (x_d_i_info['t'] - x_s_info['t']) + t_new[..., 2].fill_(0) # zero tz + + x_d_i_new = scale_new * (x_s @ R_new + delta_new) + t_new + if inference_cfg.flag_lip_zero and lip_delta_before_animation is not None: + x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) + + out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) + I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] - ######## prepare for pasteback ######## if inference_cfg.flag_pasteback: mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) - I_p_paste_lst = [] - ######################################### - - I_p_lst = [] - R_d_0, x_d_0_info = None, None - for i in track(range(n_frames), description='Animating...', total=n_frames): - if is_video(args.driving_info): - # extract kp info by M - I_d_i = I_d_lst[i] - x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) - R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) - else: - # from template - x_d_i_info = template_lst[i] - x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id) - R_d_i = x_d_i_info['R_d'] - - if i == 0: - R_d_0 = R_d_i - x_d_0_info = x_d_i_info - - if inference_cfg.flag_relative: - R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s - delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) - scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) - t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) - else: - R_new = R_d_i - delta_new = x_d_i_info['exp'] - scale_new = x_s_info['scale'] - t_new = x_d_i_info['t'] - - t_new[..., 2].fill_(0) # zero tz - x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new - - # Algorithm 1: - if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: - # without stitching or retargeting - if inference_cfg.flag_lip_zero: - x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) - else: - pass - elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: - # with stitching and without retargeting - if inference_cfg.flag_lip_zero: - x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) - else: - x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) - else: - eyes_delta, lip_delta = None, None - if inference_cfg.flag_eye_retargeting: - c_d_eyes_i = input_eye_ratio_lst[i] - combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) - # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) - eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) - if inference_cfg.flag_lip_retargeting: - c_d_lip_i = input_lip_ratio_lst[i] - combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) - # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) - lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor) - - if inference_cfg.flag_relative: # use x_s - x_d_i_new = x_s + \ - (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ - (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) - else: # use x_d,i - x_d_i_new = x_d_i_new + \ - (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ - (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) - - if inference_cfg.flag_stitching: - x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) - - out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) - I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] - I_p_lst.append(I_p_i) - - if inference_cfg.flag_pasteback: - I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori) - I_p_paste_lst.append(I_p_i_to_ori_blend) - - mkdir(args.output_dir) - wfp_concat = None - if is_video(args.driving_info): - frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256) - # save (driving frames, source image, drived frames) result - wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4') - images2video(frames_concatenated, wfp=wfp_concat) - - # save drived result - wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') - if inference_cfg.flag_pasteback: - images2video(I_p_paste_lst, wfp=wfp) + I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori) + return I_p_i_to_ori_blend else: - images2video(I_p_lst, wfp=wfp) + return I_p_i - return wfp, wfp_concat