Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions NAM/convnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,22 +322,27 @@ void nam::convnet::ConvNet::_rewind_buffers_()
this->Buffer::_rewind_buffers_();
}

// Config parser
nam::convnet::ConvNetConfig nam::convnet::parse_config_json(const nlohmann::json& config)
{
ConvNetConfig c;
c.channels = config["channels"];
c.dilations = config["dilations"].get<std::vector<int>>();
c.batchnorm = config["batchnorm"];
c.activation = activations::ActivationConfig::from_json(config["activation"]);
c.groups = config.value("groups", 1);
c.in_channels = config.value("in_channels", 1);
c.out_channels = config.value("out_channels", 1);
return c;
}

// Factory
std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate)
{
const int channels = config["channels"];
const std::vector<int> dilations = config["dilations"];
const bool batchnorm = config["batchnorm"];
// Parse JSON into typed ActivationConfig at model loading boundary
const activations::ActivationConfig activation_config =
activations::ActivationConfig::from_json(config["activation"]);
const int groups = config.value("groups", 1); // defaults to 1
// Default to 1 channel in/out for backward compatibility
const int in_channels = config.value("in_channels", 1);
const int out_channels = config.value("out_channels", 1);
return std::make_unique<nam::convnet::ConvNet>(
in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups);
auto c = parse_config_json(config);
return std::make_unique<nam::convnet::ConvNet>(c.in_channels, c.out_channels, c.channels, c.dilations, c.batchnorm,
c.activation, weights, expectedSampleRate, c.groups);
}

namespace
Expand Down
17 changes: 17 additions & 0 deletions NAM/convnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,23 @@ class ConvNet : public Buffer
int PrewarmSamples() override { return mPrewarmSamples; };
};

/// \brief Configuration for a ConvNet model
struct ConvNetConfig
{
int channels;
std::vector<int> dilations;
bool batchnorm;
activations::ActivationConfig activation;
int groups;
int in_channels;
int out_channels;
};

/// \brief Parse ConvNet configuration from JSON
/// \param config JSON configuration object
/// \return ConvNetConfig
ConvNetConfig parse_config_json(const nlohmann::json& config);

/// \brief Factory function to instantiate ConvNet from JSON
/// \param config JSON configuration object
/// \param weights Model weights vector
Expand Down
26 changes: 20 additions & 6 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,30 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num
nam::Buffer::_advance_input_buffer_(num_frames);
}

// Config parser
nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config)
{
LinearConfig c;
c.receptive_field = config["receptive_field"];
c.bias = config["bias"];
c.in_channels = config.value("in_channels", 1);
c.out_channels = config.value("out_channels", 1);
return c;
}

// Factory
std::unique_ptr<nam::DSP> nam::linear::Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate)
{
const int receptive_field = config["receptive_field"];
const bool bias = config["bias"];
// Default to 1 channel in/out for backward compatibility
const int in_channels = config.value("in_channels", 1);
const int out_channels = config.value("out_channels", 1);
return std::make_unique<nam::Linear>(in_channels, out_channels, receptive_field, bias, weights, expectedSampleRate);
auto c = parse_config_json(config);
return std::make_unique<nam::Linear>(c.in_channels, c.out_channels, c.receptive_field, c.bias, weights,
expectedSampleRate);
}

// Register the factory
namespace
{
static nam::factory::Helper _register_Linear("Linear", nam::linear::Factory);
}

// NN modules =================================================================
Expand Down
15 changes: 15 additions & 0 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,21 @@ class Linear : public Buffer

namespace linear
{

/// \brief Configuration for a Linear model
struct LinearConfig
{
int receptive_field;
bool bias;
int in_channels;
int out_channels;
};

/// \brief Parse Linear configuration from JSON
/// \param config JSON configuration object
/// \return LinearConfig
LinearConfig parse_config_json(const nlohmann::json& config);

/// \brief Factory function to instantiate Linear model from JSON
/// \param config JSON configuration object
/// \param weights Model weights vector
Expand Down
128 changes: 85 additions & 43 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <sstream>
#include <stdexcept>
#include <unordered_set>
#include <variant>

#include "dsp.h"
#include "registry.h"
Expand All @@ -11,6 +12,7 @@
#include "convnet.h"
#include "wavenet.h"
#include "get_dsp.h"
#include "model_config.h"

namespace nam
{
Expand Down Expand Up @@ -146,62 +148,102 @@ std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConf
return get_dsp(conf);
}

struct OptionalValue
// =============================================================================
// Unified construction path
// =============================================================================

ModelConfig parse_model_config_json(const std::string& architecture, const nlohmann::json& config, double sample_rate)
{
bool have = false;
double value = 0.0;
};
if (architecture == "Linear")
return linear::parse_config_json(config);
else if (architecture == "LSTM")
return lstm::parse_config_json(config);
else if (architecture == "ConvNet")
return convnet::parse_config_json(config);
else if (architecture == "WaveNet")
return wavenet::parse_config_json(config, sample_rate);
else
throw std::runtime_error("Unknown architecture: " + architecture);
}

std::unique_ptr<DSP> get_dsp(dspData& conf)
namespace
{
verify_config_version(conf.version);

auto& architecture = conf.architecture;
nlohmann::json& config = conf.config;
std::vector<float>& weights = conf.weights;
OptionalValue loudness, inputLevel, outputLevel;
void apply_metadata(DSP& dsp, const ModelMetadata& metadata)
{
if (metadata.loudness.has_value())
dsp.SetLoudness(metadata.loudness.value());
if (metadata.input_level.has_value())
dsp.SetInputLevel(metadata.input_level.value());
if (metadata.output_level.has_value())
dsp.SetOutputLevel(metadata.output_level.value());
}

} // anonymous namespace

std::unique_ptr<DSP> create_dsp(ModelConfig config, std::vector<float> weights, const ModelMetadata& metadata)
{
const double sample_rate = metadata.sample_rate;

auto AssignOptional = [&conf](const std::string key, OptionalValue& v) {
if (conf.metadata.find(key) != conf.metadata.end())
{
if (!conf.metadata[key].is_null())
std::unique_ptr<DSP> out = std::visit(
[&](auto&& cfg) -> std::unique_ptr<DSP> {
using T = std::decay_t<decltype(cfg)>;
if constexpr (std::is_same_v<T, linear::LinearConfig>)
{
return std::make_unique<Linear>(cfg.in_channels, cfg.out_channels, cfg.receptive_field, cfg.bias, weights,
sample_rate);
}
else if constexpr (std::is_same_v<T, lstm::LSTMConfig>)
{
return std::make_unique<lstm::LSTM>(cfg.in_channels, cfg.out_channels, cfg.num_layers, cfg.input_size,
cfg.hidden_size, weights, sample_rate);
}
else if constexpr (std::is_same_v<T, convnet::ConvNetConfig>)
{
v.value = conf.metadata[key];
v.have = true;
return std::make_unique<convnet::ConvNet>(cfg.in_channels, cfg.out_channels, cfg.channels, cfg.dilations,
cfg.batchnorm, cfg.activation, weights, sample_rate, cfg.groups);
}
}
};
else if constexpr (std::is_same_v<T, wavenet::WaveNetConfig>)
{
return std::make_unique<wavenet::WaveNet>(cfg.in_channels, cfg.layer_array_params, cfg.head_scale,
cfg.with_head, std::move(weights), std::move(cfg.condition_dsp),
sample_rate);
}
},
std::move(config));

if (!conf.metadata.is_null())
{
AssignOptional("loudness", loudness);
AssignOptional("input_level_dbu", inputLevel);
AssignOptional("output_level_dbu", outputLevel);
}
const double expectedSampleRate = conf.expected_sample_rate;
apply_metadata(*out, metadata);
out->prewarm();
return out;
}

// Initialize using registry-based factory
std::unique_ptr<DSP> out =
nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate);
// =============================================================================
// get_dsp(dspData&) — now uses unified path
// =============================================================================

if (loudness.have)
{
out->SetLoudness(loudness.value);
}
if (inputLevel.have)
{
out->SetInputLevel(inputLevel.value);
}
if (outputLevel.have)
std::unique_ptr<DSP> get_dsp(dspData& conf)
{
verify_config_version(conf.version);

// Extract metadata from JSON
ModelMetadata metadata;
metadata.version = conf.version;
metadata.sample_rate = conf.expected_sample_rate;

if (!conf.metadata.is_null())
{
out->SetOutputLevel(outputLevel.value);
auto extract = [&conf](const std::string& key) -> std::optional<double> {
if (conf.metadata.find(key) != conf.metadata.end() && !conf.metadata[key].is_null())
return conf.metadata[key].get<double>();
return std::nullopt;
};
metadata.loudness = extract("loudness");
metadata.input_level = extract("input_level_dbu");
metadata.output_level = extract("output_level_dbu");
}

// "pre-warm" the model to settle initial conditions
// Can this be removed now that it's part of Reset()?
out->prewarm();

return out;
ModelConfig model_config = parse_model_config_json(conf.architecture, conf.config, conf.expected_sample_rate);
return create_dsp(std::move(model_config), std::move(conf.weights), metadata);
}

double get_sample_rate_from_nam_file(const nlohmann::json& j)
Expand Down
23 changes: 15 additions & 8 deletions NAM/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,25 @@ void nam::lstm::LSTM::_process_sample()
this->_output.noalias() += this->_head_bias;
}

// Config parser
nam::lstm::LSTMConfig nam::lstm::parse_config_json(const nlohmann::json& config)
{
LSTMConfig c;
c.num_layers = config["num_layers"];
c.input_size = config["input_size"];
c.hidden_size = config["hidden_size"];
c.in_channels = config.value("in_channels", 1);
c.out_channels = config.value("out_channels", 1);
return c;
}

// Factory to instantiate from nlohmann json
std::unique_ptr<nam::DSP> nam::lstm::Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate)
{
const int num_layers = config["num_layers"];
const int input_size = config["input_size"];
const int hidden_size = config["hidden_size"];
// Default to 1 channel in/out for backward compatibility
const int in_channels = config.value("in_channels", 1);
const int out_channels = config.value("out_channels", 1);
return std::make_unique<nam::lstm::LSTM>(
in_channels, out_channels, num_layers, input_size, hidden_size, weights, expectedSampleRate);
auto c = parse_config_json(config);
return std::make_unique<nam::lstm::LSTM>(c.in_channels, c.out_channels, c.num_layers, c.input_size, c.hidden_size,
weights, expectedSampleRate);
}

// Register the factory
Expand Down
15 changes: 15 additions & 0 deletions NAM/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ class LSTM : public DSP
Eigen::VectorXf _output;
};

/// \brief Configuration for an LSTM model
struct LSTMConfig
{
int num_layers;
int input_size;
int hidden_size;
int in_channels;
int out_channels;
};

/// \brief Parse LSTM configuration from JSON
/// \param config JSON configuration object
/// \return LSTMConfig
LSTMConfig parse_config_json(const nlohmann::json& config);

/// \brief Factory function to instantiate LSTM from JSON
/// \param config JSON configuration object
/// \param weights Model weights vector
Expand Down
51 changes: 51 additions & 0 deletions NAM/model_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once
// Unified model configuration types for both JSON and binary loaders.
// No circular dependencies: architecture headers define config structs,
// this header combines them into a variant.

#include <memory>
#include <optional>
#include <string>
#include <variant>
#include <vector>

#include "convnet.h"
#include "dsp.h"
#include "lstm.h"
#include "wavenet.h"

namespace nam
{

/// \brief Metadata common to all model formats
struct ModelMetadata
{
std::string version;
double sample_rate = -1.0;
std::optional<double> loudness;
std::optional<double> input_level;
std::optional<double> output_level;
};

/// \brief Variant of all architecture configs
using ModelConfig = std::variant<linear::LinearConfig, lstm::LSTMConfig, convnet::ConvNetConfig, wavenet::WaveNetConfig>;

/// \brief Construct a DSP object from a typed config, weights, and metadata
///
/// This is the single construction path used by both JSON and binary loaders.
/// Handles construction, metadata application, and prewarm.
/// \param config Architecture-specific configuration (variant)
/// \param weights Model weights (taken by value to allow move for WaveNet)
/// \param metadata Model metadata (version, sample rate, loudness, levels)
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> create_dsp(ModelConfig config, std::vector<float> weights, const ModelMetadata& metadata);

/// \brief Parse a ModelConfig from a JSON architecture name and config block
/// \param architecture Architecture name string (e.g., "WaveNet", "LSTM")
/// \param config JSON config block for this architecture
/// \param sample_rate Expected sample rate from metadata
/// \return ModelConfig variant
ModelConfig parse_model_config_json(const std::string& architecture, const nlohmann::json& config,
double sample_rate);

} // namespace nam
Loading