Skip to content

Commit a34139e

Browse files
committed
add more examples for pulse detection for muse anthena
Signed-off-by: Andrey Parfenov <a1994ndrey@gmail.com>
1 parent 5bfaaab commit a34139e

3 files changed

Lines changed: 353 additions & 118 deletions

File tree

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import argparse
2+
import math
3+
from pathlib import Path
4+
5+
import matplotlib
6+
7+
matplotlib.use('Agg')
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
11+
from brainflow.board_shim import BoardShim, BoardIds, BrainFlowPresets
12+
from brainflow.data_filter import DataFilter, DetrendOperations, FilterTypes
13+
14+
15+
def build_time_axis(data, timestamp_channel, sampling_rate):
16+
timestamps = data[timestamp_channel, :]
17+
if timestamps.size == data.shape[1] and timestamps.size > 1 and np.all(np.isfinite(timestamps)):
18+
diffs = np.diff(timestamps)
19+
if np.count_nonzero(diffs > 0) > timestamps.size * 0.8:
20+
return timestamps - timestamps[0], 1.0 / np.median(diffs[diffs > 0])
21+
return np.arange(data.shape[1], dtype=np.float64) / sampling_rate, float(sampling_rate)
22+
23+
24+
def get_active_channels(optics_data, optical_channels):
25+
active = []
26+
active_labels = []
27+
for index, channel in enumerate(optical_channels):
28+
row = optics_data[index, :]
29+
finite = row[np.isfinite(row)]
30+
if finite.size > 0 and np.std(finite) > 1e-9:
31+
active.append(index)
32+
active_labels.append(f'Optics {channel}')
33+
return active, active_labels
34+
35+
36+
def zscore(data):
37+
stddev = np.std(data)
38+
if stddev < 1e-12:
39+
return np.zeros(data.shape)
40+
return (data - np.mean(data)) / stddev
41+
42+
43+
def filter_active_channels(optics_data, active_indexes, sampling_rate, low_cut, high_cut):
44+
filtered = []
45+
for index in active_indexes:
46+
row = np.asarray(optics_data[index, :], dtype=np.float64).copy()
47+
finite = np.isfinite(row)
48+
if not np.all(finite):
49+
row[~finite] = np.interp(np.flatnonzero(~finite), np.flatnonzero(finite), row[finite])
50+
DataFilter.detrend(row, DetrendOperations.LINEAR.value)
51+
DataFilter.perform_bandpass(
52+
row,
53+
int(round(sampling_rate)),
54+
low_cut,
55+
high_cut,
56+
4,
57+
FilterTypes.BUTTERWORTH_ZERO_PHASE.value,
58+
0.0,
59+
)
60+
filtered.append(row)
61+
return np.asarray(filtered)
62+
63+
64+
def find_local_peaks(signal, sampling_rate, min_bpm, max_bpm):
65+
if signal.size < 3:
66+
return np.asarray([], dtype=np.int64)
67+
68+
min_distance = max(1, int(round(sampling_rate * 60.0 / max_bpm)))
69+
threshold = np.percentile(signal, 65.0)
70+
candidates = np.flatnonzero((signal[1:-1] > signal[:-2]) & (signal[1:-1] >= signal[2:])) + 1
71+
candidates = candidates[signal[candidates] > threshold]
72+
73+
peaks = []
74+
for candidate in candidates:
75+
if not peaks or candidate - peaks[-1] >= min_distance:
76+
peaks.append(int(candidate))
77+
elif signal[candidate] > signal[peaks[-1]]:
78+
peaks[-1] = int(candidate)
79+
peaks = np.asarray(peaks, dtype=np.int64)
80+
81+
if peaks.size < 3:
82+
return peaks
83+
84+
intervals = np.diff(peaks) / sampling_rate
85+
valid = (intervals >= 60.0 / max_bpm) & (intervals <= 60.0 / min_bpm)
86+
keep = np.concatenate(([True], valid))
87+
return peaks[keep]
88+
89+
90+
def score_peaks(peaks, signal, sampling_rate, min_bpm, max_bpm):
91+
if peaks.size < 3:
92+
return -math.inf
93+
intervals = np.diff(peaks) / sampling_rate
94+
valid = (intervals >= 60.0 / max_bpm) & (intervals <= 60.0 / min_bpm)
95+
if np.count_nonzero(valid) < 2:
96+
return -math.inf
97+
valid_intervals = intervals[valid]
98+
rr_cv = np.std(valid_intervals) / np.mean(valid_intervals)
99+
return float(np.mean(signal[peaks]) - rr_cv)
100+
101+
102+
def select_pulse_signal(combined, sampling_rate, min_bpm, max_bpm):
103+
best_signal = combined
104+
best_peaks = find_local_peaks(combined, sampling_rate, min_bpm, max_bpm)
105+
best_score = score_peaks(best_peaks, combined, sampling_rate, min_bpm, max_bpm)
106+
107+
inverted = -combined
108+
inverted_peaks = find_local_peaks(inverted, sampling_rate, min_bpm, max_bpm)
109+
inverted_score = score_peaks(inverted_peaks, inverted, sampling_rate, min_bpm, max_bpm)
110+
111+
if inverted_score > best_score:
112+
best_signal = inverted
113+
best_peaks = inverted_peaks
114+
115+
return best_signal, best_peaks
116+
117+
118+
def bpm_from_peaks(peaks, time_axis):
119+
if peaks.size < 3:
120+
return None
121+
intervals = np.diff(time_axis[peaks])
122+
intervals = intervals[intervals > 0]
123+
if intervals.size == 0:
124+
return None
125+
return 60.0 / np.median(intervals)
126+
127+
128+
def spectrum(signal, sampling_rate, low_cut, high_cut):
129+
if signal.size < 4:
130+
return np.asarray([]), np.asarray([]), None
131+
centered = signal - np.mean(signal)
132+
windowed = centered * np.hanning(centered.size)
133+
freqs = np.fft.rfftfreq(windowed.size, d=1.0 / sampling_rate)
134+
power = np.abs(np.fft.rfft(windowed)) ** 2
135+
band = (freqs >= low_cut) & (freqs <= high_cut)
136+
if not np.any(band):
137+
return freqs, power, None
138+
peak_freq = freqs[band][np.argmax(power[band])]
139+
return freqs, power, peak_freq * 60.0
140+
141+
142+
def save_raw_plot(time_axis, optics_data, active_indexes, labels, output_file):
143+
fig, axes = plt.subplots(len(active_indexes), 1, figsize=(12, 2.1 * len(active_indexes)), sharex=True)
144+
axes = np.atleast_1d(axes)
145+
for axis, index, label in zip(axes, active_indexes, labels):
146+
axis.plot(time_axis, optics_data[index, :], linewidth=1.0)
147+
axis.set_ylabel(label)
148+
axis.grid(True)
149+
axes[-1].set_xlabel('Time, sec')
150+
fig.tight_layout()
151+
fig.savefig(output_file, dpi=150)
152+
plt.close(fig)
153+
154+
155+
def save_filtered_plot(time_axis, filtered, labels, output_file):
156+
fig, axes = plt.subplots(filtered.shape[0], 1, figsize=(12, 2.1 * filtered.shape[0]), sharex=True)
157+
axes = np.atleast_1d(axes)
158+
for axis, row, label in zip(axes, filtered, labels):
159+
axis.plot(time_axis, zscore(row), linewidth=1.0)
160+
axis.set_ylabel(label)
161+
axis.grid(True)
162+
axes[-1].set_xlabel('Time, sec')
163+
fig.tight_layout()
164+
fig.savefig(output_file, dpi=150)
165+
plt.close(fig)
166+
167+
168+
def save_pulse_plot(time_axis, pulse_signal, peaks, peak_bpm, spectrum_bpm, output_file):
169+
title_parts = []
170+
if peak_bpm is not None:
171+
title_parts.append(f'peaks {peak_bpm:.1f} BPM')
172+
if spectrum_bpm is not None:
173+
title_parts.append(f'spectrum {spectrum_bpm:.1f} BPM')
174+
175+
fig, axis = plt.subplots(1, 1, figsize=(12, 4))
176+
axis.plot(time_axis, pulse_signal, linewidth=1.0)
177+
if peaks.size > 0:
178+
axis.plot(time_axis[peaks], pulse_signal[peaks], 'ro', markersize=3)
179+
axis.set_xlabel('Time, sec')
180+
axis.set_ylabel('Combined filtered optics, z-score')
181+
axis.set_title(', '.join(title_parts) if title_parts else 'Combined filtered optics')
182+
axis.grid(True)
183+
fig.tight_layout()
184+
fig.savefig(output_file, dpi=150)
185+
plt.close(fig)
186+
187+
188+
def save_spectrum_plot(freqs, power, spectrum_bpm, output_file):
189+
fig, axis = plt.subplots(1, 1, figsize=(10, 4))
190+
if freqs.size > 0:
191+
axis.plot(freqs * 60.0, power, linewidth=1.0)
192+
if spectrum_bpm is not None:
193+
axis.axvline(spectrum_bpm, color='r', linestyle='--', linewidth=1.0)
194+
axis.set_xlabel('Frequency, BPM')
195+
axis.set_ylabel('Power')
196+
axis.set_xlim(30, 210)
197+
axis.grid(True)
198+
fig.tight_layout()
199+
fig.savefig(output_file, dpi=150)
200+
plt.close(fig)
201+
202+
203+
def main():
204+
parser = argparse.ArgumentParser()
205+
parser.add_argument('--input-file', type=str, required=False, default='muse_anthena_optics_recording.csv')
206+
parser.add_argument('--output-prefix', type=str, required=False, default='')
207+
parser.add_argument('--discard-seconds', type=float, required=False, default=5.0)
208+
parser.add_argument('--low-cut', type=float, required=False, default=0.7)
209+
parser.add_argument('--high-cut', type=float, required=False, default=3.5)
210+
parser.add_argument('--min-bpm', type=float, required=False, default=40.0)
211+
parser.add_argument('--max-bpm', type=float, required=False, default=180.0)
212+
args = parser.parse_args()
213+
214+
data = DataFilter.read_file(args.input_file)
215+
board_id = BoardIds.MUSE_S_ANTHENA_BOARD.value
216+
preset = BrainFlowPresets.ANCILLARY_PRESET
217+
optical_channels = BoardShim.get_optical_channels(board_id, preset)
218+
timestamp_channel = BoardShim.get_timestamp_channel(board_id, preset)
219+
expected_sampling_rate = BoardShim.get_sampling_rate(board_id, preset)
220+
221+
time_axis, actual_sampling_rate = build_time_axis(data, timestamp_channel, expected_sampling_rate)
222+
start_index = int(np.searchsorted(time_axis, args.discard_seconds))
223+
if start_index >= data.shape[1] - 4:
224+
start_index = 0
225+
data = data[:, start_index:]
226+
time_axis = time_axis[start_index:] - time_axis[start_index]
227+
228+
optics_data = data[optical_channels, :]
229+
active_indexes, active_labels = get_active_channels(optics_data, optical_channels)
230+
if not active_indexes:
231+
raise RuntimeError('no active optical channels found')
232+
233+
filtered = filter_active_channels(
234+
optics_data,
235+
active_indexes,
236+
actual_sampling_rate,
237+
args.low_cut,
238+
args.high_cut,
239+
)
240+
combined = np.mean(np.asarray([zscore(row) for row in filtered]), axis=0)
241+
pulse_signal, peaks = select_pulse_signal(combined, actual_sampling_rate, args.min_bpm, args.max_bpm)
242+
peak_bpm = bpm_from_peaks(peaks, time_axis)
243+
freqs, power, spectrum_bpm = spectrum(pulse_signal, actual_sampling_rate, args.low_cut, args.high_cut)
244+
245+
prefix = args.output_prefix or str(Path(args.input_file).with_suffix(''))
246+
raw_plot = f'{prefix}_raw.png'
247+
filtered_plot = f'{prefix}_filtered.png'
248+
pulse_plot = f'{prefix}_pulse.png'
249+
spectrum_plot = f'{prefix}_spectrum.png'
250+
251+
save_raw_plot(time_axis, optics_data, active_indexes, active_labels, raw_plot)
252+
save_filtered_plot(time_axis, filtered, active_labels, filtered_plot)
253+
save_pulse_plot(time_axis, pulse_signal, peaks, peak_bpm, spectrum_bpm, pulse_plot)
254+
save_spectrum_plot(freqs, power, spectrum_bpm, spectrum_plot)
255+
256+
print(f'input file: {args.input_file}')
257+
print(f'samples analyzed: {data.shape[1]}')
258+
print(f'duration analyzed: {time_axis[-1] - time_axis[0]:.3f} sec')
259+
print(f'sampling rate: {actual_sampling_rate:.2f} Hz, expected: {expected_sampling_rate} Hz')
260+
print(f'active optical channels: {active_labels}')
261+
if peak_bpm is not None:
262+
print(f'peak estimate: {peak_bpm:.1f} BPM from {peaks.size} peaks')
263+
else:
264+
print('peak estimate: unavailable')
265+
if spectrum_bpm is not None:
266+
print(f'spectrum estimate: {spectrum_bpm:.1f} BPM')
267+
else:
268+
print('spectrum estimate: unavailable')
269+
print(f'wrote plot: {raw_plot}')
270+
print(f'wrote plot: {filtered_plot}')
271+
print(f'wrote plot: {pulse_plot}')
272+
print(f'wrote plot: {spectrum_plot}')
273+
274+
275+
if __name__ == '__main__':
276+
main()

python_package/examples/tests/muse_anthena_optics_plot.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

0 commit comments

Comments
 (0)