diff --git a/README.rst b/README.rst index 12f9892..eb1c3a0 100644 --- a/README.rst +++ b/README.rst @@ -95,8 +95,12 @@ and then "Load"): .. image:: images/Convolver.png -We go back to the spectrum measurement and set the uncorrected -spectrum as reference (to compare with later measurements). +We go back to the spectrum measurement and store the uncorrected +spectrum with the "Store" button (to compare with later measurements). +More measurements can be stored as well, for example where the microphone +is placed in different locatations, The total average of the stored +measurements is shown in orange + Measuring the equalized system gives this: .. image:: images/laptop-flattened-spectrum.png diff --git a/hifiscan/analyzer.py b/hifiscan/analyzer.py index 2823b30..1aadc46 100644 --- a/hifiscan/analyzer.py +++ b/hifiscan/analyzer.py @@ -1,4 +1,5 @@ import array +import types from functools import lru_cache from typing import List, NamedTuple, Optional, Tuple @@ -44,10 +45,10 @@ class Analyzer: x: np.ndarray y: np.ndarray rate: int - secs: float fmin: float fmax: float time: float + numMeasurements: int def __init__( self, f0: int, f1: int, secs: float, rate: int, ampl: float, @@ -58,18 +59,41 @@ class Analyzer: self.chirp, np.zeros(int(self.MAX_DELAY_SECS * rate)) ]) - self.secs = self.x.size / rate + self.y = np.zeros(self.x.size) self.rate = rate self.fmin = min(f0, f1) self.fmax = max(f0, f1) self.time = 0 + self.numMeasurements = 0 self._calibration = calibration self._target = target + self._sumH = np.zeros(self.X().size) - # Cache the methods in a way that allows garbage collection of self. - for meth in ['X', 'Y', 'H', 'H2', 'h', 'h_inv', 'spectrum', + def setCaching(self): + """ + Cache the main methods in a way that allows garbage collection of self. + Calling this method again will in effect clear the previous caching. + """ + for name in ['X', 'Y', 'calcH', 'H', 'H2', 'h', 'h_inv', 'spectrum', 'frequency', 'calibration', 'target']: - setattr(self, meth, lru_cache(getattr(self, meth))) + unbound = getattr(Analyzer, name) + bound = types.MethodType(unbound, self) + setattr(self, name, lru_cache(bound)) + + def addMeasurements(self, analyzer): + """Add measurements from other analyzer to this one.""" + if not self.isCompatible(analyzer): + raise ValueError('Incompatible analyzers') + self._sumH = self._sumH + analyzer._sumH + self.numMeasurements += analyzer.numMeasurements + self.setCaching() + + def isCompatible(self, analyzer): + """ + See if other analyzer is compatible for adding measurement to this one. + """ + return isinstance(analyzer, Analyzer) and np.array_equal( + analyzer.x, self.x) def findMatch(self, recording: array.array) -> bool: """ @@ -84,15 +108,19 @@ class Analyzer: corr = np.fft.ifft(X * Y).real idx = int(corr.argmax()) - self.x.size + 1 if idx >= 0: - self.y = np.array(recording[idx:idx + self.x.size], 'f') + self.y = np.array(recording[idx:idx + self.x.size]) + self.numMeasurements += 1 + self._sumH += self.calcH() + self.setCaching() return True return False def timedOut(self) -> bool: """See if time to find a match has exceeded the timeout limit.""" - return self.time > self.secs + self.TIMEOUT_SECS + return self.time > self.x.size / self.rate + self.TIMEOUT_SECS def frequency(self) -> np.ndarray: + """Frequency array, from 0 to the Nyquist frequency.""" return np.linspace(0, self.rate // 2, self.X().size) def freqRange(self, size: int = 0) -> slice: @@ -107,9 +135,11 @@ class Analyzer: return slice(i0, i1 + 1) def calibration(self) -> Optional[np.ndarray]: + """Interpolated calibration curve.""" return self.interpolateCorrection(self._calibration) def target(self) -> Optional[np.ndarray]: + """Interpolated target curve.""" return self.interpolateCorrection(self._target) def interpolateCorrection(self, corr: Correction) -> Optional[np.ndarray]: @@ -134,10 +164,9 @@ class Analyzer: def Y(self) -> np.ndarray: return np.fft.rfft(self.y) - def H(self) -> XY: + def calcH(self) -> np.ndarray: """ - Calculate complex-valued transfer function H in the - frequency domain. + Calculate transfer function H of the last measurement. """ X = self.X() Y = self.Y() @@ -145,13 +174,20 @@ class Analyzer: H = Y * np.conj(X) / (np.abs(X) ** 2 + 1e-3) if self._calibration: H *= 10 ** (-self.calibration() / 20) + H = np.abs(H) + return H + + def H(self) -> XY: + """ + Transfer function H averaged over all measurements. + """ freq = self.frequency() + H = self._sumH / (self.numMeasurements or 1) return XY(freq, H) - def H2(self, smoothing: float): + def H2(self, smoothing: float) -> XY: """Calculate smoothed squared transfer function |H|^2.""" freq, H = self.H() - H = np.abs(H) r = self.freqRange() H2 = np.empty_like(H) # Perform smoothing on the squared amplitude. diff --git a/hifiscan/app.py b/hifiscan/app.py index d9d19b2..31e6785 100644 --- a/hifiscan/app.py +++ b/hifiscan/app.py @@ -1,4 +1,5 @@ import asyncio +import copy import datetime as dt import logging import os @@ -68,9 +69,6 @@ class App(qt.QMainWindow): if analyzer.timedOut(): break - def setPaused(self): - self.paused = not self.paused - def plot(self, *_): if self.stack.currentIndex() == 0: self.plotSpectrum() @@ -92,7 +90,7 @@ class App(qt.QMainWindow): self.refSpectrumPlot.setData(*spectrum) def plotIR(self): - if self.refAnalyzer and self.useRefBox.isChecked(): + if self.refAnalyzer and self.useBox.currentIndex() == 0: analyzer = self.refAnalyzer else: analyzer = self.analyzer @@ -131,7 +129,7 @@ class App(qt.QMainWindow): self.saveDir = Path(filename).parent def saveIR(self): - if self.refAnalyzer and self.useRefBox.isChecked(): + if self.refAnalyzer and self.useBox.currentIndex() == 0: analyzer = self.refAnalyzer else: analyzer = self.analyzer @@ -151,16 +149,6 @@ class App(qt.QMainWindow): hifi.write_wav(filename, analyzer.rate, irInv) self.saveDir = Path(filename).parent - def setReference(self, withRef: bool): - if withRef: - if self.analyzer: - self.refAnalyzer = self.analyzer - self.plot() - else: - self.refAnalyzer = None - self.refSpectrumPlot.clear() - self.spectrumPlotWidget.repaint() - def run(self): """Run both the Qt and asyncio event loops.""" @@ -210,8 +198,6 @@ class App(qt.QMainWindow): self.spectrumSmoothing = pg.SpinBox( value=15, step=1, bounds=[0, 30]) self.spectrumSmoothing.sigValueChanging.connect(self.plot) - refBox = qt.QCheckBox('Reference') - refBox.stateChanged.connect(self.setReference) hbox = qt.QHBoxLayout() hbox.addStretch(1) @@ -229,8 +215,6 @@ class App(qt.QMainWindow): hbox.addSpacing(32) hbox.addWidget(qt.QLabel('Smoothing: ')) hbox.addWidget(self.spectrumSmoothing) - hbox.addSpacing(32) - hbox.addWidget(refBox) hbox.addStretch(1) vbox.addLayout(hbox) @@ -287,8 +271,11 @@ class App(qt.QMainWindow): value=15, step=1, bounds=[0, 30]) self.irSmoothing.sigValueChanging.connect(self.plot) self.kaiserBeta.sigValueChanging.connect(self.plot) - self.useRefBox = qt.QCheckBox('Use reference') - self.useRefBox.stateChanged.connect(self.plot) + + self.useBox = qt.QComboBox() + self.useBox.addItems(['Stored measurements', 'Last measurement']) + self.useBox.currentIndexChanged.connect(self.plot) + exportButton = qt.QPushButton('Export as WAV') exportButton.setShortcut('E') exportButton.setToolTip('') @@ -308,10 +295,10 @@ class App(qt.QMainWindow): hbox.addWidget(qt.QLabel('Smoothing: ')) hbox.addWidget(self.irSmoothing) hbox.addSpacing(32) - hbox.addWidget(self.useRefBox) - hbox.addSpacing(32) - hbox.addWidget(exportButton) + hbox.addWidget(qt.QLabel('Use: ')) + hbox.addWidget(self.useBox) hbox.addStretch(1) + hbox.addWidget(exportButton) vbox.addLayout(hbox) return topWidget @@ -379,20 +366,49 @@ class App(qt.QMainWindow): correctionsButton = qt.QPushButton('Corrections...') correctionsButton.pressed.connect(correctionsPressed) + def storeButtonClicked(): + if self.analyzer: + if self.analyzer.isCompatible(self.refAnalyzer): + self.refAnalyzer.addMeasurements(self.analyzer) + else: + self.refAnalyzer = copy.copy(self.analyzer) + measurementsLabel.setText( + f'Measurements: {self.refAnalyzer.numMeasurements}') + self.plot() + + def clearButtonClicked(): + self.refAnalyzer = None + self.refSpectrumPlot.clear() + measurementsLabel.setText('Measurements: ') + self.plot() + + measurementsLabel = qt.QLabel('Measurements: ') + + storeButton = qt.QPushButton('Store') + storeButton.clicked.connect(storeButtonClicked) + storeButton.setShortcut('S') + storeButton.setToolTip('') + + clearButton = qt.QPushButton('Clear') + clearButton.clicked.connect(clearButtonClicked) + clearButton.setShortcut('C') + clearButton.setToolTip('') + screenshotButton = qt.QPushButton('Screenshot') - screenshotButton.setShortcut('S') - screenshotButton.setToolTip('') screenshotButton.clicked.connect(self.screenshot) + def setPaused(): + self.paused = not self.paused + pauseButton = qt.QPushButton('Pause') pauseButton.setShortcut('Space') pauseButton.setToolTip('') pauseButton.setFocusPolicy(qtcore.Qt.FocusPolicy.NoFocus) - pauseButton.clicked.connect(self.setPaused) + pauseButton.clicked.connect(setPaused) exitButton = qt.QPushButton('Exit') - exitButton.setShortcut('Esc') - exitButton.setToolTip('') + exitButton.setShortcut('Ctrl+Q') + exitButton.setToolTip('Ctrl+Q') exitButton.clicked.connect(self.close) hbox = qt.QHBoxLayout() @@ -402,6 +418,10 @@ class App(qt.QMainWindow): hbox.addSpacing(64) hbox.addWidget(correctionsButton) hbox.addStretch(1) + hbox.addWidget(measurementsLabel) + hbox.addWidget(storeButton) + hbox.addWidget(clearButton) + hbox.addStretch(1) hbox.addWidget(screenshotButton) hbox.addSpacing(32) hbox.addWidget(pauseButton)