Skip to content

Commit f6de923

Browse files
committed
Refactor Inlet to respond to settings switch on-the-fly.
1 parent 6b807c3 commit f6de923

File tree

2 files changed

+137
-89
lines changed

2 files changed

+137
-89
lines changed

src/ezmsg/lsl/inlet.py

+86-73
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class LSLInletUnit(ez.Unit):
124124
SETTINGS = LSLInletSettings
125125
STATE = LSLInletState
126126

127+
INPUT_SETTINGS = ez.InputStream(LSLInletSettings)
127128
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
128129

129130
# Share clock correction across all instances
@@ -137,9 +138,15 @@ def __init__(self, *args, **kwargs) -> None:
137138
"""
138139
kwargs = _sanitize_kwargs(kwargs)
139140
super().__init__(*args, **kwargs)
141+
self._msg_template: typing.Optional[AxisArray] = None
142+
self._fetch_buffer: typing.Optional[npt.NDArray] = None
143+
144+
def _reset_resolver(self) -> None:
145+
self.STATE.resolver = pylsl.ContinuousResolver(pred=None, forget_after=30.0)
140146

141147
def _reset_inlet(self) -> None:
142-
self._fetch_buffer: npt.NDArray | None = None
148+
self._msg_template: typing.Optional[AxisArray] = None
149+
self._fetch_buffer: typing.Optional[npt.NDArray] = None
143150
if self.STATE.inlet is not None:
144151
self.STATE.inlet.close_stream()
145152
del self.STATE.inlet
@@ -166,90 +173,96 @@ def _reset_inlet(self) -> None:
166173
self.STATE.inlet = pylsl.StreamInlet(
167174
info, max_chunklen=1, processing_flags=self.SETTINGS.processing_flags
168175
)
176+
else:
177+
results: list[pylsl.StreamInfo] = self.STATE.resolver.results()
178+
for strm_info in results:
179+
b_match = True
180+
b_match = b_match and ((not self.SETTINGS.info.name) or strm_info.name() == self.SETTINGS.info.name)
181+
b_match = b_match and ((not self.SETTINGS.info.type) or strm_info.type() == self.SETTINGS.info.type)
182+
if b_match:
183+
self.STATE.inlet = pylsl.StreamInlet(
184+
strm_info, max_chunklen=1, processing_flags=self.SETTINGS.processing_flags
185+
)
186+
break
169187

170-
def _reset_resolver(self) -> None:
171-
# Build the predicate string. This uses XPATH syntax and can filter on anything in the stream info. e.g.,
172-
# `"name='BioSemi'" or "type='EEG' and starts-with(name,'BioSemi') and count(info/desc/channel)=32"`
173-
pred = ""
174-
if self.SETTINGS.info.name:
175-
pred += f"name='{self.SETTINGS.info.name}'"
176-
if self.SETTINGS.info.type:
177-
if len(pred):
178-
pred += " and "
179-
pred += f"type='{self.SETTINGS.info.type}'"
180-
if not len(pred):
181-
pred = None
182-
self.STATE.resolver = pylsl.ContinuousResolver(pred=pred)
188+
if self.STATE.inlet is not None:
189+
self.STATE.inlet.open_stream()
190+
inlet_info = self.STATE.inlet.info()
191+
self.SETTINGS.info.nominal_srate = inlet_info.nominal_srate()
192+
# If possible, create a destination buffer for faster pulls
193+
fmt = inlet_info.channel_format()
194+
n_ch = inlet_info.channel_count()
195+
if fmt in fmt2npdtype:
196+
dtype = fmt2npdtype[fmt]
197+
n_buff = (
198+
int(self.SETTINGS.local_buffer_dur * inlet_info.nominal_srate()) or 1000
199+
)
200+
self._fetch_buffer = np.zeros((n_buff, n_ch), dtype=dtype)
201+
ch_labels = []
202+
chans = inlet_info.desc().child("channels")
203+
if not chans.empty():
204+
ch = chans.first_child()
205+
while not ch.empty():
206+
ch_labels.append(ch.child_value("label"))
207+
ch = ch.next_sibling()
208+
while len(ch_labels) < n_ch:
209+
ch_labels.append(str(len(ch_labels) + 1))
210+
# Pre-allocate a message template.
211+
fs = inlet_info.nominal_srate()
212+
self._msg_template = AxisArray(
213+
data=np.empty((0, n_ch)),
214+
dims=["time", "ch"],
215+
axes={
216+
"time": AxisArray.Axis.TimeAxis(
217+
fs=fs if fs else 1.0
218+
), # HACK: Use 1.0 for irregular rate.
219+
"ch": AxisArray.Axis.SpaceAxis(labels=ch_labels),
220+
},
221+
key=inlet_info.name(),
222+
)
183223

184224
async def initialize(self) -> None:
185-
self._reset_inlet()
186225
self._reset_resolver()
226+
self._reset_inlet()
227+
# TODO: Let the clock_sync task do its job at the beginning.
187228

188229
def shutdown(self) -> None:
189-
if self.STATE.resolver is not None:
190-
del self.STATE.resolver
191-
self.STATE.resolver = None
192230
if self.STATE.inlet is not None:
193231
self.STATE.inlet.close_stream()
194232
del self.STATE.inlet
195233
self.STATE.inlet = None
234+
if self.STATE.resolver is not None:
235+
del self.STATE.resolver
236+
self.STATE.resolver = None
196237

197238
@ez.task
198239
async def clock_sync_task(self) -> None:
199240
while True:
200241
force = self.clock_sync.count < 1000
201242
await self.clock_sync.update(force=force, burst=1000 if force else 4)
202243

244+
@ez.subscriber(INPUT_SETTINGS)
245+
async def on_settings(self, msg: LSLInletSettings) -> None:
246+
# The message may be full LSLInletSettings, a dict of settings, just the info, or dict of just info.
247+
if isinstance(msg, dict):
248+
# First make sure the info is in the right place.
249+
msg = _sanitize_kwargs(msg)
250+
# Next, convert to LSLInletSettings object.
251+
msg = LSLInletSettings(**msg)
252+
if msg != self.SETTINGS:
253+
self.apply_settings(msg)
254+
self._reset_resolver()
255+
self._reset_inlet()
256+
203257
@ez.publisher(OUTPUT_SIGNAL)
204258
async def lsl_pull(self) -> typing.AsyncGenerator:
205-
while self.STATE.inlet is None:
206-
results: list[pylsl.StreamInfo] = self.STATE.resolver.results()
207-
if len(results):
208-
self.STATE.inlet = pylsl.StreamInlet(
209-
results[0], max_chunklen=1, processing_flags=pylsl.proc_ALL
210-
)
211-
else:
212-
await asyncio.sleep(0.5)
213-
214-
self.STATE.inlet.open_stream()
215-
inlet_info = self.STATE.inlet.info()
216-
# If possible, create a destination buffer for faster pulls
217-
fmt = inlet_info.channel_format()
218-
n_ch = inlet_info.channel_count()
219-
if fmt in fmt2npdtype:
220-
dtype = fmt2npdtype[fmt]
221-
n_buff = (
222-
int(self.SETTINGS.local_buffer_dur * inlet_info.nominal_srate()) or 1000
223-
)
224-
self._fetch_buffer = np.zeros((n_buff, n_ch), dtype=dtype)
225-
ch_labels = []
226-
chans = inlet_info.desc().child("channels")
227-
if not chans.empty():
228-
ch = chans.first_child()
229-
while not ch.empty():
230-
ch_labels.append(ch.child_value("label"))
231-
ch = ch.next_sibling()
232-
while len(ch_labels) < n_ch:
233-
ch_labels.append(str(len(ch_labels) + 1))
234-
# Pre-allocate a message template.
235-
fs = inlet_info.nominal_srate()
236-
msg_template = AxisArray(
237-
data=np.empty((0, n_ch)),
238-
dims=["time", "ch"],
239-
axes={
240-
"time": AxisArray.Axis.TimeAxis(
241-
fs=fs if fs else 1.0
242-
), # HACK: Use 1.0 for irregular rate.
243-
"ch": AxisArray.Axis.SpaceAxis(labels=ch_labels),
244-
},
245-
key=inlet_info.name(),
246-
)
247-
248-
while self.clock_sync.count < 1000:
249-
# Let the clock_sync task do its job at the beginning.
250-
await asyncio.sleep(0.001)
251-
252-
while self.STATE.inlet is not None:
259+
while True:
260+
if self.STATE.inlet is None:
261+
# Inlet not yet created, or recently destroyed because settings changed.
262+
self._reset_inlet()
263+
await asyncio.sleep(0.1)
264+
continue
265+
253266
if self._fetch_buffer is not None:
254267
samples, timestamps = self.STATE.inlet.pull_chunk(
255268
max_samples=self._fetch_buffer.shape[0], dest_obj=self._fetch_buffer
@@ -270,16 +283,16 @@ async def lsl_pull(self) -> typing.AsyncGenerator:
270283
t0 = time.time() - (timestamps[-1] - timestamps[0])
271284
else:
272285
t0 = self.clock_sync.convert_timestamp(timestamps[0])
273-
if fs <= 0.0:
286+
if self.SETTINGS.info.nominal_srate <= 0.0:
274287
# Irregular rate streams need to be streamed sample-by-sample
275288
for ts, samp in zip(timestamps, data):
276289
out_msg = replace(
277-
msg_template,
290+
self._msg_template,
278291
data=samp[None, ...],
279292
axes={
280-
**msg_template.axes,
293+
**self._msg_template.axes,
281294
"time": replace(
282-
msg_template.axes["time"],
295+
self._msg_template.axes["time"],
283296
offset=t0 + (ts - timestamps[0]),
284297
),
285298
},
@@ -288,11 +301,11 @@ async def lsl_pull(self) -> typing.AsyncGenerator:
288301
else:
289302
# Regular-rate streams can go in a chunk
290303
out_msg = replace(
291-
msg_template,
304+
self._msg_template,
292305
data=data,
293306
axes={
294-
**msg_template.axes,
295-
"time": replace(msg_template.axes["time"], offset=t0),
307+
**self._msg_template.axes,
308+
"time": replace(self._msg_template.axes["time"], offset=t0),
296309
},
297310
)
298311
yield self.OUTPUT_SIGNAL, out_msg

tests/test_inlet.py

+51-16
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
These unit tests aren't really testable in a runner without a complicated setup with inlets and outlets.
33
This code exists mostly to use during development and debugging.
44
"""
5-
import os
5+
import asyncio
66
import json
7+
import os
78
from pathlib import Path
89
import tempfile
10+
import typing
911

1012
import numpy as np
1113

@@ -21,6 +23,24 @@ def test_inlet_init_defaults():
2123
assert True
2224

2325

26+
class StreamSwitcher(ez.Unit):
27+
STATE = ez.State
28+
SETTINGS = ez.Settings
29+
OUTPUT_SETTINGS = ez.OutputStream(LSLInletSettings)
30+
31+
@ez.publisher(OUTPUT_SETTINGS)
32+
async def switch_stream(self) -> typing.AsyncGenerator:
33+
switch_counter = 0
34+
35+
while True:
36+
if switch_counter % 2 == 0:
37+
yield self.OUTPUT_SETTINGS, LSLInletSettings(info=LSLInfo(type="ECoG"))
38+
else:
39+
yield self.OUTPUT_SETTINGS, LSLInletSettings(info=LSLInfo(type="Markers"))
40+
switch_counter += 1
41+
await asyncio.sleep(2)
42+
43+
2444
class MessageReceiverSettings(ez.Settings):
2545
num_msgs: int
2646
output_fn: str
@@ -34,38 +54,53 @@ class AxarrReceiver(ez.Unit):
3454
STATE = MessageReceiverState
3555
SETTINGS = MessageReceiverSettings
3656
INPUT_SIGNAL = ez.InputStream(AxisArray)
57+
OUTPUT_SETTINGS = ez.OutputStream(LSLInletSettings)
3758

3859
@ez.subscriber(INPUT_SIGNAL)
3960
async def on_message(self, msg: AxisArray) -> None:
4061
self.STATE.num_received += 1
41-
t_ax = msg.axes["time"]
42-
tvec = np.arange(msg.data.shape[0]) * t_ax.gain + t_ax.offset
43-
payload = {self.STATE.num_received: tvec.tolist()}
44-
with open(self.SETTINGS.output_fn, "a") as output_file:
45-
output_file.write(json.dumps(payload) + "\n")
62+
try:
63+
t_ax = msg.axes["time"]
64+
tvec = np.arange(msg.data.shape[0]) * t_ax.gain + t_ax.offset
65+
payload = {self.STATE.num_received: tvec.tolist()}
66+
with open(self.SETTINGS.output_fn, "a") as output_file:
67+
output_file.write(json.dumps(payload) + "\n")
68+
except Exception as e:
69+
print(f"Debug {e}")
4670
if self.STATE.num_received == self.SETTINGS.num_msgs:
4771
raise ez.NormalTermination
4872

4973

5074
def test_inlet_init_with_settings():
5175
test_name = os.environ.get("PYTEST_CURRENT_TEST")
76+
if test_name is None:
77+
test_name = "test_inlet:test_inlet_init_with_settings na"
5278
test_name = test_name.split(":")[-1].split(" ")[0]
5379
file_path = Path(tempfile.gettempdir())
5480
file_path = file_path / Path(f"{test_name}.json")
5581

5682
comps = {
5783
"SRC": LSLInletUnit(info=LSLInfo(name="BrainVision RDA", type="EEG")),
58-
"SINK": AxarrReceiver(num_msgs=10_000, output_fn=file_path),
84+
"SINK": AxarrReceiver(num_msgs=500, output_fn=file_path),
85+
"FLIPFLOP": StreamSwitcher(),
5986
}
60-
conns = ((comps["SRC"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL),)
87+
conns = (
88+
(comps["FLIPFLOP"].OUTPUT_SETTINGS, comps["SRC"].INPUT_SETTINGS),
89+
(comps["SRC"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL),
90+
)
6191
ez.run(components=comps, connections=conns)
6292

63-
tvecs = []
64-
with open(file_path, "r") as file:
65-
for ix, line in enumerate(file.readlines()):
66-
tvecs.append(json.loads(line)[str(ix + 1)])
67-
os.remove(str(file_path))
68-
tvec = np.hstack(tvecs)
93+
# tvecs = []
94+
# with open(file_path, "r") as file:
95+
# for ix, line in enumerate(file.readlines()):
96+
# tmp = json.loads(line)
97+
# tvecs.append(tmp[str(ix + 1)])
98+
# os.remove(str(file_path))
99+
# tvec = np.hstack(tvecs)
100+
#
101+
# # counts, bins = np.histogram(np.diff(tvec), 20)
102+
# assert np.max(np.diff(tvec)) < 0.003
103+
69104

70-
# counts, bins = np.histogram(np.diff(tvec), 20)
71-
assert np.max(np.diff(tvec)) < 0.003
105+
if __name__ == "__main__":
106+
test_inlet_init_with_settings()

0 commit comments

Comments
 (0)