Skip to content

Commit

Permalink
Add post and preprocessing for offline pipeline (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
RinoReyns authored Nov 16, 2024
1 parent 15657d8 commit ff8556f
Show file tree
Hide file tree
Showing 25 changed files with 402 additions and 260 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ is why this repository was created. I know that it will take a lot of work but i
- [ ] Validate RtAudio for different OSes
- [ ] Move *vst_host_config_* from *WaveProcessingPipeline* class to *AudioProcessingVstHost* class.
- [ ] Rename ** to *AudioProcessingTool*
- [ ] Create Baseclass for processing modules and extract common parameters

1. Android
- TODO:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,29 @@ class VstHostConfigGenerator
{
"input_wave",
"output_wave",
"preprocessing",
"vst_host",
"postprocessing"
SAMPLING_RATE_PARAM_STR,
PREPROCESSING,
VST_HOST_CONFIG_PARAM_STR,
POSTPROCESSING
};

const nlohmann::json sub_sections_params_
{
{"preprocessing",
{SAMPLING_RATE_PARAM_STR, 0},

{PREPROCESSING,
{
{"filter",
{
{"enable", false}
{ENABLE_STRING, false}
}
}
}
},

{"vst_host",
{VST_HOST_CONFIG_PARAM_STR,
{
{"enable", false},
{ENABLE_STRING, false},
{"processing_config",
{
{"plugin_1",
Expand All @@ -62,11 +65,11 @@ class VstHostConfigGenerator
}
},

{"postprocessing",
{POSTPROCESSING,
{
{"filter",
{
{"enable", false}
{ENABLE_STRING, false}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions VstHost_VisualC++/modules/ArgParser/header/arg_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ArgParser
int CheckInputArgsFormat(std::vector<std::string> args);
int ValidateVstHostConfigParam();
int DumpVstHostConfig();
int ValidateSamplingRate(float sampling_rate);

std::unique_ptr<argparse::ArgumentParser> arg_parser_;
std::string input_wave_path_ = "";
Expand Down
26 changes: 19 additions & 7 deletions VstHost_VisualC++/modules/ArgParser/src/arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ int ArgParser::ParsParameters(std::vector<std::string> args)
{
status = this->DumpVstHostConfig();
RETURN_ERROR_IF_NOT_SUCCESS(status);
status = ValidateSamplingRate(main_config_[SAMPLING_RATE_PARAM_STR]);
RETURN_ERROR_IF_NOT_SUCCESS(status);

for(nlohmann::json params : main_config_[VST_HOST_CONFIG_PARAM_STR][PROCESSING_CONFIG_PARAM_STR])
{
Expand All @@ -116,6 +118,7 @@ int ArgParser::ParsParameters(std::vector<std::string> args)
if (!enable_audio_capture_)
{
status = this->ValidateVstHostConfigParam();
RETURN_ERROR_IF_NOT_SUCCESS(status);

std::unique_ptr<VstHostConfigGenerator> config_generator(new VstHostConfigGenerator());
main_config_ = config_generator->ReadAppConfig(vst_host_config_);
Expand All @@ -127,6 +130,9 @@ int ArgParser::ParsParameters(std::vector<std::string> args)

status = CheckOutputWave();
RETURN_ERROR_IF_NOT_SUCCESS(status);

status = ValidateSamplingRate(main_config_[SAMPLING_RATE_PARAM_STR]);
RETURN_ERROR_IF_NOT_SUCCESS(status);
}

// verbosity_
Expand All @@ -142,12 +148,21 @@ int ArgParser::ParsParameters(std::vector<std::string> args)

int ArgParser::CheckIfPathExists(std::string path)
{
int status = VST_ERROR_STATUS::SUCCESS;
if (!std::filesystem::exists(path))
{
status = VST_ERROR_STATUS::PATH_NOT_EXISTS;
return VST_ERROR_STATUS::PATH_NOT_EXISTS;
}
return status;
return VST_ERROR_STATUS::SUCCESS;
}

int ArgParser::ValidateSamplingRate(float sampling_rate)
{
if (sampling_rate == NULL || sampling_rate <= 0)
{
LOG(ERROR) << "Unsupported sampling rate: " << sampling_rate << "Hz.";
return VST_ERROR_STATUS::UNSUPPORTED_SAMPLING_RATE;
}
return VST_ERROR_STATUS::SUCCESS;
}

int ArgParser::ValidateVstHostConfigParam()
Expand All @@ -167,10 +182,7 @@ int ArgParser::ValidateVstHostConfigParam()
int ArgParser::DumpVstHostConfig()
{
int status = this->ValidateVstHostConfigParam();
if (status == VST_ERROR_STATUS::MISSING_PARAMETER_VALUE)
{
return status;
}
RETURN_IF_MISSING_PARAMETER_VALUE(status);

vst_host_config_ = arg_parser_->get<std::string>(VST_HOST_CMD_PARAM_STR);
std::unique_ptr<VstHostConfigGenerator> config_generator(new VstHostConfigGenerator());
Expand Down
11 changes: 8 additions & 3 deletions VstHost_VisualC++/modules/AudioProcessing/header/FilterWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ class FilterWrapper
{
public:
explicit FilterWrapper();
int Init(float sampling_rate);
int ApplyBwLowPassFilter(std::vector<float> input, std::vector<float>& output);
int Init(size_t sampling_rate);
int Process(std::vector<float> input, std::vector<float>& output);
int SetEnableProcessing(bool enable);
~FilterWrapper();

public:
const std::string module_name_ = "filter";

private:
BWLowPass* bw_low_pass_filter_;
BWLowPass* bw_low_pass_filter_ = nullptr;
bool enable_processing_ = false;
};

#endif //FILTER_WRAPPER_H
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ class WaveProcessingPipeline
private:
int CreateVstHost();
int ProcessingVstHost();

int CreatePreprocessingModule();
int CreatePostprocessingModule();
int CreatePreprocessingModules();
int PreprocessingProcessing();
int CreatePostprocessingModules();
int PostprocessingProcessing();
int SwapInOutBuffers();

private:
Expand All @@ -32,6 +33,9 @@ class WaveProcessingPipeline
std::unique_ptr<AudioProcessingVstHost> vst_host_;
std::unique_ptr<WaveDataContainer> input_wave_;
std::unique_ptr<WaveDataContainer> output_wave_;
std::unique_ptr<FilterWrapper> preprocessing_filter_wrapper_;
std::unique_ptr<FilterWrapper> postprocessing_filter_wrapper_;
size_t processing_sampling_rate_ = 0;
};

#endif //WAVE_PROCESSING_PIPELINE_H
42 changes: 35 additions & 7 deletions VstHost_VisualC++/modules/AudioProcessing/source/FilterWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,59 @@ FilterWrapper::~FilterWrapper()
}
}

int FilterWrapper::Init(float sampling_rate)
int FilterWrapper::Init(size_t sampling_rate)
{
if (sampling_rate < 0)
{
LOG(ERROR) << "Unsupported sampling rate: " << sampling_rate << "Hz.";
return VST_ERROR_STATUS::UNSUPPORTED_SAMPLING_RATE;
}

// TODO:
// parameterization for each module is needed.
bw_low_pass_filter_ = create_bw_low_pass_filter(
8,
sampling_rate,
400);
bw_low_pass_filter_ = create_bw_low_pass_filter(
8,
static_cast<FTR_PRECISION>(sampling_rate),
400);

RETURN_ERROR_IF_NULL(bw_low_pass_filter_);

return VST_ERROR_STATUS::SUCCESS;
}

int FilterWrapper::ApplyBwLowPassFilter(
int FilterWrapper::Process(
std::vector<float> input,
std::vector<float>& output)
{
if (!enable_processing_)
{
return VST_ERROR_STATUS::BYPASS;
}

if (bw_low_pass_filter_ == nullptr)
{
return VST_ERROR_STATUS::NULL_POINTER;
}

if (output.size() < input.size())
{
return VST_ERROR_STATUS::SIZE_MISSMATCH;
}

for (int i = 0; i < input.size(); i++)
{
output[i] = bw_low_pass(bw_low_pass_filter_, input[i]);
}

return VST_ERROR_STATUS::SUCCESS;;
return VST_ERROR_STATUS::SUCCESS;
}

int FilterWrapper::SetEnableProcessing(bool enable)
{
enable_processing_ = enable;
if (!enable_processing_)
{
return VST_ERROR_STATUS::BYPASS;
}
return VST_ERROR_STATUS::SUCCESS;
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
#include "WaveProcessingPipeline.h"
#include "VstHostMacro.h"


WaveProcessingPipeline::WaveProcessingPipeline(uint8_t verbosity) :
verbosity_(verbosity)
{}

int WaveProcessingPipeline::Init(nlohmann::json pipeline_config)
{
pipeline_config_ = pipeline_config;
// TODO:
// move vst_host_config_ to vst_host
vst_host_config_ = pipeline_config[VST_HOST_CONFIG_PARAM_STR][PROCESSING_CONFIG_PARAM_STR];
processing_sampling_rate_ = static_cast<size_t>(pipeline_config_[SAMPLING_RATE_PARAM_STR]);

int status = this->CreateVstHost();
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(status);

status = this->CreatePreprocessingModule();
status = this->CreatePreprocessingModules();
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(status);

status = this->CreatePostprocessingModule();
status = this->CreatePostprocessingModules();
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(status);

return status;
return VST_ERROR_STATUS::SUCCESS;
}

int WaveProcessingPipeline::CreateVstHost()
{
vst_host_.reset(new AudioProcessingVstHost());
int status = vst_host_->SetEnableProcessing(pipeline_config_[VST_HOST_CONFIG_PARAM_STR][ENABLE_STRING]);
// TODO:
// move vst_host_config_ to vst_host
vst_host_config_ = pipeline_config_[vst_host_->module_name_][PROCESSING_CONFIG_PARAM_STR];
int status = vst_host_->SetEnableProcessing(pipeline_config_[vst_host_->module_name_][ENABLE_STRING]);
RETURN_IF_BYPASS(status);

vst_host_->SetVerbosity(this->verbosity_);
Expand All @@ -41,18 +42,45 @@ int WaveProcessingPipeline::ProcessingVstHost()
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(status);
status = vst_host_->BufferProcessing(input_wave_.get(), output_wave_.get());
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(status);

return status;
}

int WaveProcessingPipeline::CreatePreprocessingModule()
int WaveProcessingPipeline::CreatePreprocessingModules()
{
return VST_ERROR_STATUS::SUCCESS;
preprocessing_filter_wrapper_.reset(new FilterWrapper());
int status = preprocessing_filter_wrapper_->SetEnableProcessing(pipeline_config_[PREPROCESSING][preprocessing_filter_wrapper_->module_name_][ENABLE_STRING]);
RETURN_IF_BYPASS(status);
return preprocessing_filter_wrapper_->Init(processing_sampling_rate_);
}

int WaveProcessingPipeline::CreatePostprocessingModule()
int WaveProcessingPipeline::PreprocessingProcessing()
{
return VST_ERROR_STATUS::SUCCESS;
LOG(INFO) << "Prepocessing in progress..";
int status = preprocessing_filter_wrapper_->Process(input_wave_->data, output_wave_->data);
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(status);
if (status == VST_ERROR_STATUS::SUCCESS)
{
status = SwapInOutBuffers();
RETURN_ERROR_IF_NOT_SUCCESS(status);
}
LOG(INFO) << "Prepocessing finished.";
return status;
}

int WaveProcessingPipeline::CreatePostprocessingModules()
{
postprocessing_filter_wrapper_.reset(new FilterWrapper());
int status = postprocessing_filter_wrapper_->SetEnableProcessing(pipeline_config_[POSTPROCESSING][postprocessing_filter_wrapper_->module_name_][ENABLE_STRING]);
RETURN_IF_BYPASS(status);
return postprocessing_filter_wrapper_->Init(processing_sampling_rate_);
}

int WaveProcessingPipeline::PostprocessingProcessing()
{
LOG(INFO) << "Postpocessing in progress..";
int status = postprocessing_filter_wrapper_->Process(input_wave_->data, output_wave_->data);
LOG(INFO) << "Postpocessing finished.";
return status;
}

int WaveProcessingPipeline::GetConfig()
Expand All @@ -62,6 +90,10 @@ int WaveProcessingPipeline::GetConfig()

int WaveProcessingPipeline::SwapInOutBuffers()
{
input_wave_->data = output_wave_->data;
input_wave_->frame_number = output_wave_->frame_number;
input_wave_->channel_number = output_wave_->channel_number;
input_wave_->bits_per_sample = output_wave_->bits_per_sample;

return VST_ERROR_STATUS::SUCCESS;
}
Expand All @@ -74,25 +106,28 @@ int WaveProcessingPipeline::Run(std::string input_path, std::string output_path)
}

std::unique_ptr<WaveIOClass> wave_io(new WaveIOClass());
input_wave_.reset(new WaveDataContainer(input_path));
output_wave_.reset(new WaveDataContainer(output_path));
input_wave_.reset(new WaveDataContainer(input_path, processing_sampling_rate_));
output_wave_.reset(new WaveDataContainer(output_path, processing_sampling_rate_));

//Load wave file
// TODO:
// add resampling in wave loader for setting proper sampling rate for all modules
int status = wave_io->LoadWave(input_wave_.get());
RETURN_ERROR_IF_NOT_SUCCESS(status);

// TODO:
// add preprocessing
output_wave_->SetParams(input_wave_.get());

// Preprocessing
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(PreprocessingProcessing());

// VST Host Processing
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(ProcessingVstHost());

// TODO:
// Add option to run only model with OV
// Add option to enable/disable filtration and OV processing
// Preprocessing
RETURN_ERROR_IF_NOT_SUCCESS_OR_BYPASS(PostprocessingProcessing());

// TODO:
// add postprocessing
// Add option to run only model with OV

// Save wave file
status = wave_io->SaveWave(output_wave_.get());
Expand Down
1 change: 1 addition & 0 deletions VstHost_VisualC++/modules/Common/header/VstHostMacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define RETURN_IF_AUDIO_CAPTURE_FAILED(X) if(FAILED(hr = (X))){ return VST_ERROR_STATUS::AUDIO_CAPTURE_ERROR; }
#define RETURN_IF_AUDIO_RENDER_FAILED(X) if(FAILED(hr = (X))){ return VST_ERROR_STATUS::AUDIO_RENDER_ERROR; }
#define RETURN_IF_BYPASS(value) {if(value == VST_ERROR_STATUS::BYPASS) return value;}
#define RETURN_IF_MISSING_PARAMETER_VALUE(value) {if(value == VST_ERROR_STATUS::MISSING_PARAMETER_VALUE) return value;}

#define CLOSE_HANDLE_IF(h) if(h != INVALID_HANDLE_VALUE){ CloseHandle(h); h = INVALID_HANDLE_VALUE; }
#define IF_ERROR_RETURN(b) if(b == FALSE){ return b; }
Expand Down
Loading

0 comments on commit ff8556f

Please sign in to comment.