mirror of
https://github.com/k4yt3x/video2x.git
synced 2025-01-30 23:58:11 +00:00
106 lines
3.3 KiB
Python
Executable File
106 lines
3.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
Name: Interpolator
|
|
Author: K4YT3X
|
|
Date Created: May 27, 2021
|
|
Last Modified: February 2, 2022
|
|
"""
|
|
|
|
# local imports
|
|
from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife
|
|
|
|
# built-in imports
|
|
import multiprocessing
|
|
import multiprocessing.managers
|
|
import multiprocessing.sharedctypes
|
|
import queue
|
|
import signal
|
|
import time
|
|
|
|
# third-party imports
|
|
from PIL import ImageChops, ImageStat
|
|
from loguru import logger
|
|
|
|
|
|
DRIVER_CLASSES = {"rife": Rife}
|
|
|
|
|
|
class Interpolator(multiprocessing.Process):
|
|
def __init__(
|
|
self,
|
|
processing_queue: multiprocessing.Queue,
|
|
processed_frames: multiprocessing.managers.ListProxy,
|
|
):
|
|
multiprocessing.Process.__init__(self)
|
|
self.running = False
|
|
self.processing_queue = processing_queue
|
|
self.processed_frames = processed_frames
|
|
|
|
signal.signal(signal.SIGTERM, self._stop)
|
|
|
|
def run(self):
|
|
self.running = True
|
|
logger.info(f"Interpolator process {self.name} initiating")
|
|
driver_objects = {}
|
|
while self.running:
|
|
try:
|
|
try:
|
|
# get new job from queue
|
|
(
|
|
frame_index,
|
|
(image0, image1),
|
|
(difference_threshold, driver),
|
|
) = self.processing_queue.get(False)
|
|
except queue.Empty:
|
|
time.sleep(0.1)
|
|
continue
|
|
|
|
# if image0 is None, image1 is the first frame
|
|
# skip this round
|
|
if image0 is None:
|
|
continue
|
|
|
|
difference = ImageChops.difference(image0, image1)
|
|
difference_stat = ImageStat.Stat(difference)
|
|
difference_ratio = (
|
|
sum(difference_stat.mean) / (len(difference_stat.mean) * 255) * 100
|
|
)
|
|
|
|
# if the difference is lower than threshold
|
|
# process the interpolation
|
|
if difference_ratio < difference_threshold:
|
|
|
|
# select a driver object with the required settings
|
|
# create a new object if none are available
|
|
driver_object = driver_objects.get(driver)
|
|
if driver_object is None:
|
|
driver_object = DRIVER_CLASSES[driver](0)
|
|
driver_objects[driver] = driver_object
|
|
interpolated_image = driver_object.process(image0, image1)
|
|
|
|
# if the difference is greater than threshold
|
|
# there's a change in camera angle, ignore
|
|
else:
|
|
interpolated_image = image0
|
|
|
|
if frame_index == 1:
|
|
self.processed_frames[0] = image0
|
|
self.processed_frames[frame_index * 2 - 1] = interpolated_image
|
|
self.processed_frames[frame_index * 2] = image1
|
|
|
|
# send exceptions into the client connection pipe
|
|
except (SystemExit, KeyboardInterrupt):
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
break
|
|
|
|
logger.info(f"Interpolator process {self.name} terminating")
|
|
self.running = False
|
|
return super().run()
|
|
|
|
def _stop(self, _signal_number, _frame):
|
|
self.running = False
|