feat: update

This commit is contained in:
zhangdingyun 2024-08-14 15:56:55 +08:00
parent 200f84dd1f
commit 3f45c776da
5 changed files with 159 additions and 59 deletions

110
app.py
View File

@ -85,12 +85,12 @@ data_examples_i2v = [
[osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
]
data_examples_v2v = [
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
# [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7],
# [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, 3e-7],
# [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, False, 3e-7],
[osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7],
]
#################### interface logic ####################
@ -126,6 +126,75 @@ output_video = gr.Video(autoplay=False)
output_video_paste_back = gr.Video(autoplay=False)
output_video_i2v = gr.Video(autoplay=False)
output_video_concat_i2v = gr.Video(autoplay=False)
output_image_i2i = gr.Image(type="numpy")
output_image_concat_i2i = gr.Image(type="numpy")
"""
每个点和每个维度对应的表情
(0,0): 头顶左右偏
(0,1): 头顶上下偏
(0,2): 头顶前后偏
(1,0): 眉毛上下眼睛左右
(1,1): 眉毛上下眼睛上下
(1,2): 嘴巴和眼睛的动作
(2,0): 眉毛上下眼睛左右
(2,1): 眉毛上下眼睛上下
(2,2): 嘴巴动作
(3,0): 左脸胖瘦, 眉毛上下
(3,1): 左脸上下眉毛上下
(3,2): 左脸前后会变形
(4,0): 右脸胖瘦
(4,1): 右脸上下
(4,2): 右脸前后会变形
(5,0): 头左右平移
(5,1): 头上下平移
(5,2): 嘴部动作
(6,0): 嘴部动作
(6,1): 嘴部动作
(6,2): 嘴部动作
(7,0): 右脸胖瘦
(7,1): 右脸上下
(7,2): 右脸前后
(8,0): 右脸胖瘦
(8,1): 右脸上下
(8,2): 嘴部动作
(9,0): 下巴胖瘦
(9,1): 嘴部动作
(9,2): 眼部动作
(10,0): 左边放缩
(10,1): 左边放缩眼部动作
(10,2): 下巴放缩
(11,0): 左眼左右转
(11,1): 左眼上下睁开闭合
(11,2): 左眼前后
(12,0): 嘴部动作
(12,1): 无明显
(12,2): 嘴部动作
(13,0): 眼部动作
(13,1): 眼部动作
(13,2): 眼部动作
(14,0): 嘴部动作
(14,1): 嘴部动作
(14,2): 嘴部动作
(15,0): 眼部动作
(15,1): 眼部动作嘴部动作
(15,2): 眼部动作
(16,0): 眼睛
(16,1): 右眼睁开闭合嘴部动作
(16,2): 眼部动作
(17,0): 嘴部动作眼部动作
(17,1): 嘴部动作眼部动作
(17,2): 撅嘴拉平嘴
(18,0): 眼部方向
(18,1): 眼部上下
(18,2): 嘴部动作眼部动作
(19,0): 撇嘴
(19,1): 张开闭合嘴
(19,2): 内收外翻嘴
(20,0): 下弯嘴
(20,1): 露牙闭合牙
(20,2): 下拉嘴哦形嘴
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
@ -196,6 +265,19 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[driving_video_input],
cache_examples=False,
)
with gr.TabItem("🖼️ Driving Image") as v_tab_image:
with gr.Accordion(open=True, label="Driving Image"):
driving_image_input = gr.Image(type="filepath")
gr.Examples(
examples=[
[osp.join(example_video_dir, "d3.jpg")],
[osp.join(example_video_dir, "d9.jpg")],
[osp.join(example_video_dir, "d11.jpg")],
],
inputs=[driving_image_input],
cache_examples=False,
)
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
@ -212,8 +294,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
)
v_tab_selection = gr.Textbox(visible=False)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
v_tab_video.select(lambda: "Video", None, v_tab_selection)
v_tab_image.select(lambda: "Image", None, v_tab_selection)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
# with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
@ -229,9 +312,9 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_stitching_input = gr.Checkbox(value=True, label="stitching")
animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region")
driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)")
driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02)
flag_video_editing_head_rotation = gr.Checkbox(value=False, label="relative head rotation (v2v)")
driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8)
gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
@ -244,8 +327,14 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated image in the original image space"):
output_image_i2i.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated image"):
output_image_concat_i2i.render()
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear")
with gr.Row():
# Examples
@ -279,7 +368,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
driving_smooth_observation_variance,
],
outputs=[output_image, output_image_paste_back],
@ -413,16 +501,17 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
inputs=[
source_image_input,
source_video_input,
driving_video_pickle_input,
driving_video_input,
driving_image_input,
driving_video_pickle_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_stitching_input,
animation_region,
driving_option_input,
driving_multiplier,
flag_crop_driving_video_input,
flag_video_editing_head_rotation,
scale,
vx_ratio,
vy_ratio,
@ -433,10 +522,11 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta San
tab_selection,
v_tab_selection,
],
outputs=[output_video_i2v, output_video_concat_i2v],
outputs=[output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i],
show_progress=True
)
retargeting_input_image.change(
fn=gradio_pipeline.init_retargeting_image,
inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image],

View File

@ -13,8 +13,8 @@ from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig):
########## input arguments ##########
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to the source portrait (human/animal) or video (human)
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d6.pkl') # path to driving video or template (.pkl format)
source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s18.mp4') # path to the source portrait (human/animal) or video (human)
driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d3.jpg') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
########## inference arguments ##########
@ -24,7 +24,6 @@ class ArgumentConfig(PrintableConfig):
flag_force_cpu: bool = False # force cpu inference, WIP!
flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
flag_video_editing_head_rotation: bool = False # when the input is a source video, whether to inherit the relative head rotation from the driving video
flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
@ -35,7 +34,7 @@ class ArgumentConfig(PrintableConfig):
driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "eyes" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
########## source crop arguments ##########
det_thresh: float = 0.15 # detection threshold
scale: float = 2.3 # the ratio of face area is smaller if scale is larger

View File

@ -34,7 +34,6 @@ class InferenceConfig(PrintableConfig):
device_id: int = 0
flag_normalize_lip: bool = True
flag_source_video_eye_retargeting: bool = False
flag_video_editing_head_rotation: bool = False
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True

View File

@ -146,16 +146,18 @@ class GradioPipeline(LivePortraitPipeline):
self,
input_source_image_path=None,
input_source_video_path=None,
input_driving_video_pickle_path=None,
input_driving_video_path=None,
input_driving_image_path=None,
input_driving_video_pickle_path=None,
flag_relative_input=True,
flag_do_crop_input=True,
flag_remap_input=True,
flag_stitching_input=True,
animation_region="all",
driving_option_input="pose-friendly",
driving_multiplier=1.0,
flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False,
# flag_video_editing_head_rotation=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
@ -177,6 +179,8 @@ class GradioPipeline(LivePortraitPipeline):
if v_tab_selection == 'Video':
input_driving_path = input_driving_video_path
elif v_tab_selection == 'Image':
input_driving_path = input_driving_image_path
elif v_tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path
else:
@ -195,10 +199,10 @@ class GradioPipeline(LivePortraitPipeline):
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'flag_stitching': flag_stitching_input,
'animation_region': animation_region,
'driving_option': driving_option_input,
'driving_multiplier': driving_multiplier,
'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
@ -211,10 +215,13 @@ class GradioPipeline(LivePortraitPipeline):
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
output_path, output_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat
if output_path.endswith(".jpg"):
return None, None, output_path, output_path_concat
else:
return output_path, output_path_concat, None, None
else:
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)

View File

@ -112,8 +112,12 @@ class LivePortraitPipeline(object):
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
driving_n_frames = driving_template_dct['n_frames']
flag_is_driving_video = True if driving_n_frames > 1 else False
if flag_is_source_video:
# if flag_is_source_video and not flag_is_driving_video:
# raise Exception(f"Animating a source video with a driving image is not supported!")
if flag_is_source_video and flag_is_driving_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
elif flag_is_source_video and not flag_is_driving_video:
n_frames = len(source_rgb_lst)
else:
n_frames = driving_n_frames
@ -134,8 +138,10 @@ class LivePortraitPipeline(object):
driving_rgb_lst = load_video(args.driving)
elif is_image(args.driving):
flag_is_driving_video = False
# if flag_is_source_video:
# raise Exception(f"Animating a source video with a driving image is not supported!")
driving_img_rgb = load_image_rgb(args.driving)
output_fps = 1
output_fps = 25
log(f"Load driving image from {args.driving}")
driving_rgb_lst = [driving_img_rgb]
else:
@ -143,9 +149,11 @@ class LivePortraitPipeline(object):
######## make motion template ########
log("Start making driving motion template...")
driving_n_frames = len(driving_rgb_lst)
if flag_is_source_video:
if flag_is_source_video and flag_is_driving_video:
n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames
driving_rgb_lst = driving_rgb_lst[:n_frames]
elif flag_is_source_video and not flag_is_driving_video:
n_frames = len(source_rgb_lst)
else:
n_frames = driving_n_frames
if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)):
@ -207,15 +215,23 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_relative_motion:
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
if flag_is_driving_video:
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance)
if inf_cfg.flag_video_editing_head_rotation:
else:
x_d_exp_lst = [driving_template_dct['motion'][0]['exp']]
x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]*n_frames
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
if flag_is_driving_video:
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance)
else:
x_d_r_lst = [driving_template_dct['motion'][0][key_r]]
x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]*n_frames
else: # if the input is a source image, process it only once
if inf_cfg.flag_do_crop:
@ -281,7 +297,9 @@ class LivePortraitPipeline(object):
if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back
mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
if flag_is_source_video and not flag_is_driving_video:
x_d_i_info = driving_template_dct['motion'][0]
else:
x_d_i_info = driving_template_dct['motion'][i]
x_d_i_info = dct2device(x_d_i_info, device)
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
@ -292,24 +310,17 @@ class LivePortraitPipeline(object):
delta_new = x_s_info['exp'].clone()
if inf_cfg.flag_relative_motion:
if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
else:
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
R_new = x_d_r_lst_smooth[i] if flag_is_source_video else (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
else:
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
elif inf_cfg.animation_region == "lip":
for lip_idx in [14, 17, 19, 20]:
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
elif inf_cfg.animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16]:
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
if inf_cfg.animation_region == "all":
scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
@ -319,20 +330,12 @@ class LivePortraitPipeline(object):
t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
t_new = x_s_info['t']
else:
if flag_is_source_video:
if inf_cfg.flag_video_editing_head_rotation:
R_new = x_d_r_lst_smooth[i]
else:
R_new = R_s
else:
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
R_new = R_d_i
R_new = x_d_r_lst_smooth[i] if flag_is_source_video else R_d_i
else:
R_new = R_s
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp":
# delta_new = x_d_exp_lst_smooth[i] if flag_is_source_video else x_d_i_info['exp']
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] if flag_is_source_video else x_d_i_info['exp'][:, idx, :]
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] if flag_is_source_video else x_d_i_info['exp'][:, 3:5, 1]
@ -340,10 +343,10 @@ class LivePortraitPipeline(object):
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] if flag_is_source_video else x_d_i_info['exp'][:, 8, 2]
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] if flag_is_source_video else x_d_i_info['exp'][:, 9, 1:]
elif inf_cfg.animation_region == "lip":
for lip_idx in [14, 17, 19, 20]:
for lip_idx in [6, 12, 14, 17, 19, 20]:
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, lip_idx, :]
elif inf_cfg.animation_region == "eyes":
for eyes_idx in [11, 13, 15, 16]:
for eyes_idx in [11, 13, 15, 16, 18]:
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, eyes_idx, :]
scale_new = x_s_info['scale']
if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose":
@ -421,12 +424,14 @@ class LivePortraitPipeline(object):
wfp_concat = None
######### build the final concatenation result #########
# driving frame | source frame | generation
if flag_is_source_video:
if flag_is_source_video and flag_is_driving_video:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst)
elif flag_is_source_video and not flag_is_driving_video:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst)
else:
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst)
if flag_is_driving_video:
if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video):
flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source)
flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving)