diff --git a/app.py b/app.py index 15cbaa2..12b4dbf 100644 --- a/app.py +++ b/app.py @@ -98,6 +98,7 @@ data_examples_v2v = [ retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale") video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale") driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8) +video_retargeting_silence = gr.Checkbox(value=False, label="keeping the lip silent") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") @@ -124,10 +125,6 @@ retargeting_output_image = gr.Image(type="numpy") retargeting_output_image_paste_back = gr.Image(type="numpy") output_video = gr.Video(autoplay=False) output_video_paste_back = gr.Video(autoplay=False) -output_video_i2v = gr.Video(autoplay=False) -output_video_concat_i2v = gr.Video(autoplay=False) -output_image_i2i = gr.Image(type="numpy") -output_image_concat_i2i = gr.Image(type="numpy") """ 每个点和每个维度对应的表情: @@ -274,6 +271,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San [osp.join(example_video_dir, "d9.jpg")], [osp.join(example_video_dir, "d19.jpg")], [osp.join(example_video_dir, "d8.jpg")], + [osp.join(example_video_dir, "d12.jpg")], + [osp.join(example_video_dir, "d38.jpg")], ], inputs=[driving_image_input], cache_examples=False, @@ -323,18 +322,14 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San process_button_animation = gr.Button("🚀 Animate", variant="primary") with gr.Row(): with gr.Column(): - with gr.Accordion(open=True, label="The animated video in the original image space"): - output_video_i2v.render() + output_video_i2v = gr.Video(autoplay=False, label="The animated video in the original image space") with gr.Column(): - with gr.Accordion(open=True, label="The animated video"): - output_video_concat_i2v.render() + output_video_concat_i2v = gr.Video(autoplay=False, label="The animated video") with gr.Row(): with gr.Column(): - with gr.Accordion(open=True, label="The animated image in the original image space"): - output_image_i2i.render() + output_image_i2i = gr.Image(type="numpy", label="The animated image in the original image space", visible=False) with gr.Column(): - with gr.Accordion(open=True, label="The animated image"): - output_image_concat_i2i.render() + output_image_concat_i2i = gr.Image(type="numpy", label="The animated image", visible=False) with gr.Row(): process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear") @@ -463,6 +458,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San video_retargeting_source_scale.render() video_lip_retargeting_slider.render() driving_smooth_observation_variance_retargeting.render() + video_retargeting_silence.render() with gr.Row(visible=True): process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary") with gr.Row(visible=True): @@ -524,7 +520,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San tab_selection, v_tab_selection, ], - outputs=[output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], + outputs=[output_video_i2v, output_video_i2v, output_video_concat_i2v, output_video_concat_i2v, output_image_i2i, output_image_i2i, output_image_concat_i2i, output_image_concat_i2i], show_progress=True ) @@ -550,7 +546,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San process_button_retargeting_video.click( fn=gpu_wrapped_execute_video_retargeting, - inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, flag_do_crop_input_retargeting_video], + inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, video_retargeting_silence, flag_do_crop_input_retargeting_video], outputs=[output_video, output_video_paste_back], show_progress=True ) diff --git a/src/config/argument_config.py b/src/config/argument_config.py index e0dccd8..aa482b8 100644 --- a/src/config/argument_config.py +++ b/src/config/argument_config.py @@ -13,7 +13,7 @@ from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class ArgumentConfig(PrintableConfig): ########## input arguments ########## - source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to the source portrait (human/animal) or video (human) + source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s3.jpg') # path to the source portrait (human/animal) or video (human) driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d30.jpg') # path to driving video or template (.pkl format) output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index a364c3d..8a67b40 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -219,9 +219,9 @@ class GradioPipeline(LivePortraitPipeline): output_path, output_path_concat = self.execute(self.args) gr.Info("Run successfully!", duration=2) if output_path.endswith(".jpg"): - return None, None, output_path, output_path_concat + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True) else: - return output_path, output_path_concat, None, None + return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) else: raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5) @@ -396,29 +396,51 @@ class GradioPipeline(LivePortraitPipeline): return source_eye_ratio, source_lip_ratio @torch.no_grad() - def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, flag_do_crop_input_retargeting_video=True): + def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True): """ retargeting the lip-open ratio of each source frame """ # disposable feature device = self.live_portrait_wrapper.device - f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \ - self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video) - if input_lip_ratio is None: - raise gr.Error("Invalid ratio input 💥!", duration=5) + if not video_retargeting_silence: + f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \ + self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video) + if input_lip_ratio is None: + raise gr.Error("Invalid ratio input 💥!", duration=5) + else: + inference_cfg = self.live_portrait_wrapper.inference_cfg + + I_p_pstbk_lst = None + if flag_do_crop_input_retargeting_video: + I_p_pstbk_lst = [] + I_p_lst = [] + for i in track(range(n_frames), description='Retargeting video...', total=n_frames): + x_s_user_i = x_s_user_lst[i].to(device) + f_s_user_i = f_s_user_lst[i].to(device) + + lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i] + x_d_i_new = x_s_user_i + lip_delta_retargeting + x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new) + out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new) + I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] + I_p_lst.append(I_p_i) + + if flag_do_crop_input_retargeting_video: + I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) + I_p_pstbk_lst.append(I_p_pstbk) else: inference_cfg = self.live_portrait_wrapper.inference_cfg + f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \ + self.prepare_video_lip_silence(input_video, device, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video) I_p_pstbk_lst = None if flag_do_crop_input_retargeting_video: I_p_pstbk_lst = [] I_p_lst = [] - for i in track(range(n_frames), description='Retargeting video...', total=n_frames): + for i in track(range(n_frames), description='Silencing video...', total=n_frames): x_s_user_i = x_s_user_lst[i].to(device) f_s_user_i = f_s_user_lst[i].to(device) - - lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i] - x_d_i_new = x_s_user_i + lip_delta_retargeting + x_d_i_new = x_d_i_new_lst[i] x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new) out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new) I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] @@ -428,37 +450,37 @@ class GradioPipeline(LivePortraitPipeline): I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) I_p_pstbk_lst.append(I_p_pstbk) - mkdir(self.args.output_dir) - flag_source_has_audio = has_audio_stream(input_video) + mkdir(self.args.output_dir) + flag_source_has_audio = has_audio_stream(input_video) - ######### build the final concatenation result ######### - # source frame | generation - frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst) - wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4') - images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps) + ######### build the final concatenation result ######### + # source frame | generation + frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst) + wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4') + images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps) - if flag_source_has_audio: - # final result with concatenation - wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4') - add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio) - os.replace(wfp_concat_with_audio, wfp_concat) - log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") + if flag_source_has_audio: + # final result with concatenation + wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4') + add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio) + os.replace(wfp_concat_with_audio, wfp_concat) + log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") - # save the animated result - wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4') - if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: - images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps) - else: - images2video(I_p_lst, wfp=wfp, fps=source_fps) + # save the animated result + wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4') + if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: + images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps) + else: + images2video(I_p_lst, wfp=wfp, fps=source_fps) - ######### build the final result ######### - if flag_source_has_audio: - wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4') - add_audio_to_video(wfp, input_video, wfp_with_audio) - os.replace(wfp_with_audio, wfp) - log(f"Replace {wfp_with_audio} with {wfp}") - gr.Info("Run successfully!", duration=2) - return wfp_concat, wfp + ######### build the final result ######### + if flag_source_has_audio: + wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4') + add_audio_to_video(wfp, input_video, wfp_with_audio) + os.replace(wfp_with_audio, wfp) + log(f"Replace {wfp_with_audio} with {wfp}") + gr.Info("Run successfully!", duration=2) + return wfp_concat, wfp @torch.no_grad() def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True): @@ -517,6 +539,59 @@ class GradioPipeline(LivePortraitPipeline): # when press the clear button, go here raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5) + @torch.no_grad() + def prepare_video_lip_silence(self, input_video, device, driving_smooth_observation_variance_retargeting, flag_do_crop=True): + """ for keeping lips in the source video silent + """ + if input_video is not None: + # gr.Info("Upload successfully!", duration=2) + inference_cfg = self.live_portrait_wrapper.inference_cfg + ######## process source video ######## + source_rgb_lst = load_video(input_video) + source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst] + source_fps = int(get_fps(input_video)) + n_frames = len(source_rgb_lst) + log(f"Load source video from {input_video}. FPS is {source_fps}") + + if flag_do_crop: + ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg) + log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.') + if len(ret_s["frame_crop_lst"]) != n_frames: + n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"])) + img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst'] + mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst] + else: + source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst) + img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256 + source_M_c2o_lst, mask_ori_lst = None, None + + c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst) + # save the motion template + I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst) + source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps) + + f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], [] + for i in track(range(n_frames), description='Preparing silence lip...', total=n_frames): + x_s_info = source_template_dct['motion'][i] + x_s_info = dct2device(x_s_info, device) + x_s_user = x_s_info['x_s'] + delta_new = torch.zeros_like(x_s_info['exp']) + for eyes_idx in [11, 13, 15, 16, 18]: + delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :] + + source_lmk = source_lmk_crop_lst[i] + img_crop_256x256 = img_crop_256x256_lst[i] + I_s = I_s_lst[i] + f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) + x_d_i_new = x_s_info['scale'] * (x_s_info['kp'] @ x_s_info['R'] + delta_new + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)) + x_s_info['t'] + + f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new) + + return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames + else: + # when press the clear button, go here + raise gr.Error("Please upload a source video as the input 🤗🤗🤗", duration=5) + class GradioPipelineAnimal(LivePortraitPipelineAnimal): """gradio for animal """ diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index 54651a8..d36db88 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -217,7 +217,8 @@ class LivePortraitPipeline(object): x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) else: - x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)] + # x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - 0) for i in range(n_frames)] if driving_template_dct['motion'][0]['exp'].mean() > 0 else [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)] + x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)] x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst] if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": if flag_is_driving_video: @@ -340,14 +341,25 @@ class LivePortraitPipeline(object): if flag_is_driving_video: delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) else: - delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - 0) if x_d_i_info['exp'].mean() > 0 else x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)) - + # delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - 0) if x_d_i_info['exp'].mean() > 0 else x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)) + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)) elif inf_cfg.animation_region == "lip": for lip_idx in [6, 12, 14, 17, 19, 20]: - delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] + if flag_is_source_video: + delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] + elif flag_is_driving_video: + delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] + else: + # delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, lip_idx, :] if x_d_i_info['exp'].mean() > 0 else (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :] + delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :] elif inf_cfg.animation_region == "eyes": for eyes_idx in [11, 13, 15, 16, 18]: - delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :] + if flag_is_source_video: + delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] + elif flag_is_driving_video: + delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :] + else: + delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, eyes_idx, :] if inf_cfg.animation_region == "all": scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) else: @@ -384,7 +396,7 @@ class LivePortraitPipeline(object): # x_d_i_new = x_s_info['scale'] * (x_c_s @ R_s) + x_s_info['t'] x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new - if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video and inf_cfg.animation_region == "all": + if inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video: if i == 0: x_d_0_new = x_d_i_new motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) @@ -454,7 +466,10 @@ class LivePortraitPipeline(object): if flag_is_source_video and flag_is_driving_video: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) elif flag_is_source_video and not flag_is_driving_video: - frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst) + if flag_load_from_template: + frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) + else: + frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst) else: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)