Skip to content

Commit 3b61ce2

Browse files
authored
Merge pull request #177 from buligar/master
Add mnist-learn spike logging and raster plot script
2 parents a271cb3 + 019e831 commit 3b61ce2

4 files changed

Lines changed: 260 additions & 2 deletions

File tree

examples/mnist-learn/inference.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
#include <map>
3030
#include <memory>
31+
#include <filesystem>
32+
#include <fstream>
3133
#include <string>
3234
#include <utility>
3335
#include <vector>
@@ -83,6 +85,7 @@ std::vector<knp::core::messaging::SpikeMessage> infer_network(
8385
model_executor.get_backend()->stop_learning();
8486

8587
std::ofstream log_stream;
88+
std::ofstream raw_spikes_stream;
8689

8790
// This variable should have the same lifetime as model_executor, or else UB.
8891
// cppcheck-suppress variableScope
@@ -97,6 +100,42 @@ std::vector<knp::core::messaging::SpikeMessage> infer_network(
97100

98101
knp::framework::monitoring::model::add_spikes_logger(model_executor, pop_names, std::cout);
99102

103+
if (!model_desc.log_path_.empty())
104+
{
105+
std::filesystem::create_directories(model_desc.log_path_);
106+
raw_spikes_stream.open(model_desc.log_path_ / "spikes_inference_raw.csv", std::ofstream::out);
107+
if (!raw_spikes_stream.is_open())
108+
{
109+
std::cout << "Couldn't open raw inference spikes log file : " << model_desc.log_path_ << std::endl;
110+
}
111+
else
112+
{
113+
raw_spikes_stream << "send_time,sender_name,sender_uid,neuron_index" << std::endl;
114+
115+
std::vector<knp::core::UID> all_senders_uids(pop_names.size());
116+
std::transform(
117+
pop_names.begin(), pop_names.end(), all_senders_uids.begin(),
118+
[](const auto& sender) -> knp::core::UID { return sender.first; });
119+
120+
model_executor.add_observer<knp::core::messaging::SpikeMessage>(
121+
[&raw_spikes_stream, &pop_names](const std::vector<knp::core::messaging::SpikeMessage>& messages)
122+
{
123+
for (const auto& message : messages)
124+
{
125+
const auto name_iter = pop_names.find(message.header_.sender_uid_);
126+
const std::string sender_name =
127+
name_iter == pop_names.end() ? "UNKNOWN" : name_iter->second;
128+
for (const auto neuron_index : message.neuron_indexes_)
129+
{
130+
raw_spikes_stream << message.header_.send_time_ << "," << sender_name << ","
131+
<< message.header_.sender_uid_ << "," << neuron_index << std::endl;
132+
}
133+
}
134+
},
135+
all_senders_uids);
136+
}
137+
}
138+
100139
// All loggers go here
101140
if (!model_desc.log_path_.empty())
102141
{

examples/mnist-learn/parse_arguments.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ std::optional<ModelDescription> parse_arguments(int argc, char** argv)
152152
model_desc.inference_backend_path_ = model_desc.training_backend_path_;
153153
}
154154

155-
if (vm.count("log_path"))
155+
if (vm.count("extensive_logs_path"))
156156
{
157-
model_desc.log_path_ = vm["log_path"].as<std::string>();
157+
model_desc.log_path_ = vm["extensive_logs_path"].as<std::string>();
158158
}
159159
else
160160
{
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#!/usr/bin/env python3
2+
"""Build a raster plot from mnist-learn spikes_inference_raw.csv."""
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import csv
8+
from collections import OrderedDict, defaultdict
9+
from dataclasses import dataclass
10+
from pathlib import Path
11+
12+
13+
@dataclass
14+
class SpikeEvent:
15+
send_time: int
16+
sender_name: str
17+
sender_uid: str
18+
neuron_index: int
19+
20+
21+
@dataclass
22+
class SenderTrack:
23+
sender_name: str
24+
sender_uid: str
25+
label: str
26+
start_index: int
27+
end_index: int
28+
29+
30+
def parse_args() -> argparse.Namespace:
31+
parser = argparse.ArgumentParser(description=__doc__)
32+
parser.add_argument("csv_path", type=Path, help="Path to spikes_inference_raw.csv")
33+
parser.add_argument(
34+
"-o",
35+
"--output",
36+
type=Path,
37+
help="Output PNG path. Default: <csv_path>.png",
38+
)
39+
parser.add_argument(
40+
"--sender",
41+
action="append",
42+
default=[],
43+
help="Filter by sender_name. Can be passed multiple times.",
44+
)
45+
parser.add_argument(
46+
"--title",
47+
default="Inference Spike Raster",
48+
help="Figure title.",
49+
)
50+
parser.add_argument(
51+
"--dpi",
52+
type=int,
53+
default=150,
54+
help="PNG DPI.",
55+
)
56+
parser.add_argument(
57+
"--marker-size",
58+
type=float,
59+
default=8.0,
60+
help="Scatter marker size.",
61+
)
62+
parser.add_argument(
63+
"--neuron-gap",
64+
type=int,
65+
default=4,
66+
help="Vertical gap between sender blocks.",
67+
)
68+
return parser.parse_args()
69+
70+
71+
def load_events(csv_path: Path, sender_filters: set[str]) -> list[SpikeEvent]:
72+
events: list[SpikeEvent] = []
73+
with csv_path.open("r", newline="", encoding="utf-8") as handle:
74+
reader = csv.DictReader(handle)
75+
required_columns = {"send_time", "sender_name", "sender_uid", "neuron_index"}
76+
if reader.fieldnames is None or not required_columns.issubset(reader.fieldnames):
77+
raise ValueError(
78+
f"{csv_path} must contain columns: {', '.join(sorted(required_columns))}"
79+
)
80+
81+
for row in reader:
82+
event = SpikeEvent(
83+
send_time=int(row["send_time"]),
84+
sender_name=row["sender_name"],
85+
sender_uid=row["sender_uid"],
86+
neuron_index=int(row["neuron_index"]),
87+
)
88+
if sender_filters and event.sender_name not in sender_filters:
89+
continue
90+
events.append(event)
91+
92+
if not events:
93+
raise ValueError("No spikes matched the selected filters.")
94+
95+
return events
96+
97+
98+
def build_tracks(events: list[SpikeEvent], neuron_gap: int) -> tuple[list[SenderTrack], dict[str, int]]:
99+
sender_order: OrderedDict[str, tuple[str, int]] = OrderedDict()
100+
max_neuron_per_sender: dict[str, int] = {}
101+
uids_per_name: dict[str, set[str]] = defaultdict(set)
102+
103+
for event in events:
104+
sender_key = event.sender_uid
105+
if sender_key not in sender_order:
106+
sender_order[sender_key] = (event.sender_name, event.neuron_index)
107+
max_neuron_per_sender[sender_key] = max(
108+
max_neuron_per_sender.get(sender_key, event.neuron_index), event.neuron_index
109+
)
110+
uids_per_name[event.sender_name].add(event.sender_uid)
111+
112+
tracks: list[SenderTrack] = []
113+
sender_offsets: dict[str, int] = {}
114+
current_offset = 0
115+
116+
for sender_uid, (sender_name, _) in sender_order.items():
117+
max_neuron = max_neuron_per_sender[sender_uid]
118+
label = sender_name
119+
if len(uids_per_name[sender_name]) > 1:
120+
label = f"{sender_name} ({sender_uid[:8]})"
121+
122+
start_index = current_offset
123+
end_index = current_offset + max_neuron
124+
tracks.append(
125+
SenderTrack(
126+
sender_name=sender_name,
127+
sender_uid=sender_uid,
128+
label=label,
129+
start_index=start_index,
130+
end_index=end_index,
131+
)
132+
)
133+
sender_offsets[sender_uid] = current_offset
134+
current_offset = end_index + 1 + neuron_gap
135+
136+
return tracks, sender_offsets
137+
138+
139+
def build_output_path(csv_path: Path, output: Path | None) -> Path:
140+
if output is not None:
141+
return output
142+
return csv_path.with_suffix(".png")
143+
144+
145+
def plot_raster(
146+
events: list[SpikeEvent],
147+
tracks: list[SenderTrack],
148+
sender_offsets: dict[str, int],
149+
output_path: Path,
150+
title: str,
151+
dpi: int,
152+
marker_size: float,
153+
) -> None:
154+
import matplotlib.pyplot as plt
155+
156+
x_values = [event.send_time for event in events]
157+
y_values = [sender_offsets[event.sender_uid] + event.neuron_index for event in events]
158+
159+
max_time = max(x_values)
160+
max_y = max(y_values)
161+
162+
fig_width = max(10.0, min(18.0, max_time / 40.0 + 4.0))
163+
fig_height = max(5.0, min(14.0, max_y / 60.0 + 3.0))
164+
165+
fig, ax = plt.subplots(figsize=(fig_width, fig_height), constrained_layout=True)
166+
ax.scatter(x_values, y_values, s=marker_size, c="black", marker="|", linewidths=0.7)
167+
168+
for track in tracks:
169+
ax.axhline(track.start_index - 0.5, color="0.85", linewidth=0.8)
170+
y_center = (track.start_index + track.end_index) / 2.0
171+
ax.text(
172+
max_time + 1,
173+
y_center,
174+
track.label,
175+
va="center",
176+
ha="left",
177+
fontsize=8,
178+
color="0.35",
179+
)
180+
181+
ax.axhline(tracks[-1].end_index + 0.5, color="0.85", linewidth=0.8)
182+
183+
ax.set_title(title)
184+
ax.set_xlabel("time step")
185+
ax.set_ylabel("stacked neuron index")
186+
ax.set_xlim(-1, max_time + max(2, max_time * 0.08))
187+
ax.set_ylim(-1, max_y + 1)
188+
ax.grid(axis="x", color="0.92", linewidth=0.8)
189+
190+
output_path.parent.mkdir(parents=True, exist_ok=True)
191+
fig.savefig(output_path, dpi=dpi)
192+
plt.close(fig)
193+
194+
195+
def main() -> int:
196+
args = parse_args()
197+
sender_filters = set(args.sender)
198+
events = load_events(args.csv_path, sender_filters)
199+
tracks, sender_offsets = build_tracks(events, args.neuron_gap)
200+
output_path = build_output_path(args.csv_path, args.output)
201+
plot_raster(
202+
events=events,
203+
tracks=tracks,
204+
sender_offsets=sender_offsets,
205+
output_path=output_path,
206+
title=args.title,
207+
dpi=args.dpi,
208+
marker_size=args.marker_size,
209+
)
210+
print(output_path)
211+
return 0
212+
213+
214+
if __name__ == "__main__":
215+
raise SystemExit(main())

examples/mnist-learn/training.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
#include <map>
3030
#include <memory>
31+
#include <filesystem>
32+
#include <fstream>
3133
#include <string>
3234
#include <utility>
3335
#include <vector>
@@ -124,6 +126,8 @@ void train_network(
124126
// All loggers go here
125127
if (!model_desc.log_path_.empty())
126128
{
129+
std::filesystem::create_directories(model_desc.log_path_);
130+
127131
log_stream.open(model_desc.log_path_ / "spikes_training.csv", std::ofstream::out);
128132
if (log_stream.is_open())
129133
knp::framework::monitoring::model::add_aggregated_spikes_logger(

0 commit comments

Comments
 (0)