Skip to content
This repository was archived by the owner on Jun 11, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
- uses: psf/black@stable
with:
options: "--check --verbose"
version: "24.1.1"
version: "25.1.0"
2 changes: 1 addition & 1 deletion .github/workflows/isort.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ jobs:
sudo apt install openmpi-bin libopenmpi-dev
# Install dependencies for proper 1st/2nd/3rd party import sorting
- run: pip install -e .[parallel]
- uses: isort/isort-action@v1.1.0
- uses: isort/isort-action@master
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ default_install_hook_types:
- commit-msg
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.1.1
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 6.0.0
hooks:
- id: isort
name: isort (python)
Expand Down
32 changes: 31 additions & 1 deletion bsb_neuron/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SimulatorAdapter,
report,
)
from neo import AnalogSignal
from neo import AnalogSignal, SpikeTrain

if typing.TYPE_CHECKING:
from bsb import Simulation
Expand All @@ -30,6 +30,36 @@ def __init__(self, simulation: "Simulation", result=None):


class NeuronResult(SimulationResult):

def record_spike(
self, time_vect, id_vect, cell_model_id, loc_label_id, locs_ids, **annotations
):
def flush(segment):
if "units" not in annotations.keys():
annotations["units"] = "ms"
if cell_model_id:
segment.spiketrains.append(
SpikeTrain(
np.array(time_vect) if len(time_vect) > 0 else [],
gids=np.array(id_vect) if len(id_vect) > 0 else [],
array_annotations={
"senders": np.array([cell_model_id[gid] for gid in id_vect])
},
labels=np.array([loc_label_id[gid] for gid in id_vect]),
loc=np.array([locs_ids[gid] for gid in id_vect]),
**annotations,
)
)
else:
segment.spiketrains.append(
SpikeTrain(
np.array(time_vect) if len(time_vect) > 0 else [],
**annotations,
)
)

self.create_recorder(flush)

def record(self, obj, **annotations):
from patch import p
from quantities import ms
Expand Down
1 change: 1 addition & 0 deletions bsb_neuron/devices/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .current_clamp import CurrentClamp
from .ion_recorder import IonRecorder
from .spike_generator import SpikeGenerator
from .spike_recorder import SpikeRecorder
from .synapse_recorder import SynapseRecorder
from .voltage_clamp import VoltageClamp
from .voltage_recorder import VoltageRecorder
92 changes: 92 additions & 0 deletions bsb_neuron/devices/spike_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
from bsb import LocationTargetting, config
from patch import p
from patch.objects import Vector

from ..device import NeuronDevice


@config.node
class SpikeRecorder(NeuronDevice, classmap_entry="spike_recorder"):
"""
Device to record the spike events in selected locations.

:param location: The LocationTargetting chosen to select location on cells, default selects "soma".
:type LocatioTargetting: ~bsb.simulation.targetting.LocationTargetting

:param join_population: If set to True, a SpikeTrain object will be created for each cell population; if set to False,
a SpikeTrain will be stored for each individual location. The default value is False.
:type bool
"""

locations = config.attr(type=LocationTargetting, default={"strategy": "soma"})
join_population = config.attr(type=bool, default=False)

def implement(self, adapter, simulation, simdata):
for model, pop in self.targetting.get_targets(
adapter, simulation, simdata
).items():
if self.join_population:
spike_times = p.parallel.Vector
neuron_gids = p.parallel.Vector
gids_to_cell = {}
gids_to_labels = {}
gids_to_locs = {}
for target in pop:
for location in self.locations.get_locations(target):

gid = target.check_netcon(location, adapter)
# Call record_spike() method on selected gid using common spike_times and neuron_gids Vector for
# cells in the same population
gids_to_cell[gid] = target.id
gids_to_labels[gid] = location.section.labels
gids_to_locs[gid] = location._loc
spike_times, neuron_gids = p.parallel.spike_record(
gid, spike_times, neuron_gids
)
# If join_population is selected Record a SpikeTrain obj for every model
self._add_spike_recorder(
simdata.result,
spike_times,
neuron_gids,
gids_to_cell,
gids_to_labels,
gids_to_locs,
device=self.name,
t_stop=simulation.duration,
cell_type=target.cell_model.name,
pop_size=len(pop),
)
else: # We are splitting the outputs
for target in pop:
for location in self.locations.get_locations(target):
gid = target.check_netcon(location, adapter)

# Call record_spike() method on selected gid, it will build a SpikeTrain obj for every location
spike_times, neuron_gids = p.parallel.spike_record(gid)
self._add_spike_recorder(
simdata.result,
spike_times,
neuron_gids,
device=self.name,
t_stop=simulation.duration,
cell_type=target.cell_model.name,
cell_id=target.id,
labels=location.section.labels,
loc=location._loc,
pop_size=len(pop),
)

def _add_spike_recorder(
self,
results,
spike_times,
gids,
cell_dict=None,
labels_dict=None,
locs_dict=None,
**annotations
):
results.record_spike(
spike_times, gids, cell_dict, labels_dict, locs_dict, **annotations
)
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ dev = [
"build~=1.0",
"twine~=4.0",
"pre-commit~=3.5",
"black~=24.1.1",
"isort~=5.12",
"black~=25.1.0",
"isort~=6.0",
"snakeviz~=2.1",
"bump-my-version~=0.24"
]
Expand Down
189 changes: 189 additions & 0 deletions tests/test_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import importlib
import unittest
from copy import copy

import numpy as np
from bsb.services import MPI
from bsb.simulation import get_simulation_adapter
from bsb_test import (
ConfigFixture,
MorphologiesFixture,
NetworkFixture,
NumpyTestCase,
RandomStorageFixture,
)
from patch import p

from bsb_neuron.cell import ArborizedModel
from bsb_neuron.connection import TransceiverModel


def neuron_installed():
return importlib.util.find_spec("neuron")


@unittest.skipIf(not neuron_installed(), "NEURON is not installed")
class TestSpikeRecorder(
RandomStorageFixture,
ConfigFixture,
NetworkFixture,
MorphologiesFixture,
NumpyTestCase,
unittest.TestCase,
config="complete",
morpho_filters=["3branch"],
engine_name="hdf5",
):

def setUp(self):
super().setUp()
p.parallel.gid_clear()
self.network.network.chunk_size = [10, 10, 10]
for ct in self.network.cell_types.values():
ct.spatial.morphologies = ["3branch"]
hh_soma = {
"cable_types": {
"soma": {
"cable": {"Ra": 10, "cm": 1},
"mechanisms": {"pas": {}, "hh": {}},
},
"dendrites": {
"cable": {"Ra": 2, "cm": 5},
"mechanisms": {"pas": {}, "hh": {}},
},
},
"synapse_types": {"ExpSyn": {}},
}
self.network.simulations.add(
"test",
simulator="neuron",
duration=50,
resolution=0.1,
temperature=32,
cell_models=dict(
A=ArborizedModel(model=hh_soma),
B=ArborizedModel(model=hh_soma),
C=ArborizedModel(model=hh_soma),
),
connection_models=dict(
B_to_C=TransceiverModel(
synapses=[dict(synapse="ExpSyn", weight=0.001, delay=1)]
),
),
devices=dict(
spike_detector=dict(
device="spike_recorder",
targetting={
"strategy": "cell_model",
"cell_models": ["A", "B", "C"],
},
),
first_current=dict(
device="current_clamp",
targetting={
"strategy": "cell_model",
"cell_models": ["A", "C"],
},
locations={"strategy": "soma"},
before=5,
amplitude=50,
duration=1,
),
second_current=dict(
device="current_clamp",
targetting={
"strategy": "cell_model",
"cell_models": ["C"],
},
locations={"strategy": "soma"},
before=35,
amplitude=50,
duration=1,
),
),
)
self.network.compile()

def test_simple_stimulus(self):
sim = self.network.simulations.test
adapter = get_simulation_adapter(sim.simulator)
simdata = adapter.prepare(sim)
results = adapter.run(sim)
result = adapter.collect(results)[0]
pop_lenghts = []
ids = []
for cm in sim.cell_models:
pop = [cell.id for cell in simdata.populations[sim.cell_models[cm]]]
pop_lenghts.append(len(pop))
ids.append(pop)

for index, spk in zip(ids[0], result.spiketrains[: pop_lenghts[0] : 1]):
self.assertEqual(spk.annotations["cell_type"], "A")
self.assertEqual(spk.annotations["cell_id"], index)
self.assertClose(
spk.magnitude,
np.full(pop_lenghts[0], 5.1, dtype=np.float64),
f"SpikeTrains for cell A do not match!",
)
second_interval = pop_lenghts[0] + pop_lenghts[1]
for index, spk in zip(
ids[1], result.spiketrains[pop_lenghts[0] : second_interval : 1]
):
self.assertEqual(spk.annotations["cell_type"], "B")
self.assertEqual(spk.annotations["cell_id"], index)
self.assertClose(
spk.magnitude, np.array([]), f"SpikeTrains for cell B should be empty!"
)
for index, spk in zip(ids[2], result.spiketrains[second_interval::1]):
self.assertEqual(spk.annotations["cell_type"], "C")
self.assertEqual(spk.annotations["cell_id"], index)
self.assertClose(
spk.magnitude,
np.full((pop_lenghts[0], 2), [5.1, 35.1], dtype=np.float64),
f"SpikeTrains for cell C do not match!",
)

def test_join_population(self):
"""Test that spike_recorder correctly records stimulus and stores a spiketrain for every cell population"""
cfg = self.network.configuration
cfg.simulations.test.devices.spike_detector.join_population = True
self.network.storage.store_active_config(cfg)
sim = self.network.simulations.test
adapter = get_simulation_adapter(sim.simulator)
simdata = adapter.prepare(sim)
results = adapter.run(sim)
result = adapter.collect(results)[0]
self.assertEqual(
len(result.spiketrains),
3,
"No event should be recorded for B cells but a SpikeTrain should still be allocated",
)
control_data = []
for cm in sim.cell_models:
appo = []
pop = [cell.id for cell in simdata.populations[sim.cell_models[cm]]]
pop_len = len(pop)
appo.append(cm)
if cm == "A":
appo.append(list(pop))
appo.append(np.full(pop_len, [5.1], dtype=np.float64))
if cm == "B":
appo.append([])
appo.append(np.array([], dtype=np.float64))
if cm == "C":
appo.append([*pop, *pop])
tot_times = np.concatenate(
(
np.full(pop_len, [5.1], dtype=np.float64),
np.full(pop_len, [35.1], dtype=np.float64),
)
)
appo.append(tot_times)
control_data.append(appo)

for elem, spike_train in enumerate(result.spiketrains):
self.assertEqual(control_data[elem][0], spike_train.annotations["cell_type"])
self.assertEqual(
control_data[elem][1], list(spike_train.array_annotations["senders"])
)
self.assertClose(control_data[elem][2], spike_train.magnitude)
Loading