refactored interpolator

This commit is contained in:
k4yt3x 2023-05-25 18:44:54 +00:00
parent 5255e20283
commit cc01f2d8e2
2 changed files with 44 additions and 74 deletions

View File

@ -22,105 +22,74 @@ Date Created: May 27, 2021
Last Modified: March 20, 2022 Last Modified: March 20, 2022
""" """
import multiprocessing
import queue
import signal
import time import time
from multiprocessing.managers import ListProxy
from multiprocessing.sharedctypes import Synchronized
from loguru import logger from loguru import logger
from PIL import ImageChops, ImageStat from PIL import ImageChops, ImageStat
from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife
from .processor import Processor
class Interpolator:
ALGORITHM_CLASSES = {"rife": Rife} ALGORITHM_CLASSES = {"rife": Rife}
processor_objects = {}
class Interpolator(multiprocessing.Process): def interpolate_image(self, image0, image1, difference_threshold, algorithm):
def __init__(
self,
processing_queue: multiprocessing.Queue,
processed_frames: ListProxy,
pause: Synchronized,
) -> None:
multiprocessing.Process.__init__(self)
self.processing_queue = processing_queue
self.processed_frames = processed_frames
self.pause = pause
self.running = False
self.processor_objects = {}
signal.signal(signal.SIGTERM, self._stop)
def run(self) -> None:
self.running = True
logger.opt(colors=True).info(
f"Interpolator process <blue>{self.name}</blue> initiating"
)
while self.running is True:
try:
# pause if pause flag is set
if self.pause.value is True:
time.sleep(0.1)
continue
try:
# get new job from queue
(
frame_index,
(image0, image1),
(difference_threshold, algorithm),
) = self.processing_queue.get(False)
except queue.Empty:
time.sleep(0.1)
continue
# if image0 is None, image1 is the first frame
# skip this round
if image0 is None:
continue
# calculate the %diff between the current frame and the previous frame
difference = ImageChops.difference(image0, image1) difference = ImageChops.difference(image0, image1)
difference_stat = ImageStat.Stat(difference) difference_stat = ImageStat.Stat(difference)
difference_ratio = ( difference_ratio = (
sum(difference_stat.mean) / (len(difference_stat.mean) * 255) * 100 sum(difference_stat.mean) / (len(difference_stat.mean) * 255) * 100
) )
# if the difference is lower than threshold
# process the interpolation
if difference_ratio < difference_threshold: if difference_ratio < difference_threshold:
# select a processor object with the required settings
# create a new object if none are available
processor_object = self.processor_objects.get(algorithm) processor_object = self.processor_objects.get(algorithm)
if processor_object is None: if processor_object is None:
processor_object = ALGORITHM_CLASSES[algorithm](0) processor_object = self.ALGORITHM_CLASSES[algorithm](0)
self.processor_objects[algorithm] = processor_object self.processor_objects[algorithm] = processor_object
interpolated_image = processor_object.process(image0, image1) interpolated_image = processor_object.process(image0, image1)
# if the difference is greater than threshold
# there's a change in camera angle, ignore
else: else:
interpolated_image = image0 interpolated_image = image0
return interpolated_image
class InterpolatorProcessor(Processor, Interpolator):
def process(self) -> None:
task = self.tasks_queue.get()
while task is not None:
try:
if self.pause_flag.value is True:
time.sleep(0.1)
continue
(
frame_index,
image0,
image1,
(difference_threshold, algorithm),
) = task
if image0 is None:
task = self.tasks_queue.get()
continue
interpolated_image = self.interpolate_image(
image0, image1, difference_threshold, algorithm
)
if frame_index == 1: if frame_index == 1:
self.processed_frames[0] = image0 self.processed_frames[0] = image0
self.processed_frames[frame_index * 2 - 1] = interpolated_image self.processed_frames[frame_index * 2 - 1] = interpolated_image
self.processed_frames[frame_index * 2] = image1 self.processed_frames[frame_index * 2] = image1
# send exceptions into the client connection pipe task = self.tasks_queue.get()
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
break break
except Exception as error: except Exception as error:
logger.exception(error) logger.exception(error)
break break
logger.opt(colors=True).info(
f"Interpolator process <blue>{self.name}</blue> terminating"
)
return super().run()
def _stop(self, _signal_number, _frame) -> None:
self.running = False

View File

@ -69,7 +69,7 @@ from video2x.processor import Processor
from . import __version__ from . import __version__
from .decoder import VideoDecoder, VideoDecoderThread from .decoder import VideoDecoder, VideoDecoderThread
from .encoder import VideoEncoder from .encoder import VideoEncoder
from .interpolator import Interpolator from .interpolator import Interpolator, InterpolatorProcessor
from .upscaler import Upscaler, UpscalerProcessor from .upscaler import Upscaler, UpscalerProcessor
# for desktop environments only # for desktop environments only
@ -102,7 +102,7 @@ class ProcessingSpeedColumn(ProgressColumn):
class ProcessingMode(Enum): class ProcessingMode(Enum):
UPSCALE = {"label": "Upscaling", "processor": UpscalerProcessor} UPSCALE = {"label": "Upscaling", "processor": UpscalerProcessor}
INTERPOLATE = {"label": "Interpolating", "processor": Interpolator} INTERPOLATE = {"label": "Interpolating", "processor": InterpolatorProcessor}
class Video2X: class Video2X:
@ -179,7 +179,7 @@ class Video2X:
# elif mode == ProcessingMode.INTERPOLATE: # elif mode == ProcessingMode.INTERPOLATE:
else: else:
standalone_processor: Any = Interpolator.ALGORITHM_CLASSES[ standalone_processor: Any = Interpolator.ALGORITHM_CLASSES[
processing_settings[2] processing_settings[1]
] ]
if getattr(standalone_processor, "process", None) is None: if getattr(standalone_processor, "process", None) is None:
standalone_processor().process_video( standalone_processor().process_video(
@ -222,7 +222,7 @@ class Video2X:
logger.info("Starting video encoder") logger.info("Starting video encoder")
encoder = VideoEncoder( encoder = VideoEncoder(
input_path, input_path,
frame_rate * 2 if mode == "interpolate" else frame_rate, frame_rate * 2 if mode == ProcessingMode.INTERPOLATE else frame_rate,
output_path, output_path,
output_width, output_width,
output_height, output_height,
@ -337,6 +337,7 @@ class Video2X:
else: else:
logger.info("Processing has completed") logger.info("Processing has completed")
logger.info("Writing video trailer")
finally: finally:
# stop keyboard listener # stop keyboard listener