feat(video2x): updated the variables in the Video2X class for compatibility with the GUI

Signed-off-by: k4yt3x <i@k4yt3x.com>
This commit is contained in:
k4yt3x 2023-09-16 18:49:29 +00:00
parent 6d934e6a98
commit 22993028b4
No known key found for this signature in database

View File

@ -39,7 +39,7 @@ import time
from enum import Enum from enum import Enum
from multiprocessing import Manager, Pool, Queue, Value from multiprocessing import Manager, Pool, Queue, Value
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Callable, Optional
import ffmpeg import ffmpeg
from cv2 import cv2 from cv2 import cv2
@ -106,8 +106,9 @@ class Video2X:
- interpolate: perform motion interpolation on a file - interpolate: perform motion interpolation on a file
""" """
def __init__(self) -> None: def __init__(self, progress_callback: Optional[Callable] = None) -> None:
self.version = __version__ self.version = __version__
self.progress_callback = progress_callback
@staticmethod @staticmethod
def _get_video_info(path: Path) -> tuple: def _get_video_info(path: Path) -> tuple:
@ -229,7 +230,7 @@ class Video2X:
processor_pool = Pool(processes, processor.process) processor_pool = Pool(processes, processor.process)
# create progress bar # create progress bar
progress = Progress( self.progress = Progress(
"[progress.description]{task.description}", "[progress.description]{task.description}",
BarColumn(complete_style="blue", finished_style="green"), BarColumn(complete_style="blue", finished_style="green"),
"[progress.percentage]{task.percentage:>3.0f}%", "[progress.percentage]{task.percentage:>3.0f}%",
@ -242,7 +243,9 @@ class Video2X:
speed_estimate_period=300.0, speed_estimate_period=300.0,
disable=True, disable=True,
) )
task = progress.add_task(f"[cyan]{mode.value['label']}", total=total_frames) task = self.progress.add_task(
f"[cyan]{mode.value['label']}", total=total_frames
)
def _toggle_pause(_signal_number: int = -1, _frame=None): def _toggle_pause(_signal_number: int = -1, _frame=None):
# allow the closure to modify external immutable flag # allow the closure to modify external immutable flag
@ -250,17 +253,17 @@ class Video2X:
# print console messages and update the progress bar's status # print console messages and update the progress bar's status
if pause_flag.value is False: if pause_flag.value is False:
progress.update( self.progress.update(
task, description=f"[cyan]{mode.value['label']} (paused)" task, description=f"[cyan]{mode.value['label']} (paused)"
) )
progress.stop_task(task) self.progress.stop_task(task)
logger.warning("Processing paused, press Ctrl+Alt+V again to resume") logger.warning("Processing paused, press Ctrl+Alt+V again to resume")
# the lock is already acquired # the lock is already acquired
elif pause_flag.value is True: elif pause_flag.value is True:
progress.update(task, description=f"[cyan]{mode.value['label']}") self.progress.update(task, description=f"[cyan]{mode.value['label']}")
logger.warning("Resuming processing") logger.warning("Resuming processing")
progress.start_task(task) self.progress.start_task(task)
# invert the flag # invert the flag
with pause_flag.get_lock(): with pause_flag.get_lock():
@ -292,7 +295,7 @@ class Video2X:
try: try:
# let the context manager automatically stop the progress bar # let the context manager automatically stop the progress bar
with progress: with self.progress:
frame_index = 0 frame_index = 0
while frame_index < total_frames: while frame_index < total_frames:
current_frame = processed_frames.get(frame_index) current_frame = processed_frames.get(frame_index)
@ -304,8 +307,8 @@ class Video2X:
# show the progress bar after the processing starts # show the progress bar after the processing starts
# reduces speed estimation inaccuracies and print overlaps # reduces speed estimation inaccuracies and print overlaps
if frame_index == 0: if frame_index == 0:
progress.disable = False self.progress.disable = False
progress.start() self.progress.start()
if current_frame is True: if current_frame is True:
encoder.write(processed_frames.get(frame_index - 1)) encoder.write(processed_frames.get(frame_index - 1))
@ -316,7 +319,9 @@ class Video2X:
if frame_index > 0: if frame_index > 0:
del processed_frames[frame_index - 1] del processed_frames[frame_index - 1]
progress.update(task, completed=frame_index + 1) self.progress.update(task, completed=frame_index + 1)
if self.progress_callback is not None:
self.progress_callback(frame_index + 1, total_frames)
frame_index += 1 frame_index += 1
# if SIGTERM is received or ^C is pressed # if SIGTERM is received or ^C is pressed
@ -389,10 +394,10 @@ class Video2X:
width, height, total_frames, frame_rate = self._get_video_info(input_path) width, height, total_frames, frame_rate = self._get_video_info(input_path)
# automatically calculate output width and height if only one is given # automatically calculate output width and height if only one is given
if output_width == 0 or output_width is None: if output_width == 0:
output_width = output_height / height * width output_width = output_height / height * width
elif output_height == 0 or output_height is None: elif output_height == 0:
output_height = output_width / width * height output_height = output_width / width * height
# sanitize output width and height to be divisible by 2 # sanitize output width and height to be divisible by 2