refactored interpolator

This commit is contained in:
k4yt3x 2023-05-25 18:44:54 +00:00
parent 5255e20283
commit cc01f2d8e2
2 changed files with 44 additions and 74 deletions

View File

@ -22,105 +22,74 @@ Date Created: May 27, 2021
Last Modified: March 20, 2022 Last Modified: March 20, 2022
""" """
import multiprocessing
import queue
import signal
import time import time
from multiprocessing.managers import ListProxy
from multiprocessing.sharedctypes import Synchronized
from loguru import logger from loguru import logger
from PIL import ImageChops, ImageStat from PIL import ImageChops, ImageStat
from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife
ALGORITHM_CLASSES = {"rife": Rife} from .processor import Processor
class Interpolator(multiprocessing.Process): class Interpolator:
def __init__( ALGORITHM_CLASSES = {"rife": Rife}
self,
processing_queue: multiprocessing.Queue,
processed_frames: ListProxy,
pause: Synchronized,
) -> None:
multiprocessing.Process.__init__(self)
self.processing_queue = processing_queue
self.processed_frames = processed_frames
self.pause = pause
self.running = False processor_objects = {}
self.processor_objects = {}
signal.signal(signal.SIGTERM, self._stop) def interpolate_image(self, image0, image1, difference_threshold, algorithm):
difference = ImageChops.difference(image0, image1)
def run(self) -> None: difference_stat = ImageStat.Stat(difference)
self.running = True difference_ratio = (
logger.opt(colors=True).info( sum(difference_stat.mean) / (len(difference_stat.mean) * 255) * 100
f"Interpolator process <blue>{self.name}</blue> initiating"
) )
while self.running is True:
if difference_ratio < difference_threshold:
processor_object = self.processor_objects.get(algorithm)
if processor_object is None:
processor_object = self.ALGORITHM_CLASSES[algorithm](0)
self.processor_objects[algorithm] = processor_object
interpolated_image = processor_object.process(image0, image1)
else:
interpolated_image = image0
return interpolated_image
class InterpolatorProcessor(Processor, Interpolator):
def process(self) -> None:
task = self.tasks_queue.get()
while task is not None:
try: try:
# pause if pause flag is set if self.pause_flag.value is True:
if self.pause.value is True:
time.sleep(0.1) time.sleep(0.1)
continue continue
try: (
# get new job from queue frame_index,
( image0,
frame_index, image1,
(image0, image1), (difference_threshold, algorithm),
(difference_threshold, algorithm), ) = task
) = self.processing_queue.get(False)
except queue.Empty:
time.sleep(0.1)
continue
# if image0 is None, image1 is the first frame
# skip this round
if image0 is None: if image0 is None:
task = self.tasks_queue.get()
continue continue
# calculate the %diff between the current frame and the previous frame interpolated_image = self.interpolate_image(
difference = ImageChops.difference(image0, image1) image0, image1, difference_threshold, algorithm
difference_stat = ImageStat.Stat(difference)
difference_ratio = (
sum(difference_stat.mean) / (len(difference_stat.mean) * 255) * 100
) )
# if the difference is lower than threshold
# process the interpolation
if difference_ratio < difference_threshold:
# select a processor object with the required settings
# create a new object if none are available
processor_object = self.processor_objects.get(algorithm)
if processor_object is None:
processor_object = ALGORITHM_CLASSES[algorithm](0)
self.processor_objects[algorithm] = processor_object
interpolated_image = processor_object.process(image0, image1)
# if the difference is greater than threshold
# there's a change in camera angle, ignore
else:
interpolated_image = image0
if frame_index == 1: if frame_index == 1:
self.processed_frames[0] = image0 self.processed_frames[0] = image0
self.processed_frames[frame_index * 2 - 1] = interpolated_image self.processed_frames[frame_index * 2 - 1] = interpolated_image
self.processed_frames[frame_index * 2] = image1 self.processed_frames[frame_index * 2] = image1
# send exceptions into the client connection pipe task = self.tasks_queue.get()
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
break break
except Exception as error: except Exception as error:
logger.exception(error) logger.exception(error)
break break
logger.opt(colors=True).info(
f"Interpolator process <blue>{self.name}</blue> terminating"
)
return super().run()
def _stop(self, _signal_number, _frame) -> None:
self.running = False

View File

@ -69,7 +69,7 @@ from video2x.processor import Processor
from . import __version__ from . import __version__
from .decoder import VideoDecoder, VideoDecoderThread from .decoder import VideoDecoder, VideoDecoderThread
from .encoder import VideoEncoder from .encoder import VideoEncoder
from .interpolator import Interpolator from .interpolator import Interpolator, InterpolatorProcessor
from .upscaler import Upscaler, UpscalerProcessor from .upscaler import Upscaler, UpscalerProcessor
# for desktop environments only # for desktop environments only
@ -102,7 +102,7 @@ class ProcessingSpeedColumn(ProgressColumn):
class ProcessingMode(Enum): class ProcessingMode(Enum):
UPSCALE = {"label": "Upscaling", "processor": UpscalerProcessor} UPSCALE = {"label": "Upscaling", "processor": UpscalerProcessor}
INTERPOLATE = {"label": "Interpolating", "processor": Interpolator} INTERPOLATE = {"label": "Interpolating", "processor": InterpolatorProcessor}
class Video2X: class Video2X:
@ -179,7 +179,7 @@ class Video2X:
# elif mode == ProcessingMode.INTERPOLATE: # elif mode == ProcessingMode.INTERPOLATE:
else: else:
standalone_processor: Any = Interpolator.ALGORITHM_CLASSES[ standalone_processor: Any = Interpolator.ALGORITHM_CLASSES[
processing_settings[2] processing_settings[1]
] ]
if getattr(standalone_processor, "process", None) is None: if getattr(standalone_processor, "process", None) is None:
standalone_processor().process_video( standalone_processor().process_video(
@ -222,7 +222,7 @@ class Video2X:
logger.info("Starting video encoder") logger.info("Starting video encoder")
encoder = VideoEncoder( encoder = VideoEncoder(
input_path, input_path,
frame_rate * 2 if mode == "interpolate" else frame_rate, frame_rate * 2 if mode == ProcessingMode.INTERPOLATE else frame_rate,
output_path, output_path,
output_width, output_width,
output_height, output_height,
@ -337,6 +337,7 @@ class Video2X:
else: else:
logger.info("Processing has completed") logger.info("Processing has completed")
logger.info("Writing video trailer")
finally: finally:
# stop keyboard listener # stop keyboard listener