From 1d2698c0af153d349b12f2cd1bce54f0d80589c0 Mon Sep 17 00:00:00 2001 From: David Protzman Date: Thu, 28 Apr 2022 22:43:57 -0400 Subject: [PATCH] Added a new normalized xcorr function that's faster Does trade some accuracy in the process --- .../updated_scripts/normalized_xcorr_fast.m | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 matlab/updated_scripts/normalized_xcorr_fast.m diff --git a/matlab/updated_scripts/normalized_xcorr_fast.m b/matlab/updated_scripts/normalized_xcorr_fast.m new file mode 100644 index 0000000..72d6c28 --- /dev/null +++ b/matlab/updated_scripts/normalized_xcorr_fast.m @@ -0,0 +1,80 @@ +% A cross correlation function that takes some shortcuts to be faster than xcorr(x,y,0,'normalized') with some small +% tradeoffs in accuracy. Return values are normalized to be between 0 and 1.0 with 1.0 being a perfect match +% +% Will return a vector that is `length(filter)` samples shorter than the input +% +% Correlation peaks point to the *beginning* of the `filter` sequence +% +% @param input_samples Complex row/column vector of samples (must have at least as many samples as `filter`) +% @param filter Complex row/column vector to correlate for in `input_samples` +% @param varargin Variable arguments (see above) +% @return scores Vector of correlation scores as complex values (use `abs(x).^2` to get score in range 0-1.0) +function [scores] = normalized_xcorr_fast(input_samples, filter, varargin) + assert(isrow(input_samples) || iscolumn(input_samples), "Input samples must be row or column vector"); + assert(isrow(filter) || iscolumn(filter), "Filter must be a row or column vector"); + assert(mod(length(varargin), 2) == 0, "Varargs length must be a multiple of 2"); + + % Placeholder for any varargs that might be needed in the future + for idx=1:2:length(varargin) + key = varargin{idx}; + val = varargin{idx+1}; + + switch(key) + otherwise + error("Invalid vararg key '%s'", key); + end + end + + % Create the output vector using the same dimensions as the input samples vector. Not all samples can be + % computed, so don't include the last `length(filter)` samples + dims = size(input_samples); + if (dims(1) == 1) + scores = zeros(1, length(input_samples) - length(filter)); + else + scores = zeros(length(input_samples) - length(filter), 1); + end + + % Make the filter zero mean + filter = filter - mean(filter); + + % Will be using dot product, so the conjugate is needed + filter_conj = conj(filter); + + % Pre-calculate the variance, and the square root of the variance + filter_conj_var = var(filter_conj); + filter_conj_var_sqrt = sqrt(filter_conj_var); + + % Pre-calculate and convert to multiplication, the divisions that will need to happen later + window_size = length(filter); + recip_window_size = 1 / window_size; + recip_window_size_minus_one = 1 / (window_size - 1); + + % To prevent needing an if statement in the critical path, start the running sum with the first element + % missing, and the value being removed first in the loop set to 0. This means that on startup, the loop + % will work properly without needing a conditional + running_sum = sum(input_samples(2:window_size - 1)); + prev_val = 0; + + for idx=1:length(scores) + % Get the next `window_size` samples starting at the current offset + window = input_samples(idx:idx + window_size - 1); + + % Since the window is shifting to the right, subtract off the left-most value that was just removed and add + % on the new value on the right + running_sum = running_sum - prev_val + window(end); + + % Make the window zero mean by subtracting the average power + window = window - (running_sum * recip_window_size); + + % Compute the dot product + prod = sum(window .* filter_conj) * recip_window_size; + + % Get the variance of the window + x = sum(real(window).^2 + imag(window).^2) * recip_window_size_minus_one; + + % Divide the dot product result by the square root of the std deviation of both windows combined + scores(idx) = prod / (sqrt(x) * filter_conj_var_sqrt); + end + +end +