-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_inlet.py
71 lines (54 loc) · 2.15 KB
/
test_inlet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
These unit tests aren't really testable in a runner without a complicated setup with inlets and outlets.
This code exists mostly to use during development and debugging.
"""
import os
import json
from pathlib import Path
import tempfile
import numpy as np
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.lsl.units import LSLInfo, LSLInletSettings, LSLInletUnit
def test_inlet_init_defaults():
settings = LSLInletSettings(name="", type="")
_ = LSLInletUnit(settings)
assert True
class MessageReceiverSettings(ez.Settings):
num_msgs: int
output_fn: str
class MessageReceiverState(ez.State):
num_received: int = 0
class AxarrReceiver(ez.Unit):
STATE = MessageReceiverState
SETTINGS = MessageReceiverSettings
INPUT_SIGNAL = ez.InputStream(AxisArray)
@ez.subscriber(INPUT_SIGNAL)
async def on_message(self, msg: AxisArray) -> None:
self.STATE.num_received += 1
t_ax = msg.axes["time"]
tvec = np.arange(msg.data.shape[0]) * t_ax.gain + t_ax.offset
payload = {self.STATE.num_received: tvec.tolist()}
with open(self.SETTINGS.output_fn, "a") as output_file:
output_file.write(json.dumps(payload) + "\n")
if self.STATE.num_received == self.SETTINGS.num_msgs:
raise ez.NormalTermination
def test_inlet_init_with_settings():
test_name = os.environ.get("PYTEST_CURRENT_TEST")
test_name = test_name.split(":")[-1].split(" ")[0]
file_path = Path(tempfile.gettempdir())
file_path = file_path / Path(f"{test_name}.json")
comps = {
"SRC": LSLInletUnit(info=LSLInfo(name="BrainVision RDA", type="EEG")),
"SINK": AxarrReceiver(num_msgs=10_000, output_fn=file_path),
}
conns = ((comps["SRC"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL),)
ez.run(components=comps, connections=conns)
tvecs = []
with open(file_path, "r") as file:
for ix, line in enumerate(file.readlines()):
tvecs.append(json.loads(line)[str(ix + 1)])
os.remove(str(file_path))
tvec = np.hstack(tvecs)
# counts, bins = np.histogram(np.diff(tvec), 20)
assert np.max(np.diff(tvec)) < 0.003