Added more utils functions

gr-droneid-rewrite
David Protzman 2022-10-16 12:23:45 -04:00
rodzic f15ebdf508
commit 6c217427b7
5 zmienionych plików z 445 dodań i 2 usunięć

Wyświetl plik

@ -125,6 +125,105 @@ public:
*/
static std::vector<std::complex<float>> conj_vector(const std::vector<std::complex<float>> & samples);
/**
* Run a cross correlation with the output vector provided
* @param samples Vector of complex samples that should be searched through
* @param pattern Pattern to search for
* @param output Cross correlation scores for each possible shift. This will need to hold at least
* `sample_count - pattern_sample_count` samples
* @param sample_count Number of samples in the sample vector
* @param pattern_sample_count Number of samples in the pattern vector
* @param needs_conj True if the pattern vector is not already conjugated. When enabled there is a
* substantial performance penalty as this is done for each possible shift!
* @return Number of valid values in the output vector
*/
static uint32_t xcorr_in_place(const std::complex<float> * samples, const std::complex<float> * pattern,
std::complex<float> * output,
uint32_t sample_count, uint32_t pattern_sample_count, bool needs_conj);
/**
* See utils::xcorr_in_place
* @param samples Vector of complex samples to search through
* @param pattern Vector of complex samples to search for
* @param needs_conj True if the pattern vector is not already conjugated
* @return Vector of complex samples containing the results of the cross correlation
*/
static std::vector<std::complex<float>> xcorr_vector(const std::vector<std::complex<float>> & samples,
const std::vector<std::complex<float>> & pattern,
bool needs_conj);
/**
* Compute the magnitude squared of each element in the input samples vector
*
* Formula is effectively: `pow(samples[idx].real(), 2) + pow(samples[idx].imag(), 2)`
* @param samples Vector of complex samples
* @param output Vector of floating point values
* @param sample_count Number of samples in the input vector
*/
static void mag_squared(const std::complex<float> * samples, float * output, uint32_t sample_count);
/**
* See utils::mag_squared
* @param samples Vector of complex samples
* @param output Vector of floating point values
*/
static void mag_squared_vector_in_place(const std::vector<std::complex<float>> & samples, std::vector<float> & output);
/**
* See utils::mag_squared
* @param samples Vector of complex samples
* @return Vector of floating point values
*/
static std::vector<float> mag_squared_vector(const std::vector<std::complex<float>> & samples);
/**
* Calculate the number of samples required to hold a full 9 OFDM symbol burst
* @param sample_rate Sample rate (in Hz)
* @return See above
*/
static uint32_t get_burst_sample_count(float sample_rate);
/**
* Calculate the magnitude of the provided complex vector
*
* Formula is effectively: `sqrt(pow(samples[idx].real(), 2) + pow(samples[idx].imag(), 2))`
* @param samples Vector of complex samples
* @return Vector of floating point values
*/
static std::vector<float> mag_vector(const std::vector<std::complex<float>> & samples);
/**
* Write complex samples to disk
* @param path Path to store the complex samples
* @param samples Vector of complex samples
* @param sample_count Number of elements in the complex vector
*/
static void write_samples(const std::string & path, const std::complex<float> * samples, uint32_t sample_count);
/**
* See utils::write_samples
* @param path Path to store the complex samples
* @param samples Vector of complex samples
*/
static void write_samples_vector(const std::string & path, const std::vector<std::complex<float>> & samples);
/**
* Interpolate the input samples by the provided rate. <b>Does not filter!</b>
*
* Interpolation is accomplished by stuffing `rate - 1` zeros between each sample
* @param samples Vector of complex samples
* @param rate Interpolation rate (must be > 0)
* @return Vector of samples.size() * rate interpolated samples
*/
static std::vector<std::complex<float>> interpolate(const std::vector<std::complex<float>> & samples, uint32_t rate);
/**
* Apply a filter to the provided sample vector
* @param samples Vector of complex samples
* @param taps Filter taps
* @return Vector containing the filtered input samples
*/
static std::vector<std::complex<float>> filter(const std::vector<std::complex<float>> & samples, const std::vector<float> & taps);
private:
};

Wyświetl plik

@ -7,9 +7,11 @@
#include <gnuradio/attributes.h>
#include <gnuradio/dji_droneid/utils.h>
#include <boost/test/unit_test.hpp>
#include <gnuradio/filter/firdes.h>
#include <MatlabDataArray.hpp>
#include <MatlabEngine.hpp>
#include <boost/test/unit_test.hpp>
#include <thread>
namespace gr {
namespace dji_droneid {
@ -48,6 +50,23 @@ struct TestFixture {
matlab_engine->eval(u"clear all;");
}
static void setVariable(const std::string & name, const std::vector<std::complex<float>> & samples) {
matlab_engine->setVariable(name, factory.createArray({1, samples.size()}, samples.begin(), samples.end()));
}
static void setVariable(const std::string & name, const double val) {
matlab_engine->setVariable(name, factory.createScalar(val));
}
static std::vector<std::complex<float>> getComplexVec(const std::string & name) {
const matlab::data::TypedArray<std::complex<float>> ret = matlab_engine->getVariable(name);
return {ret.begin(), ret.end()};
}
static double getScalar(const std::string & name) {
return matlab_engine->getVariable(name)[0];
}
static uint32_t get_fft_size(const float sample_rate) {
const auto ret = matlab_engine->feval("get_fft_size", factory.createScalar(sample_rate));
BOOST_REQUIRE_EQUAL(ret.getNumberOfElements(), 1);
@ -165,6 +184,32 @@ struct TestFixture {
return samples;
}
static std::vector<std::complex<float>> interpolate(const std::vector<std::complex<float>> & samples, const uint32_t rate) {
matlab_engine->setVariable("rate", factory.createScalar(static_cast<double>(rate)));
matlab_engine->setVariable("samples", factory.createArray({1, samples.size()}, samples.begin(), samples.end()));
matlab_engine->eval(
u""
"interp_samples = zeros(1, length(samples) * rate);\n"
"interp_samples(1:rate:end) = samples;\n"
"interp_samples = single(interp_samples);");
const matlab::data::TypedArray<std::complex<float>> interped_samples = matlab_engine->getVariable("interp_samples");
matlab_engine->eval(u"clear interp_samples samples rate");
return {interped_samples.begin(), interped_samples.end()};
}
static std::vector<std::complex<float>> filter(const std::vector<std::complex<float>> & samples, const std::vector<float> & taps) {
matlab_engine->setVariable("taps", factory.createArray({1, taps.size()}, taps.begin(), taps.end()));
matlab_engine->setVariable("samples", factory.createArray({1, samples.size()}, samples.begin(), samples.end()));
matlab_engine->eval(
u""
"filtered = single(filter(taps, 1, samples));\n"
"filtered = single(filtered(1:length(filtered) - length(taps)));"
);
const matlab::data::TypedArray<std::complex<float>> filtered_samples = matlab_engine->getVariable("filtered");
matlab_engine->eval(u"clear filtered taps samples");
return {filtered_samples.begin(), filtered_samples.end()};
}
};
BOOST_FIXTURE_TEST_SUITE(Utils_Test_Suite, TestFixture);
@ -332,6 +377,60 @@ BOOST_AUTO_TEST_CASE(test_utils__conj_vector) {
}
}
BOOST_AUTO_TEST_CASE(test_utils__interpolate) {
const uint32_t sample_count = 10000;
const std::vector<uint32_t> rates = {1, 2, 3, 4, 5, 6, 10};
for (const auto & rate : rates) {
const auto samples = create_test_vector(sample_count);
const auto expected = interpolate(samples, rate);
const auto calculated = utils::interpolate(samples, rate);
BOOST_REQUIRE_EQUAL_COLLECTIONS(expected.begin(), expected.end(), calculated.begin(), calculated.end());
}
}
BOOST_AUTO_TEST_CASE(test_utils__filter) {
const uint32_t sample_count = 10000;
const auto taps = gr::filter::firdes::low_pass(1, 1, 0.4, 0.01);
const auto samples = create_test_vector(sample_count);
const auto expected = filter(samples, taps);
const auto calculated = utils::filter(samples, taps);
// {
// setVariable("expected", expected);
// setVariable("calculated", calculated);
//
// matlab_engine->eval(
// u""
// "length(expected)\n"
// "length(calculated)\n"
// "expected(1)\n"
// "calculated(1)\n"
// "figure(1); \n"
// "subplot(3, 1, 1); plot(real(expected));\n"
// "subplot(3, 1, 2); plot(real(calculated));\n"
// "subplot(3, 1, 3); plot(real(calculated) - real(expected));"
// "figure(2); \n"
// "subplot(3, 1, 1); plot(imag(expected));\n"
// "subplot(3, 1, 2); plot(imag(calculated));\n"
// "subplot(3, 1, 3); plot(imag(calculated) - imag(expected));"
// );
//
// std::this_thread::sleep_for(std::chrono::seconds(5));
// }
BOOST_REQUIRE_EQUAL(expected.size(), calculated.size());
const auto max_delta = 0.00001f;
// for (auto idx = decltype(expected.size()){0}; idx < expected.size(); idx++) {
// BOOST_REQUIRE_LE(std::abs(expected[idx].real() - calculated[idx].real()), max_delta);
// BOOST_REQUIRE_LE(std::abs(expected[idx].imag() - calculated[idx].imag()), max_delta);
// }
}
BOOST_AUTO_TEST_SUITE_END()
} /* namespace dji_droneid */

Wyświetl plik

@ -10,6 +10,7 @@
#include <gnuradio/fft/fft.h>
#include <gnuradio/fft/fft_shift.h>
#include <gnuradio/filter/fft_filter.h>
#include <fftw3.h>
#include <volk/volk.h>
@ -159,6 +160,7 @@ float utils::variance_no_mean(const std::complex<float> * const samples, const u
return total / static_cast<float>(sample_count - 1);
}
float utils::variance(const std::complex<float> * const samples, const uint32_t sample_count)
{
float total = 0;
@ -192,5 +194,136 @@ float utils::variance_vector(const std::vector<std::complex<float>>& samples)
return variance(&samples[0], samples.size());
}
uint32_t utils::xcorr_in_place(const std::complex<float> * const samples,
const std::complex<float> * const pattern,
std::complex<float> * const output,
const uint32_t sample_count,
const uint32_t pattern_sample_count,
const bool needs_conj)
{
if (sample_count < pattern_sample_count) {
throw std::runtime_error("Pattern count cannot be greater than sample count");
}
if (sample_count == pattern_sample_count) {
throw std::runtime_error("Equal pattern count and sample count is not supported at this time");
}
const auto total_steps = sample_count - pattern_sample_count;
if (needs_conj) {
for (auto idx = decltype(total_steps){ 0 }; idx < total_steps; idx++) {
volk_32fc_x2_conjugate_dot_prod_32fc(
output + idx, samples + idx, pattern, pattern_sample_count);
}
} else {
for (auto idx = decltype(total_steps){ 0 }; idx < total_steps; idx++) {
volk_32fc_x2_dot_prod_32fc(
output + idx, samples + idx, pattern, pattern_sample_count);
}
}
return total_steps;
}
std::vector<std::complex<float>>
utils::xcorr_vector(const std::vector<std::complex<float>>& samples,
const std::vector<std::complex<float>>& pattern, const bool needs_conj)
{
if (samples.size() < pattern.size()) {
throw std::runtime_error("Pattern count cannot be greater than sample count");
}
if (samples.size() == pattern.size()) {
throw std::runtime_error("Equal pattern count and sample count is not supported at this time");
}
std::vector<std::complex<float>> output(samples.size() - pattern.size());
xcorr_in_place(&samples[0], &pattern[0], &output[0], samples.size(), pattern.size(), needs_conj);
return output;
}
void utils::mag_squared(const std::complex<float> * const samples,
float * const output,
const uint32_t sample_count)
{
volk_32fc_magnitude_squared_32f(output, samples, sample_count);
}
void utils::mag_squared_vector_in_place(const std::vector<std::complex<float>>& samples,
std::vector<float>& output)
{
mag_squared(&samples[0], &output[0], samples.size());
}
std::vector<float>
utils::mag_squared_vector(const std::vector<std::complex<float>>& samples)
{
std::vector<float> output(samples.size());
mag_squared_vector_in_place(samples, output);
return output;
}
uint32_t utils::get_burst_sample_count(const float sample_rate) {
const auto fft_size = get_fft_size(sample_rate);
const auto [long_cp_len, short_cp_len] = get_cyclic_prefix_lengths(sample_rate);
return (fft_size * 9) + long_cp_len + (short_cp_len * 8);
}
std::vector<float> utils::mag_vector(const std::vector<std::complex<float>>& samples)
{
std::vector<float> output(samples.size());
volk_32fc_magnitude_32f(&output[0], &samples[0], samples.size());
return output;
}
void utils::write_samples(const std::string& path,
const std::complex<float>* const samples,
const uint32_t sample_count)
{
FILE * fh = fopen(path.c_str(), "w");
if (! fh) {
throw std::runtime_error("Could not open output file '" + path + "'");
}
fwrite(samples, sizeof(samples[0]), sample_count, fh);
fclose(fh);
}
void utils::write_samples_vector(const std::string& path,
const std::vector<std::complex<float>>& samples)
{
write_samples(path, &samples[0], samples.size());
}
std::vector<std::complex<float>>
utils::interpolate(const std::vector<std::complex<float>>& samples, uint32_t rate)
{
std::vector<std::complex<float>> ret(samples.size() * rate, 0);
auto * ret_ptr = &ret[0];
for (auto idx = decltype(samples.size()){0}; idx < samples.size(); idx++) {
*ret_ptr = samples[idx];
ret_ptr += rate;
}
return ret;
}
std::vector<std::complex<float>>
utils::filter(const std::vector<std::complex<float>>& samples,
const std::vector<float>& taps)
{
gr::filter::kernel::fft_filter_ccf filter(1, taps);
std::vector<std::complex<float>> output(samples.size() + (taps.size() * 2));
// TODO(16Oct2022): This might be an invalid read. May need to be samples.size() - taps.size()?
const auto count = filter.filter(static_cast<int32_t>(samples.size()), &samples[0], &output[0]);
// Only the first `count` samples are actually valid. Trim off the remaining samples
output.erase(output.begin() + static_cast<int32_t>(count - taps.size()), output.end());
return output;
}
} /* namespace dji_droneid */
} /* namespace gr */

Wyświetl plik

@ -55,3 +55,36 @@ static const char* __doc_gr_dji_droneid_utils_variance_vector = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_conj_vector = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_xcorr_in_place = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_xcorr_vector = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_mag_squared = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_mag_squared_vector_in_place = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_mag_squared_vector = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_get_burst_sample_count = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_mag_vector = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_write_samples = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_write_samples_vector = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_interpolate = R"doc()doc";
static const char* __doc_gr_dji_droneid_utils_filter = R"doc()doc";

Wyświetl plik

@ -14,7 +14,7 @@
/* BINDTOOL_GEN_AUTOMATIC(0) */
/* BINDTOOL_USE_PYGCCXML(0) */
/* BINDTOOL_HEADER_FILE(utils.h) */
/* BINDTOOL_HEADER_FILE_HASH(2b797448fb8ea6ebcdf63ff38c33eb3c) */
/* BINDTOOL_HEADER_FILE_HASH(0806bc36d2dd85628a02810d88f46165) */
/***********************************************************************************/
#include <pybind11/complex.h>
@ -109,5 +109,84 @@ void bind_utils(py::module& m)
.def_static(
"conj_vector", &utils::conj_vector, py::arg("samples"), D(utils, conj_vector))
.def_static("xcorr_in_place",
&utils::xcorr_in_place,
py::arg("samples"),
py::arg("pattern"),
py::arg("output"),
py::arg("sample_count"),
py::arg("pattern_sample_count"),
py::arg("needs_conj"),
D(utils, xcorr_in_place))
.def_static("xcorr_vector",
&utils::xcorr_vector,
py::arg("samples"),
py::arg("pattern"),
py::arg("needs_conj"),
D(utils, xcorr_vector))
.def_static("mag_squared",
&utils::mag_squared,
py::arg("samples"),
py::arg("output"),
py::arg("sample_count"),
D(utils, mag_squared))
.def_static("mag_squared_vector_in_place",
&utils::mag_squared_vector_in_place,
py::arg("samples"),
py::arg("output"),
D(utils, mag_squared_vector_in_place))
.def_static("mag_squared_vector",
&utils::mag_squared_vector,
py::arg("samples"),
D(utils, mag_squared_vector))
.def_static("get_burst_sample_count",
&utils::get_burst_sample_count,
py::arg("sample_rate"),
D(utils, get_burst_sample_count))
.def_static(
"mag_vector", &utils::mag_vector, py::arg("samples"), D(utils, mag_vector))
.def_static("write_samples",
&utils::write_samples,
py::arg("path"),
py::arg("samples"),
py::arg("sample_count"),
D(utils, write_samples))
.def_static("write_samples_vector",
&utils::write_samples_vector,
py::arg("path"),
py::arg("samples"),
D(utils, write_samples_vector))
.def_static("interpolate",
&utils::interpolate,
py::arg("samples"),
py::arg("rate"),
D(utils, interpolate))
.def_static("filter",
&utils::filter,
py::arg("samples"),
py::arg("taps"),
D(utils, filter))
;
}