Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions NAM/conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
output_ptr[off + 7] += w7 * input_ptr[off + 7];
}
}
else if (channels == 3)
{
const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2];
for (int f = 0; f < num_frames; f++)
{
const int off = f * 3;
output_ptr[off] += w0 * input_ptr[off];
output_ptr[off + 1] += w1 * input_ptr[off + 1];
output_ptr[off + 2] += w2 * input_ptr[off + 2];
}
}
else
{
// General depthwise path with loop unrolling
Expand Down Expand Up @@ -349,6 +360,52 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
(w0_10 * i0_0 + w0_11 * i0_1) + (w1_10 * i1_0 + w1_11 * i1_1) + (w2_10 * i2_0 + w2_11 * i2_1);
}
}
else if (kernel_size == 6 && out_ch == 3 && in_ch == 3)
{
// Fused 3x3 kernel_size=6: read all 6 input blocks and compute in one pass
const long dil = this->_dilation;
auto in0 = _input_buffer.Read(num_frames, 5 * dil);
auto in1 = _input_buffer.Read(num_frames, 4 * dil);
auto in2 = _input_buffer.Read(num_frames, 3 * dil);
auto in3 = _input_buffer.Read(num_frames, 2 * dil);
auto in4 = _input_buffer.Read(num_frames, dil);
auto in5 = _input_buffer.Read(num_frames, 0);

const float* __restrict__ in_ptrs[6] = {
in0.data(), in1.data(), in2.data(),
in3.data(), in4.data(), in5.data()
};
float* __restrict__ output_ptr = _output.data();

// Cache all 54 weights on stack (6 taps x 3x3 matrix, column-major)
float w[6][9];
for (int k = 0; k < 6; k++)
{
const float* wp = this->_weight[k].data();
for (int j = 0; j < 9; j++)
w[k][j] = wp[j];
}

for (int f = 0; f < num_frames; f++)
{
const int off = f * 3;
float o0 = 0.0f, o1 = 0.0f, o2 = 0.0f;

for (int k = 0; k < 6; k++)
{
const float i0 = in_ptrs[k][off];
const float i1 = in_ptrs[k][off + 1];
const float i2 = in_ptrs[k][off + 2];
o0 += w[k][0] * i0 + w[k][3] * i1 + w[k][6] * i2;
o1 += w[k][1] * i0 + w[k][4] * i1 + w[k][7] * i2;
o2 += w[k][2] * i0 + w[k][5] * i1 + w[k][8] * i2;
}

output_ptr[off] = o0;
output_ptr[off + 1] = o1;
output_ptr[off + 2] = o2;
}
}
else
{
// General inline GEMM path uses += accumulation, so needs setZero
Expand Down
101 changes: 91 additions & 10 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int nu
void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, const int num_frames)
{
assert(num_frames <= _output.cols());
#ifdef NAM_USE_INLINE_GEMM
bool bias_fused = false;
#endif

if (this->_is_depthwise)
{
Expand Down Expand Up @@ -499,6 +502,17 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
output_ptr[f * 2 + 1] = w1 * in_val;
}
}
else if (out_ch == 3 && in_ch == 1)
{
const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2];
for (int f = 0; f < num_frames; f++)
{
const float in_val = input_ptr[f * in_stride];
output_ptr[f * 3] = w0 * in_val;
output_ptr[f * 3 + 1] = w1 * in_val;
output_ptr[f * 3 + 2] = w2 * in_val;
}
}
else if (out_ch == 4 && in_ch == 1)
{
const float w0 = weight_ptr[0], w1 = weight_ptr[1];
Expand All @@ -521,6 +535,28 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
output_ptr[f] = w0 * in_col[0] + w1 * in_col[1];
}
}
else if (out_ch == 1 && in_ch == 3)
{
const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2];
if (this->_do_bias)
{
const float b0 = this->_bias(0);
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
output_ptr[f] = w0 * in_col[0] + w1 * in_col[1] + w2 * in_col[2] + b0;
}
bias_fused = true;
}
else
{
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
output_ptr[f] = w0 * in_col[0] + w1 * in_col[1] + w2 * in_col[2];
}
}
}
else if (out_ch == 2 && in_ch == 2)
{
// 2x2 fully unrolled
Expand Down Expand Up @@ -583,15 +619,33 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2];
const float w01 = weight_ptr[3], w11 = weight_ptr[4], w21 = weight_ptr[5];
const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8];
for (int f = 0; f < num_frames; f++)
if (this->_do_bias)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
const float i0 = in_col[0];
const float i1 = in_col[1];
const float i2 = in_col[2];
output_ptr[f * 3] = w00 * i0 + w01 * i1 + w02 * i2;
output_ptr[f * 3 + 1] = w10 * i0 + w11 * i1 + w12 * i2;
output_ptr[f * 3 + 2] = w20 * i0 + w21 * i1 + w22 * i2;
const float b0 = this->_bias(0), b1 = this->_bias(1), b2 = this->_bias(2);
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
const float i0 = in_col[0];
const float i1 = in_col[1];
const float i2 = in_col[2];
output_ptr[f * 3] = w00 * i0 + w01 * i1 + w02 * i2 + b0;
output_ptr[f * 3 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + b1;
output_ptr[f * 3 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + b2;
}
bias_fused = true;
}
else
{
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
const float i0 = in_col[0];
const float i1 = in_col[1];
const float i2 = in_col[2];
output_ptr[f * 3] = w00 * i0 + w01 * i1 + w02 * i2;
output_ptr[f * 3 + 1] = w10 * i0 + w11 * i1 + w12 * i2;
output_ptr[f * 3 + 2] = w20 * i0 + w21 * i1 + w22 * i2;
}
}
}
else if (out_ch == 4 && in_ch == 4)
Expand Down Expand Up @@ -673,8 +727,21 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
}
else
{
// Fall back to Eigen for larger matrices where it's more efficient
_output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames);
// Generic inline GEMM for any matrix size (avoids Eigen overhead for small matrices)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if you want to maybe make some configurable threshold for going back to Eigen and default it so that maybe 8 starts using Eigen? then that seems reasonable

[I also wonder if this is us just re-inventing Eigen ;) ]

for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
float* __restrict__ out_col = output_ptr + f * out_ch;
for (int o = 0; o < out_ch; o++)
{
float sum = 0.0f;
for (int i = 0; i < in_ch; i++)
{
sum += weight_ptr[i * out_ch + o] * in_col[i];
}
out_col[o] = sum;
}
}
}
#else
// Single GEMM for all cases - block-diagonal zero structure handles grouping
Expand All @@ -685,6 +752,8 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
if (this->_do_bias)
{
#ifdef NAM_USE_INLINE_GEMM
if (!bias_fused)
{
const int out_ch = (int)get_out_channels();
float* __restrict__ output_ptr = _output.data();
const float* __restrict__ bias_ptr = this->_bias.data();
Expand All @@ -700,6 +769,17 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
output_ptr[off + 1] += b1;
}
}
else if (out_ch == 3)
{
const float b0 = bias_ptr[0], b1 = bias_ptr[1], b2 = bias_ptr[2];
for (int f = 0; f < num_frames; f++)
{
const int off = f * 3;
output_ptr[off] += b0;
output_ptr[off + 1] += b1;
output_ptr[off + 2] += b2;
}
}
else if (out_ch == 4)
{
const float b0 = bias_ptr[0], b1 = bias_ptr[1];
Expand All @@ -724,6 +804,7 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
}
}
}
} // !bias_fused
#else
_output.leftCols(num_frames).colwise() += this->_bias;
#endif
Expand Down
85 changes: 58 additions & 27 deletions NAM/film.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,47 +98,78 @@ class FiLM
if (_do_shift)
{
// scale = top input_dim rows, shift = bottom input_dim rows
for (int f = 0; f < num_frames; f++)
if (input_dim == 3)
{
const float* __restrict__ in_col = input_ptr + f * input_stride;
const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows;
const float* __restrict__ shift_col = scale_col + input_dim;
float* __restrict__ out_col = output_ptr + f * input_dim;

int i = 0;
for (; i + 3 < input_dim; i += 4)
for (int f = 0; f < num_frames; f++)
{
out_col[i] = in_col[i] * scale_col[i] + shift_col[i];
out_col[i + 1] = in_col[i + 1] * scale_col[i + 1] + shift_col[i + 1];
out_col[i + 2] = in_col[i + 2] * scale_col[i + 2] + shift_col[i + 2];
out_col[i + 3] = in_col[i + 3] * scale_col[i + 3] + shift_col[i + 3];
const float* __restrict__ in_col = input_ptr + f * input_stride;
const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows;
const float* __restrict__ shift_col = scale_col + 3;
float* __restrict__ out_col = output_ptr + f * 3;
out_col[0] = in_col[0] * scale_col[0] + shift_col[0];
out_col[1] = in_col[1] * scale_col[1] + shift_col[1];
out_col[2] = in_col[2] * scale_col[2] + shift_col[2];
}
for (; i < input_dim; i++)
}
else
{
for (int f = 0; f < num_frames; f++)
{
out_col[i] = in_col[i] * scale_col[i] + shift_col[i];
const float* __restrict__ in_col = input_ptr + f * input_stride;
const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows;
const float* __restrict__ shift_col = scale_col + input_dim;
float* __restrict__ out_col = output_ptr + f * input_dim;

int i = 0;
for (; i + 3 < input_dim; i += 4)
{
out_col[i] = in_col[i] * scale_col[i] + shift_col[i];
out_col[i + 1] = in_col[i + 1] * scale_col[i + 1] + shift_col[i + 1];
out_col[i + 2] = in_col[i + 2] * scale_col[i + 2] + shift_col[i + 2];
out_col[i + 3] = in_col[i + 3] * scale_col[i + 3] + shift_col[i + 3];
}
for (; i < input_dim; i++)
{
out_col[i] = in_col[i] * scale_col[i] + shift_col[i];
}
}
}
}
else
{
// scale only
for (int f = 0; f < num_frames; f++)
if (input_dim == 3)
{
const float* __restrict__ in_col = input_ptr + f * input_stride;
const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows;
float* __restrict__ out_col = output_ptr + f * input_dim;

int i = 0;
for (; i + 3 < input_dim; i += 4)
for (int f = 0; f < num_frames; f++)
{
out_col[i] = in_col[i] * scale_col[i];
out_col[i + 1] = in_col[i + 1] * scale_col[i + 1];
out_col[i + 2] = in_col[i + 2] * scale_col[i + 2];
out_col[i + 3] = in_col[i + 3] * scale_col[i + 3];
const float* __restrict__ in_col = input_ptr + f * input_stride;
const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows;
float* __restrict__ out_col = output_ptr + f * 3;
out_col[0] = in_col[0] * scale_col[0];
out_col[1] = in_col[1] * scale_col[1];
out_col[2] = in_col[2] * scale_col[2];
}
for (; i < input_dim; i++)
}
else
{
for (int f = 0; f < num_frames; f++)
{
out_col[i] = in_col[i] * scale_col[i];
const float* __restrict__ in_col = input_ptr + f * input_stride;
const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows;
float* __restrict__ out_col = output_ptr + f * input_dim;

int i = 0;
for (; i + 3 < input_dim; i += 4)
{
out_col[i] = in_col[i] * scale_col[i];
out_col[i + 1] = in_col[i + 1] * scale_col[i + 1];
out_col[i + 2] = in_col[i + 2] * scale_col[i + 2];
out_col[i + 3] = in_col[i + 3] * scale_col[i + 3];
}
for (; i < input_dim; i++)
{
out_col[i] = in_col[i] * scale_col[i];
}
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,8 @@ std::unique_ptr<DSP> create_dsp(std::unique_ptr<ModelConfig> config, std::vector
{
auto out = config->create(std::move(weights), metadata.sample_rate);
apply_metadata(*out, metadata);
// "pre-warm" the model to settle initial conditions
// Can this be removed now that it's part of Reset()?
out->prewarm();
// Prewarm is left to the caller so it can call SetMaxBufferSize() first.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[C] For backward compatibility, can you instead make the "4096" configurable at compilation (and keep 4096 the default)?

// On embedded targets the default (4096) wastes memory and time.
return out;
}

Expand Down
Loading