refactored interpolator

This commit is contained in:
k4yt3x 2023-05-25 18:44:54 +00:00
parent 5255e20283
commit cc01f2d8e2
2 changed files with 44 additions and 74 deletions

View File

@ -22,105 +22,74 @@ Date Created: May 27, 2021
Last Modified: March 20, 2022
"""
import multiprocessing
import queue
import signal
import time
from multiprocessing.managers import ListProxy
from multiprocessing.sharedctypes import Synchronized
from loguru import logger
from PIL import ImageChops, ImageStat
from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife
from .processor import Processor
class Interpolator:
ALGORITHM_CLASSES = {"rife": Rife}
processor_objects = {}
class Interpolator(multiprocessing.Process):
def __init__(
self,
processing_queue: multiprocessing.Queue,
processed_frames: ListProxy,
pause: Synchronized,
) -> None:
multiprocessing.Process.__init__(self)
self.processing_queue = processing_queue
self.processed_frames = processed_frames
self.pause = pause
self.running = False
self.processor_objects = {}
signal.signal(signal.SIGTERM, self._stop)
def run(self) -> None:
self.running = True
logger.opt(colors=True).info(
f"Interpolator process <blue>{self.name}</blue> initiating"
)
while self.running is True:
try:
# pause if pause flag is set
if self.pause.value is True:
time.sleep(0.1)
continue
try:
# get new job from queue
(
frame_index,
(image0, image1),
(difference_threshold, algorithm),
) = self.processing_queue.get(False)
except queue.Empty:
time.sleep(0.1)
continue
# if image0 is None, image1 is the first frame
# skip this round
if image0 is None:
continue
# calculate the %diff between the current frame and the previous frame
def interpolate_image(self, image0, image1, difference_threshold, algorithm):
difference = ImageChops.difference(image0, image1)
difference_stat = ImageStat.Stat(difference)
difference_ratio = (
sum(difference_stat.mean) / (len(difference_stat.mean) * 255) * 100
)
# if the difference is lower than threshold
# process the interpolation
if difference_ratio < difference_threshold:
# select a processor object with the required settings
# create a new object if none are available
processor_object = self.processor_objects.get(algorithm)
if processor_object is None:
processor_object = ALGORITHM_CLASSES[algorithm](0)
processor_object = self.ALGORITHM_CLASSES[algorithm](0)
self.processor_objects[algorithm] = processor_object
interpolated_image = processor_object.process(image0, image1)
# if the difference is greater than threshold
# there's a change in camera angle, ignore
else:
interpolated_image = image0
return interpolated_image
class InterpolatorProcessor(Processor, Interpolator):
def process(self) -> None:
task = self.tasks_queue.get()
while task is not None:
try:
if self.pause_flag.value is True:
time.sleep(0.1)
continue
(
frame_index,
image0,
image1,
(difference_threshold, algorithm),
) = task
if image0 is None:
task = self.tasks_queue.get()
continue
interpolated_image = self.interpolate_image(
image0, image1, difference_threshold, algorithm
)
if frame_index == 1:
self.processed_frames[0] = image0
self.processed_frames[frame_index * 2 - 1] = interpolated_image
self.processed_frames[frame_index * 2] = image1
# send exceptions into the client connection pipe
task = self.tasks_queue.get()
except (SystemExit, KeyboardInterrupt):
break
except Exception as error:
logger.exception(error)
break
logger.opt(colors=True).info(
f"Interpolator process <blue>{self.name}</blue> terminating"
)
return super().run()
def _stop(self, _signal_number, _frame) -> None:
self.running = False

View File

@ -69,7 +69,7 @@ from video2x.processor import Processor
from . import __version__
from .decoder import VideoDecoder, VideoDecoderThread
from .encoder import VideoEncoder
from .interpolator import Interpolator
from .interpolator import Interpolator, InterpolatorProcessor
from .upscaler import Upscaler, UpscalerProcessor
# for desktop environments only
@ -102,7 +102,7 @@ class ProcessingSpeedColumn(ProgressColumn):
class ProcessingMode(Enum):
UPSCALE = {"label": "Upscaling", "processor": UpscalerProcessor}
INTERPOLATE = {"label": "Interpolating", "processor": Interpolator}
INTERPOLATE = {"label": "Interpolating", "processor": InterpolatorProcessor}
class Video2X:
@ -179,7 +179,7 @@ class Video2X:
# elif mode == ProcessingMode.INTERPOLATE:
else:
standalone_processor: Any = Interpolator.ALGORITHM_CLASSES[
processing_settings[2]
processing_settings[1]
]
if getattr(standalone_processor, "process", None) is None:
standalone_processor().process_video(
@ -222,7 +222,7 @@ class Video2X:
logger.info("Starting video encoder")
encoder = VideoEncoder(
input_path,
frame_rate * 2 if mode == "interpolate" else frame_rate,
frame_rate * 2 if mode == ProcessingMode.INTERPOLATE else frame_rate,
output_path,
output_width,
output_height,
@ -337,6 +337,7 @@ class Video2X:
else:
logger.info("Processing has completed")
logger.info("Writing video trailer")
finally:
# stop keyboard listener