redirected STDOUT and STDERR to prevent output from breaking the progress bar

This commit is contained in:
k4yt3x 2022-02-12 23:51:30 +00:00
parent 7c0e9c45d8
commit 0d0fd70a24

View File

@ -58,6 +58,8 @@ import time
# third-party imports # third-party imports
from loguru import logger from loguru import logger
from rich import print from rich import print
from rich.console import Console
from rich.file_proxy import FileProxy
from rich.progress import ( from rich.progress import (
BarColumn, BarColumn,
Progress, Progress,
@ -98,6 +100,13 @@ DRIVER_FIXED_SCALING_RATIOS = {
# progress bar labels for different modes # progress bar labels for different modes
MODE_LABELS = {"upscale": "Upscaling", "interpolate": "Interpolating"} MODE_LABELS = {"upscale": "Upscaling", "interpolate": "Interpolating"}
# format string for Loguru loggers
LOGURU_FORMAT = (
"<green>{time:HH:mm:ss.SSSSSS!UTC}</green> | "
"<level>{level: <8}</level> | "
"<level>{message}</level>"
)
class ProcessingSpeedColumn(ProgressColumn): class ProcessingSpeedColumn(ProgressColumn):
"""Custom progress bar column that displays the processing speed""" """Custom progress bar column that displays the processing speed"""
@ -165,6 +174,22 @@ class Video2X:
processes: int, processes: int,
processing_settings: tuple, processing_settings: tuple,
): ):
# record original STDOUT and STDERR for restoration
original_stdout = sys.stdout
original_stderr = sys.stderr
# create console for rich's Live display
console = Console()
# redirect STDOUT and STDERR to console
sys.stdout = FileProxy(console, sys.stdout)
sys.stderr = FileProxy(console, sys.stderr)
# re-add Loguru to point to the new STDERR
logger.remove()
logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)
# initialize values # initialize values
self.processor_processes = [] self.processor_processes = []
self.processing_queue = multiprocessing.Queue(maxsize=processes * 10) self.processing_queue = multiprocessing.Queue(maxsize=processes * 10)
@ -197,7 +222,7 @@ class Video2X:
) )
self.encoder.start() self.encoder.start()
# create upscaler processes # create processor processes
for process_name in range(processes): for process_name in range(processes):
process = Processor(self.processing_queue, processed_frames) process = Processor(self.processing_queue, processed_frames)
process.name = str(process_name) process.name = str(process_name)
@ -212,13 +237,14 @@ class Video2X:
# create progress bar # create progress bar
with Progress( with Progress(
"[progress.description]{task.description}", "[progress.description]{task.description}",
BarColumn(), BarColumn(finished_style="green"),
"[progress.percentage]{task.percentage:>3.0f}%", "[progress.percentage]{task.percentage:>3.0f}%",
"[color(240)]({task.completed}/{task.total})", "[color(240)]({task.completed}/{task.total})",
ProcessingSpeedColumn(), ProcessingSpeedColumn(),
TimeElapsedColumn(), TimeElapsedColumn(),
"<", "<",
TimeRemainingColumn(), TimeRemainingColumn(),
console=console,
disable=True, disable=True,
) as progress: ) as progress:
task = progress.add_task( task = progress.add_task(
@ -226,7 +252,8 @@ class Video2X:
) )
# wait for jobs in queue to deplete # wait for jobs in queue to deplete
while self.encoder.is_alive() is True: while self.processed.value < total_frames - 1:
time.sleep(0.5)
for process in self.processor_processes: for process in self.processor_processes:
if not process.is_alive(): if not process.is_alive():
raise Exception("process died unexpectedly") raise Exception("process died unexpectedly")
@ -238,10 +265,9 @@ class Video2X:
# update progress # update progress
progress.update(task, completed=self.processed.value) progress.update(task, completed=self.processed.value)
time.sleep(0.5)
logger.info("Encoding has completed") progress.update(task, completed=total_frames)
progress.update(task, completed=self.processed.value) logger.info("Processing has completed")
# if SIGTERM is received or ^C is pressed # if SIGTERM is received or ^C is pressed
# TODO: pause and continue here # TODO: pause and continue here
@ -276,6 +302,14 @@ class Video2X:
if len(exception) > 0: if len(exception) > 0:
raise exception[0] raise exception[0]
# restore original STDOUT and STDERR
sys.stdout = original_stdout
sys.stderr = original_stderr
# re-add Loguru to point to the restored STDERR
logger.remove()
logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)
def upscale( def upscale(
self, self,
input_path: pathlib.Path, input_path: pathlib.Path,
@ -461,18 +495,10 @@ def main():
os.environ["LOGURU_LEVEL"] = args.loglevel.upper() os.environ["LOGURU_LEVEL"] = args.loglevel.upper()
# remove default handler # remove default handler
logger.remove(0) logger.remove()
# add new sink with custom handler # add new sink with custom handler
logger.add( logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)
sys.stderr,
colorize=True,
format=(
"<green>{time:HH:mm:ss.SSSSSS!UTC}</green> | "
"<level>{level: <8}</level> | "
"<level>{message}</level>"
),
)
# display version and lawful informaition # display version and lawful informaition
if args.version: if args.version:
@ -509,6 +535,8 @@ def main():
args.driver, args.driver,
) )
logger.success("Processing completed successfully")
# don't print the traceback for manual terminations # don't print the traceback for manual terminations
except (SystemExit, KeyboardInterrupt) as e: except (SystemExit, KeyboardInterrupt) as e:
raise SystemExit(e) raise SystemExit(e)