terminology change: driver -> algorithm

This commit is contained in:
k4yt3x 2022-02-15 00:54:17 +00:00
parent b6b1bf9f0e
commit a7f0f34751
3 changed files with 47 additions and 51 deletions

View File

@ -38,7 +38,7 @@ from PIL import ImageChops, ImageStat
from loguru import logger from loguru import logger
DRIVER_CLASSES = {"rife": Rife} ALGORITHM_CLASSES = {"rife": Rife}
class Interpolator(multiprocessing.Process): class Interpolator(multiprocessing.Process):
@ -57,7 +57,7 @@ class Interpolator(multiprocessing.Process):
def run(self): def run(self):
self.running = True self.running = True
logger.info(f"Interpolator process {self.name} initiating") logger.info(f"Interpolator process {self.name} initiating")
driver_objects = {} processor_objects = {}
while self.running: while self.running:
try: try:
try: try:
@ -65,7 +65,7 @@ class Interpolator(multiprocessing.Process):
( (
frame_index, frame_index,
(image0, image1), (image0, image1),
(difference_threshold, driver), (difference_threshold, algorithm),
) = self.processing_queue.get(False) ) = self.processing_queue.get(False)
except queue.Empty: except queue.Empty:
time.sleep(0.1) time.sleep(0.1)
@ -86,13 +86,13 @@ class Interpolator(multiprocessing.Process):
# process the interpolation # process the interpolation
if difference_ratio < difference_threshold: if difference_ratio < difference_threshold:
# select a driver object with the required settings # select a processor object with the required settings
# create a new object if none are available # create a new object if none are available
driver_object = driver_objects.get(driver) processor_object = processor_objects.get(algorithm)
if driver_object is None: if processor_object is None:
driver_object = DRIVER_CLASSES[driver](0) processor_object = ALGORITHM_CLASSES[algorithm](0)
driver_objects[driver] = driver_object processor_objects[algorithm] = processor_object
interpolated_image = driver_object.process(image0, image1) interpolated_image = processor_object.process(image0, image1)
# if the difference is greater than threshold # if the difference is greater than threshold
# there's a change in camera angle, ignore # there's a change in camera angle, ignore

View File

@ -40,15 +40,15 @@ import time
from PIL import Image, ImageChops, ImageStat from PIL import Image, ImageChops, ImageStat
from loguru import logger from loguru import logger
# fixed scaling ratios supported by the drivers # fixed scaling ratios supported by the algorithms
# that only support certain fixed scale ratios # that only support certain fixed scale ratios
DRIVER_FIXED_SCALING_RATIOS = { ALGORITHM_FIXED_SCALING_RATIOS = {
"waifu2x": [1, 2], "waifu2x": [1, 2],
"srmd": [2, 3, 4], "srmd": [2, 3, 4],
"realsr": [4], "realsr": [4],
} }
DRIVER_CLASSES = {"waifu2x": Waifu2x, "srmd": Srmd, "realsr": Realsr} ALGORITHM_CLASSES = {"waifu2x": Waifu2x, "srmd": Srmd, "realsr": Realsr}
class Upscaler(multiprocessing.Process): class Upscaler(multiprocessing.Process):
@ -69,7 +69,7 @@ class Upscaler(multiprocessing.Process):
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Upscaler process <blue>{self.name}</blue> initiating" f"Upscaler process <blue>{self.name}</blue> initiating"
) )
driver_objects = {} processor_objects = {}
while self.running: while self.running:
try: try:
try: try:
@ -82,7 +82,7 @@ class Upscaler(multiprocessing.Process):
output_height, output_height,
noise, noise,
difference_threshold, difference_threshold,
driver, algorithm,
), ),
) = self.processing_queue.get(False) ) = self.processing_queue.get(False)
@ -123,9 +123,9 @@ class Upscaler(multiprocessing.Process):
# calculate required minimum scale ratio # calculate required minimum scale ratio
output_scale = max(output_width / width, output_height / height) output_scale = max(output_width / width, output_height / height)
# select the optimal driver scaling ratio to use # select the optimal algorithm scaling ratio to use
supported_scaling_ratios = sorted( supported_scaling_ratios = sorted(
DRIVER_FIXED_SCALING_RATIOS[driver] ALGORITHM_FIXED_SCALING_RATIOS[algorithm]
) )
remaining_scaling_ratio = math.ceil(output_scale) remaining_scaling_ratio = math.ceil(output_scale)
@ -163,17 +163,17 @@ class Upscaler(multiprocessing.Process):
for job in scaling_jobs: for job in scaling_jobs:
# select a driver object with the required settings # select a processor object with the required settings
# create a new object if none are available # create a new object if none are available
driver_object = driver_objects.get((driver, job)) processor_object = processor_objects.get((algorithm, job))
if driver_object is None: if processor_object is None:
driver_object = DRIVER_CLASSES[driver]( processor_object = ALGORITHM_CLASSES[algorithm](
scale=job, noise=noise scale=job, noise=noise
) )
driver_objects[(driver, job)] = driver_object processor_objects[(algorithm, job)] = processor_object
# process the image with the selected driver # process the image with the selected algorithm
image1 = driver_object.process(image1) image1 = processor_object.process(image1)
# downscale the image to the desired output size and save the image to disk # downscale the image to the desired output size and save the image to disk
image1 = image1.resize((output_width, output_height), Image.LANCZOS) image1 = image1.resize((output_width, output_height), Image.LANCZOS)

View File

@ -81,21 +81,15 @@ Contact: k4yt3x@k4yt3x.com""".format(
__version__ __version__
) )
UPSCALING_DRIVERS = [ # algorithms available for upscaling tasks
UPSCALING_ALGORITHMS = [
"waifu2x", "waifu2x",
"srmd", "srmd",
"realsr", "realsr",
] ]
INTERPOLATION_DRIVERS = ["rife"] # algorithms available for frame interpolation tasks
INTERPOLATION_ALGORITHMS = ["rife"]
# fixed scaling ratios supported by the drivers
# that only support certain fixed scale ratios
DRIVER_FIXED_SCALING_RATIOS = {
"waifu2x": [1, 2],
"srmd": [2, 3, 4],
"realsr": [4],
}
# progress bar labels for different modes # progress bar labels for different modes
MODE_LABELS = {"upscale": "Upscaling", "interpolate": "Interpolating"} MODE_LABELS = {"upscale": "Upscaling", "interpolate": "Interpolating"}
@ -279,6 +273,10 @@ class Video2X:
logger.exception(e) logger.exception(e)
exception.append(e) exception.append(e)
# if no exceptions were produced
else:
logger.success("Processing completed successfully")
finally: finally:
# mark processing queue as closed # mark processing queue as closed
self.processing_queue.close() self.processing_queue.close()
@ -319,7 +317,7 @@ class Video2X:
noise: int, noise: int,
processes: int, processes: int,
threshold: float, threshold: float,
driver: str, algorithm: str,
) -> None: ) -> None:
# get basic video information # get basic video information
@ -354,7 +352,7 @@ class Video2X:
output_height, output_height,
noise, noise,
threshold, threshold,
driver, algorithm,
), ),
) )
@ -364,7 +362,7 @@ class Video2X:
output_path: pathlib.Path, output_path: pathlib.Path,
processes: int, processes: int,
threshold: float, threshold: float,
driver: str, algorithm: str,
) -> None: ) -> None:
# get video basic information # get video basic information
@ -386,7 +384,7 @@ class Video2X:
Interpolator, Interpolator,
"interpolate", "interpolate",
processes, processes,
(threshold, driver), (threshold, algorithm),
) )
@ -440,11 +438,11 @@ def parse_arguments() -> argparse.Namespace:
upscale.add_argument("-h", "--height", type=int, help="output height") upscale.add_argument("-h", "--height", type=int, help="output height")
upscale.add_argument("-n", "--noise", type=int, help="denoise level", default=3) upscale.add_argument("-n", "--noise", type=int, help="denoise level", default=3)
upscale.add_argument( upscale.add_argument(
"-d", "-a",
"--driver", "--algorithm",
choices=UPSCALING_DRIVERS, choices=UPSCALING_ALGORITHMS,
help="driver to use for upscaling", help="algorithm to use for upscaling",
default=UPSCALING_DRIVERS[0], default=UPSCALING_ALGORITHMS[0],
) )
upscale.add_argument( upscale.add_argument(
"-t", "-t",
@ -462,11 +460,11 @@ def parse_arguments() -> argparse.Namespace:
"--help", action="help", help="show this help message and exit" "--help", action="help", help="show this help message and exit"
) )
interpolate.add_argument( interpolate.add_argument(
"-d", "-a",
"--driver", "--algorithm",
choices=UPSCALING_DRIVERS, choices=UPSCALING_ALGORITHMS,
help="driver to use for upscaling", help="algorithm to use for upscaling",
default=INTERPOLATION_DRIVERS[0], default=INTERPOLATION_ALGORITHMS[0],
) )
interpolate.add_argument( interpolate.add_argument(
"-t", "-t",
@ -523,7 +521,7 @@ def main():
args.noise, args.noise,
args.processes, args.processes,
args.threshold, args.threshold,
args.driver, args.algorithm,
) )
elif args.action == "interpolate": elif args.action == "interpolate":
@ -532,11 +530,9 @@ def main():
args.output, args.output,
args.processes, args.processes,
args.threshold, args.threshold,
args.driver, args.algorithm,
) )
logger.success("Processing completed successfully")
# don't print the traceback for manual terminations # don't print the traceback for manual terminations
except (SystemExit, KeyboardInterrupt) as e: except (SystemExit, KeyboardInterrupt) as e:
raise SystemExit(e) raise SystemExit(e)