|
| 1 | +import argparse |
| 2 | +import math |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import matplotlib |
| 6 | + |
| 7 | +matplotlib.use('Agg') |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +from brainflow.board_shim import BoardShim, BoardIds, BrainFlowPresets |
| 12 | +from brainflow.data_filter import DataFilter, DetrendOperations, FilterTypes |
| 13 | + |
| 14 | + |
| 15 | +def build_time_axis(data, timestamp_channel, sampling_rate): |
| 16 | + timestamps = data[timestamp_channel, :] |
| 17 | + if timestamps.size == data.shape[1] and timestamps.size > 1 and np.all(np.isfinite(timestamps)): |
| 18 | + diffs = np.diff(timestamps) |
| 19 | + if np.count_nonzero(diffs > 0) > timestamps.size * 0.8: |
| 20 | + return timestamps - timestamps[0], 1.0 / np.median(diffs[diffs > 0]) |
| 21 | + return np.arange(data.shape[1], dtype=np.float64) / sampling_rate, float(sampling_rate) |
| 22 | + |
| 23 | + |
| 24 | +def get_active_channels(optics_data, optical_channels): |
| 25 | + active = [] |
| 26 | + active_labels = [] |
| 27 | + for index, channel in enumerate(optical_channels): |
| 28 | + row = optics_data[index, :] |
| 29 | + finite = row[np.isfinite(row)] |
| 30 | + if finite.size > 0 and np.std(finite) > 1e-9: |
| 31 | + active.append(index) |
| 32 | + active_labels.append(f'Optics {channel}') |
| 33 | + return active, active_labels |
| 34 | + |
| 35 | + |
| 36 | +def zscore(data): |
| 37 | + stddev = np.std(data) |
| 38 | + if stddev < 1e-12: |
| 39 | + return np.zeros(data.shape) |
| 40 | + return (data - np.mean(data)) / stddev |
| 41 | + |
| 42 | + |
| 43 | +def filter_active_channels(optics_data, active_indexes, sampling_rate, low_cut, high_cut): |
| 44 | + filtered = [] |
| 45 | + for index in active_indexes: |
| 46 | + row = np.asarray(optics_data[index, :], dtype=np.float64).copy() |
| 47 | + finite = np.isfinite(row) |
| 48 | + if not np.all(finite): |
| 49 | + row[~finite] = np.interp(np.flatnonzero(~finite), np.flatnonzero(finite), row[finite]) |
| 50 | + DataFilter.detrend(row, DetrendOperations.LINEAR.value) |
| 51 | + DataFilter.perform_bandpass( |
| 52 | + row, |
| 53 | + int(round(sampling_rate)), |
| 54 | + low_cut, |
| 55 | + high_cut, |
| 56 | + 4, |
| 57 | + FilterTypes.BUTTERWORTH_ZERO_PHASE.value, |
| 58 | + 0.0, |
| 59 | + ) |
| 60 | + filtered.append(row) |
| 61 | + return np.asarray(filtered) |
| 62 | + |
| 63 | + |
| 64 | +def find_local_peaks(signal, sampling_rate, min_bpm, max_bpm): |
| 65 | + if signal.size < 3: |
| 66 | + return np.asarray([], dtype=np.int64) |
| 67 | + |
| 68 | + min_distance = max(1, int(round(sampling_rate * 60.0 / max_bpm))) |
| 69 | + threshold = np.percentile(signal, 65.0) |
| 70 | + candidates = np.flatnonzero((signal[1:-1] > signal[:-2]) & (signal[1:-1] >= signal[2:])) + 1 |
| 71 | + candidates = candidates[signal[candidates] > threshold] |
| 72 | + |
| 73 | + peaks = [] |
| 74 | + for candidate in candidates: |
| 75 | + if not peaks or candidate - peaks[-1] >= min_distance: |
| 76 | + peaks.append(int(candidate)) |
| 77 | + elif signal[candidate] > signal[peaks[-1]]: |
| 78 | + peaks[-1] = int(candidate) |
| 79 | + peaks = np.asarray(peaks, dtype=np.int64) |
| 80 | + |
| 81 | + if peaks.size < 3: |
| 82 | + return peaks |
| 83 | + |
| 84 | + intervals = np.diff(peaks) / sampling_rate |
| 85 | + valid = (intervals >= 60.0 / max_bpm) & (intervals <= 60.0 / min_bpm) |
| 86 | + keep = np.concatenate(([True], valid)) |
| 87 | + return peaks[keep] |
| 88 | + |
| 89 | + |
| 90 | +def score_peaks(peaks, signal, sampling_rate, min_bpm, max_bpm): |
| 91 | + if peaks.size < 3: |
| 92 | + return -math.inf |
| 93 | + intervals = np.diff(peaks) / sampling_rate |
| 94 | + valid = (intervals >= 60.0 / max_bpm) & (intervals <= 60.0 / min_bpm) |
| 95 | + if np.count_nonzero(valid) < 2: |
| 96 | + return -math.inf |
| 97 | + valid_intervals = intervals[valid] |
| 98 | + rr_cv = np.std(valid_intervals) / np.mean(valid_intervals) |
| 99 | + return float(np.mean(signal[peaks]) - rr_cv) |
| 100 | + |
| 101 | + |
| 102 | +def select_pulse_signal(combined, sampling_rate, min_bpm, max_bpm): |
| 103 | + best_signal = combined |
| 104 | + best_peaks = find_local_peaks(combined, sampling_rate, min_bpm, max_bpm) |
| 105 | + best_score = score_peaks(best_peaks, combined, sampling_rate, min_bpm, max_bpm) |
| 106 | + |
| 107 | + inverted = -combined |
| 108 | + inverted_peaks = find_local_peaks(inverted, sampling_rate, min_bpm, max_bpm) |
| 109 | + inverted_score = score_peaks(inverted_peaks, inverted, sampling_rate, min_bpm, max_bpm) |
| 110 | + |
| 111 | + if inverted_score > best_score: |
| 112 | + best_signal = inverted |
| 113 | + best_peaks = inverted_peaks |
| 114 | + |
| 115 | + return best_signal, best_peaks |
| 116 | + |
| 117 | + |
| 118 | +def bpm_from_peaks(peaks, time_axis): |
| 119 | + if peaks.size < 3: |
| 120 | + return None |
| 121 | + intervals = np.diff(time_axis[peaks]) |
| 122 | + intervals = intervals[intervals > 0] |
| 123 | + if intervals.size == 0: |
| 124 | + return None |
| 125 | + return 60.0 / np.median(intervals) |
| 126 | + |
| 127 | + |
| 128 | +def spectrum(signal, sampling_rate, low_cut, high_cut): |
| 129 | + if signal.size < 4: |
| 130 | + return np.asarray([]), np.asarray([]), None |
| 131 | + centered = signal - np.mean(signal) |
| 132 | + windowed = centered * np.hanning(centered.size) |
| 133 | + freqs = np.fft.rfftfreq(windowed.size, d=1.0 / sampling_rate) |
| 134 | + power = np.abs(np.fft.rfft(windowed)) ** 2 |
| 135 | + band = (freqs >= low_cut) & (freqs <= high_cut) |
| 136 | + if not np.any(band): |
| 137 | + return freqs, power, None |
| 138 | + peak_freq = freqs[band][np.argmax(power[band])] |
| 139 | + return freqs, power, peak_freq * 60.0 |
| 140 | + |
| 141 | + |
| 142 | +def save_raw_plot(time_axis, optics_data, active_indexes, labels, output_file): |
| 143 | + fig, axes = plt.subplots(len(active_indexes), 1, figsize=(12, 2.1 * len(active_indexes)), sharex=True) |
| 144 | + axes = np.atleast_1d(axes) |
| 145 | + for axis, index, label in zip(axes, active_indexes, labels): |
| 146 | + axis.plot(time_axis, optics_data[index, :], linewidth=1.0) |
| 147 | + axis.set_ylabel(label) |
| 148 | + axis.grid(True) |
| 149 | + axes[-1].set_xlabel('Time, sec') |
| 150 | + fig.tight_layout() |
| 151 | + fig.savefig(output_file, dpi=150) |
| 152 | + plt.close(fig) |
| 153 | + |
| 154 | + |
| 155 | +def save_filtered_plot(time_axis, filtered, labels, output_file): |
| 156 | + fig, axes = plt.subplots(filtered.shape[0], 1, figsize=(12, 2.1 * filtered.shape[0]), sharex=True) |
| 157 | + axes = np.atleast_1d(axes) |
| 158 | + for axis, row, label in zip(axes, filtered, labels): |
| 159 | + axis.plot(time_axis, zscore(row), linewidth=1.0) |
| 160 | + axis.set_ylabel(label) |
| 161 | + axis.grid(True) |
| 162 | + axes[-1].set_xlabel('Time, sec') |
| 163 | + fig.tight_layout() |
| 164 | + fig.savefig(output_file, dpi=150) |
| 165 | + plt.close(fig) |
| 166 | + |
| 167 | + |
| 168 | +def save_pulse_plot(time_axis, pulse_signal, peaks, peak_bpm, spectrum_bpm, output_file): |
| 169 | + title_parts = [] |
| 170 | + if peak_bpm is not None: |
| 171 | + title_parts.append(f'peaks {peak_bpm:.1f} BPM') |
| 172 | + if spectrum_bpm is not None: |
| 173 | + title_parts.append(f'spectrum {spectrum_bpm:.1f} BPM') |
| 174 | + |
| 175 | + fig, axis = plt.subplots(1, 1, figsize=(12, 4)) |
| 176 | + axis.plot(time_axis, pulse_signal, linewidth=1.0) |
| 177 | + if peaks.size > 0: |
| 178 | + axis.plot(time_axis[peaks], pulse_signal[peaks], 'ro', markersize=3) |
| 179 | + axis.set_xlabel('Time, sec') |
| 180 | + axis.set_ylabel('Combined filtered optics, z-score') |
| 181 | + axis.set_title(', '.join(title_parts) if title_parts else 'Combined filtered optics') |
| 182 | + axis.grid(True) |
| 183 | + fig.tight_layout() |
| 184 | + fig.savefig(output_file, dpi=150) |
| 185 | + plt.close(fig) |
| 186 | + |
| 187 | + |
| 188 | +def save_spectrum_plot(freqs, power, spectrum_bpm, output_file): |
| 189 | + fig, axis = plt.subplots(1, 1, figsize=(10, 4)) |
| 190 | + if freqs.size > 0: |
| 191 | + axis.plot(freqs * 60.0, power, linewidth=1.0) |
| 192 | + if spectrum_bpm is not None: |
| 193 | + axis.axvline(spectrum_bpm, color='r', linestyle='--', linewidth=1.0) |
| 194 | + axis.set_xlabel('Frequency, BPM') |
| 195 | + axis.set_ylabel('Power') |
| 196 | + axis.set_xlim(30, 210) |
| 197 | + axis.grid(True) |
| 198 | + fig.tight_layout() |
| 199 | + fig.savefig(output_file, dpi=150) |
| 200 | + plt.close(fig) |
| 201 | + |
| 202 | + |
| 203 | +def main(): |
| 204 | + parser = argparse.ArgumentParser() |
| 205 | + parser.add_argument('--input-file', type=str, required=False, default='muse_anthena_optics_recording.csv') |
| 206 | + parser.add_argument('--output-prefix', type=str, required=False, default='') |
| 207 | + parser.add_argument('--discard-seconds', type=float, required=False, default=5.0) |
| 208 | + parser.add_argument('--low-cut', type=float, required=False, default=0.7) |
| 209 | + parser.add_argument('--high-cut', type=float, required=False, default=3.5) |
| 210 | + parser.add_argument('--min-bpm', type=float, required=False, default=40.0) |
| 211 | + parser.add_argument('--max-bpm', type=float, required=False, default=180.0) |
| 212 | + args = parser.parse_args() |
| 213 | + |
| 214 | + data = DataFilter.read_file(args.input_file) |
| 215 | + board_id = BoardIds.MUSE_S_ANTHENA_BOARD.value |
| 216 | + preset = BrainFlowPresets.ANCILLARY_PRESET |
| 217 | + optical_channels = BoardShim.get_optical_channels(board_id, preset) |
| 218 | + timestamp_channel = BoardShim.get_timestamp_channel(board_id, preset) |
| 219 | + expected_sampling_rate = BoardShim.get_sampling_rate(board_id, preset) |
| 220 | + |
| 221 | + time_axis, actual_sampling_rate = build_time_axis(data, timestamp_channel, expected_sampling_rate) |
| 222 | + start_index = int(np.searchsorted(time_axis, args.discard_seconds)) |
| 223 | + if start_index >= data.shape[1] - 4: |
| 224 | + start_index = 0 |
| 225 | + data = data[:, start_index:] |
| 226 | + time_axis = time_axis[start_index:] - time_axis[start_index] |
| 227 | + |
| 228 | + optics_data = data[optical_channels, :] |
| 229 | + active_indexes, active_labels = get_active_channels(optics_data, optical_channels) |
| 230 | + if not active_indexes: |
| 231 | + raise RuntimeError('no active optical channels found') |
| 232 | + |
| 233 | + filtered = filter_active_channels( |
| 234 | + optics_data, |
| 235 | + active_indexes, |
| 236 | + actual_sampling_rate, |
| 237 | + args.low_cut, |
| 238 | + args.high_cut, |
| 239 | + ) |
| 240 | + combined = np.mean(np.asarray([zscore(row) for row in filtered]), axis=0) |
| 241 | + pulse_signal, peaks = select_pulse_signal(combined, actual_sampling_rate, args.min_bpm, args.max_bpm) |
| 242 | + peak_bpm = bpm_from_peaks(peaks, time_axis) |
| 243 | + freqs, power, spectrum_bpm = spectrum(pulse_signal, actual_sampling_rate, args.low_cut, args.high_cut) |
| 244 | + |
| 245 | + prefix = args.output_prefix or str(Path(args.input_file).with_suffix('')) |
| 246 | + raw_plot = f'{prefix}_raw.png' |
| 247 | + filtered_plot = f'{prefix}_filtered.png' |
| 248 | + pulse_plot = f'{prefix}_pulse.png' |
| 249 | + spectrum_plot = f'{prefix}_spectrum.png' |
| 250 | + |
| 251 | + save_raw_plot(time_axis, optics_data, active_indexes, active_labels, raw_plot) |
| 252 | + save_filtered_plot(time_axis, filtered, active_labels, filtered_plot) |
| 253 | + save_pulse_plot(time_axis, pulse_signal, peaks, peak_bpm, spectrum_bpm, pulse_plot) |
| 254 | + save_spectrum_plot(freqs, power, spectrum_bpm, spectrum_plot) |
| 255 | + |
| 256 | + print(f'input file: {args.input_file}') |
| 257 | + print(f'samples analyzed: {data.shape[1]}') |
| 258 | + print(f'duration analyzed: {time_axis[-1] - time_axis[0]:.3f} sec') |
| 259 | + print(f'sampling rate: {actual_sampling_rate:.2f} Hz, expected: {expected_sampling_rate} Hz') |
| 260 | + print(f'active optical channels: {active_labels}') |
| 261 | + if peak_bpm is not None: |
| 262 | + print(f'peak estimate: {peak_bpm:.1f} BPM from {peaks.size} peaks') |
| 263 | + else: |
| 264 | + print('peak estimate: unavailable') |
| 265 | + if spectrum_bpm is not None: |
| 266 | + print(f'spectrum estimate: {spectrum_bpm:.1f} BPM') |
| 267 | + else: |
| 268 | + print('spectrum estimate: unavailable') |
| 269 | + print(f'wrote plot: {raw_plot}') |
| 270 | + print(f'wrote plot: {filtered_plot}') |
| 271 | + print(f'wrote plot: {pulse_plot}') |
| 272 | + print(f'wrote plot: {spectrum_plot}') |
| 273 | + |
| 274 | + |
| 275 | +if __name__ == '__main__': |
| 276 | + main() |
0 commit comments