2024-12-01 09:55:56 +00:00
|
|
|
#include "filter_realesrgan.h"
|
2024-10-10 07:23:13 +00:00
|
|
|
|
2024-10-08 02:29:00 +00:00
|
|
|
#include <cstdint>
|
|
|
|
#include <cstdio>
|
2024-11-02 00:00:00 +00:00
|
|
|
#include <filesystem>
|
2024-10-08 02:29:00 +00:00
|
|
|
|
2024-10-14 02:46:59 +00:00
|
|
|
#include <spdlog/spdlog.h>
|
|
|
|
|
2024-10-08 02:29:00 +00:00
|
|
|
#include "conversions.h"
|
|
|
|
#include "fsutils.h"
|
2024-12-20 04:46:10 +00:00
|
|
|
#include "logger_manager.h"
|
2024-10-08 02:29:00 +00:00
|
|
|
|
2024-12-17 16:24:51 +00:00
|
|
|
namespace video2x {
|
|
|
|
namespace processors {
|
|
|
|
|
2024-12-01 09:55:56 +00:00
|
|
|
FilterRealesrgan::FilterRealesrgan(
|
2024-10-08 02:29:00 +00:00
|
|
|
int gpuid,
|
|
|
|
bool tta_mode,
|
|
|
|
int scaling_factor,
|
2024-12-17 16:24:51 +00:00
|
|
|
const fsutils::StringType model_name
|
2024-10-08 02:29:00 +00:00
|
|
|
)
|
2024-12-01 09:55:56 +00:00
|
|
|
: realesrgan_(nullptr),
|
|
|
|
gpuid_(gpuid),
|
|
|
|
tta_mode_(tta_mode),
|
|
|
|
scaling_factor_(scaling_factor),
|
|
|
|
model_name_(std::move(model_name)) {}
|
|
|
|
|
|
|
|
FilterRealesrgan::~FilterRealesrgan() {
|
2025-01-08 00:00:00 +00:00
|
|
|
delete realesrgan_;
|
2025-01-09 00:00:00 +00:00
|
|
|
realesrgan_ = nullptr;
|
2024-10-08 02:29:00 +00:00
|
|
|
}
|
|
|
|
|
2024-12-31 00:00:00 +00:00
|
|
|
int FilterRealesrgan::init(AVCodecContext* dec_ctx, AVCodecContext* enc_ctx, AVBufferRef*) {
|
2024-10-08 02:29:00 +00:00
|
|
|
// Construct the model paths using std::filesystem
|
|
|
|
std::filesystem::path model_param_path;
|
|
|
|
std::filesystem::path model_bin_path;
|
|
|
|
|
2024-12-17 16:24:51 +00:00
|
|
|
fsutils::StringType param_file_name =
|
|
|
|
model_name_ + STR("-x") + fsutils::to_string_type(scaling_factor_) + STR(".param");
|
|
|
|
fsutils::StringType bin_file_name =
|
|
|
|
model_name_ + STR("-x") + fsutils::to_string_type(scaling_factor_) + STR(".bin");
|
2024-11-02 00:00:00 +00:00
|
|
|
|
|
|
|
// Find the model paths by model name if provided
|
2024-11-04 00:00:00 +00:00
|
|
|
model_param_path = std::filesystem::path(STR("models")) / STR("realesrgan") / param_file_name;
|
|
|
|
model_bin_path = std::filesystem::path(STR("models")) / STR("realesrgan") / bin_file_name;
|
2024-10-08 02:29:00 +00:00
|
|
|
|
|
|
|
// Get the full paths using a function that possibly modifies or validates the path
|
2025-01-11 00:00:00 +00:00
|
|
|
std::optional<std::filesystem::path> model_param_full_path =
|
|
|
|
fsutils::find_resource(model_param_path);
|
|
|
|
std::optional<std::filesystem::path> model_bin_full_path =
|
|
|
|
fsutils::find_resource(model_bin_path);
|
2024-10-08 02:29:00 +00:00
|
|
|
|
2024-10-10 07:23:13 +00:00
|
|
|
// Check if the model files exist
|
2025-01-11 00:00:00 +00:00
|
|
|
if (!model_param_full_path.has_value()) {
|
|
|
|
logger()->error("Real-ESRGAN model param file not found: {}", model_param_path.u8string());
|
2024-10-10 07:23:13 +00:00
|
|
|
return -1;
|
|
|
|
}
|
2025-01-11 00:00:00 +00:00
|
|
|
if (!model_bin_full_path.has_value()) {
|
|
|
|
logger()->error("Real-ESRGAN model bin file not found: {}", model_bin_path.u8string());
|
2024-10-10 07:23:13 +00:00
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
|
2025-01-11 00:00:00 +00:00
|
|
|
// Create a new Real-ESRGAN instance
|
2024-12-01 09:55:56 +00:00
|
|
|
realesrgan_ = new RealESRGAN(gpuid_, tta_mode_);
|
2024-10-08 02:29:00 +00:00
|
|
|
|
|
|
|
// Store the time bases
|
2024-12-01 09:55:56 +00:00
|
|
|
in_time_base_ = dec_ctx->time_base;
|
|
|
|
out_time_base_ = enc_ctx->time_base;
|
|
|
|
out_pix_fmt_ = enc_ctx->pix_fmt;
|
2024-10-08 02:29:00 +00:00
|
|
|
|
|
|
|
// Load the model
|
2025-01-11 00:00:00 +00:00
|
|
|
if (realesrgan_->load(model_param_full_path.value(), model_bin_full_path.value()) != 0) {
|
|
|
|
logger()->error("Failed to load Real-ESRGAN model");
|
2024-10-08 02:29:00 +00:00
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
|
2025-01-11 00:00:00 +00:00
|
|
|
// Set Real-ESRGAN parameters
|
2024-12-01 09:55:56 +00:00
|
|
|
realesrgan_->scale = scaling_factor_;
|
|
|
|
realesrgan_->prepadding = 10;
|
2024-10-08 02:29:00 +00:00
|
|
|
|
|
|
|
// Calculate tilesize based on GPU heap budget
|
2024-12-01 09:55:56 +00:00
|
|
|
uint32_t heap_budget = ncnn::get_gpu_device(gpuid_)->get_heap_budget();
|
2024-10-08 02:29:00 +00:00
|
|
|
if (heap_budget > 1900) {
|
2024-12-01 09:55:56 +00:00
|
|
|
realesrgan_->tilesize = 200;
|
2024-10-08 02:29:00 +00:00
|
|
|
} else if (heap_budget > 550) {
|
2024-12-01 09:55:56 +00:00
|
|
|
realesrgan_->tilesize = 100;
|
2024-10-08 02:29:00 +00:00
|
|
|
} else if (heap_budget > 190) {
|
2024-12-01 09:55:56 +00:00
|
|
|
realesrgan_->tilesize = 64;
|
2024-10-08 02:29:00 +00:00
|
|
|
} else {
|
2024-12-01 09:55:56 +00:00
|
|
|
realesrgan_->tilesize = 32;
|
2024-10-08 02:29:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2024-12-31 00:00:00 +00:00
|
|
|
int FilterRealesrgan::filter(AVFrame* in_frame, AVFrame** out_frame) {
|
2024-10-10 07:23:13 +00:00
|
|
|
int ret;
|
|
|
|
|
2024-10-08 02:29:00 +00:00
|
|
|
// Convert the input frame to RGB24
|
2024-12-17 16:24:51 +00:00
|
|
|
ncnn::Mat in_mat = conversions::avframe_to_ncnn_mat(in_frame);
|
2024-10-21 23:54:22 +00:00
|
|
|
if (in_mat.empty()) {
|
2024-12-20 04:46:10 +00:00
|
|
|
logger()->error("Failed to convert AVFrame to ncnn::Mat");
|
2024-10-10 07:23:13 +00:00
|
|
|
return -1;
|
2024-10-08 02:29:00 +00:00
|
|
|
}
|
|
|
|
|
2024-12-01 09:55:56 +00:00
|
|
|
// Allocate space for output ncnn::Mat
|
|
|
|
int output_width = in_mat.w * realesrgan_->scale;
|
|
|
|
int output_height = in_mat.h * realesrgan_->scale;
|
2024-10-26 00:00:00 +00:00
|
|
|
ncnn::Mat out_mat = ncnn::Mat(output_width, output_height, static_cast<size_t>(3), 3);
|
2024-10-08 02:29:00 +00:00
|
|
|
|
2024-12-01 09:55:56 +00:00
|
|
|
ret = realesrgan_->process(in_mat, out_mat);
|
2024-10-10 07:23:13 +00:00
|
|
|
if (ret != 0) {
|
2025-01-11 00:00:00 +00:00
|
|
|
logger()->error("Real-ESRGAN processing failed");
|
2024-10-10 07:23:13 +00:00
|
|
|
return ret;
|
2024-10-08 02:29:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Convert ncnn::Mat to AVFrame
|
2024-12-17 16:24:51 +00:00
|
|
|
*out_frame = conversions::ncnn_mat_to_avframe(out_mat, out_pix_fmt_);
|
2024-10-08 02:29:00 +00:00
|
|
|
|
|
|
|
// Rescale PTS to encoder's time base
|
2024-12-01 09:55:56 +00:00
|
|
|
(*out_frame)->pts = av_rescale_q(in_frame->pts, in_time_base_, out_time_base_);
|
2024-10-08 02:29:00 +00:00
|
|
|
|
|
|
|
// Return the processed frame to the caller
|
2024-10-10 07:23:13 +00:00
|
|
|
return ret;
|
2024-10-08 02:29:00 +00:00
|
|
|
}
|
2024-12-01 09:55:56 +00:00
|
|
|
|
|
|
|
void FilterRealesrgan::get_output_dimensions(
|
2024-12-31 00:00:00 +00:00
|
|
|
const ProcessorConfig&,
|
2024-12-01 09:55:56 +00:00
|
|
|
int in_width,
|
|
|
|
int in_height,
|
2024-12-31 00:00:00 +00:00
|
|
|
int& out_width,
|
|
|
|
int& out_height
|
2024-12-01 09:55:56 +00:00
|
|
|
) const {
|
|
|
|
out_width = in_width * scaling_factor_;
|
|
|
|
out_height = in_height * scaling_factor_;
|
|
|
|
}
|
2024-12-17 16:24:51 +00:00
|
|
|
|
|
|
|
} // namespace processors
|
|
|
|
} // namespace video2x
|