diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 0b86fd4..ea8f99b 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -11,4 +11,4 @@ jobs: - uses: psf/black@stable with: options: "--check --verbose" - version: "24.1.1" + version: "25.1.0" diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 6e61454..e1a2c71 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -17,4 +17,4 @@ jobs: sudo apt install openmpi-bin libopenmpi-dev # Install dependencies for proper 1st/2nd/3rd party import sorting - run: pip install -e .[parallel] - - uses: isort/isort-action@v1.1.0 + - uses: isort/isort-action@master diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 33153b1..76525ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,11 +3,11 @@ default_install_hook_types: - commit-msg repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.1.1 + rev: 25.1.0 hooks: - id: black - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 6.0.0 hooks: - id: isort name: isort (python) diff --git a/bsb_neuron/adapter.py b/bsb_neuron/adapter.py index 8262b3b..11a735b 100644 --- a/bsb_neuron/adapter.py +++ b/bsb_neuron/adapter.py @@ -14,7 +14,7 @@ SimulatorAdapter, report, ) -from neo import AnalogSignal +from neo import AnalogSignal, SpikeTrain if typing.TYPE_CHECKING: from bsb import Simulation @@ -30,6 +30,36 @@ def __init__(self, simulation: "Simulation", result=None): class NeuronResult(SimulationResult): + + def record_spike( + self, time_vect, id_vect, cell_model_id, loc_label_id, locs_ids, **annotations + ): + def flush(segment): + if "units" not in annotations.keys(): + annotations["units"] = "ms" + if cell_model_id: + segment.spiketrains.append( + SpikeTrain( + np.array(time_vect) if len(time_vect) > 0 else [], + gids=np.array(id_vect) if len(id_vect) > 0 else [], + array_annotations={ + "senders": np.array([cell_model_id[gid] for gid in id_vect]) + }, + labels=np.array([loc_label_id[gid] for gid in id_vect]), + loc=np.array([locs_ids[gid] for gid in id_vect]), + **annotations, + ) + ) + else: + segment.spiketrains.append( + SpikeTrain( + np.array(time_vect) if len(time_vect) > 0 else [], + **annotations, + ) + ) + + self.create_recorder(flush) + def record(self, obj, **annotations): from patch import p from quantities import ms diff --git a/bsb_neuron/devices/__init__.py b/bsb_neuron/devices/__init__.py index ce2852e..8923c49 100644 --- a/bsb_neuron/devices/__init__.py +++ b/bsb_neuron/devices/__init__.py @@ -1,6 +1,7 @@ from .current_clamp import CurrentClamp from .ion_recorder import IonRecorder from .spike_generator import SpikeGenerator +from .spike_recorder import SpikeRecorder from .synapse_recorder import SynapseRecorder from .voltage_clamp import VoltageClamp from .voltage_recorder import VoltageRecorder diff --git a/bsb_neuron/devices/spike_recorder.py b/bsb_neuron/devices/spike_recorder.py new file mode 100644 index 0000000..b4a4325 --- /dev/null +++ b/bsb_neuron/devices/spike_recorder.py @@ -0,0 +1,92 @@ +import numpy as np +from bsb import LocationTargetting, config +from patch import p +from patch.objects import Vector + +from ..device import NeuronDevice + + +@config.node +class SpikeRecorder(NeuronDevice, classmap_entry="spike_recorder"): + """ + Device to record the spike events in selected locations. + + :param location: The LocationTargetting chosen to select location on cells, default selects "soma". + :type LocatioTargetting: ~bsb.simulation.targetting.LocationTargetting + + :param join_population: If set to True, a SpikeTrain object will be created for each cell population; if set to False, + a SpikeTrain will be stored for each individual location. The default value is False. + :type bool + """ + + locations = config.attr(type=LocationTargetting, default={"strategy": "soma"}) + join_population = config.attr(type=bool, default=False) + + def implement(self, adapter, simulation, simdata): + for model, pop in self.targetting.get_targets( + adapter, simulation, simdata + ).items(): + if self.join_population: + spike_times = p.parallel.Vector + neuron_gids = p.parallel.Vector + gids_to_cell = {} + gids_to_labels = {} + gids_to_locs = {} + for target in pop: + for location in self.locations.get_locations(target): + + gid = target.check_netcon(location, adapter) + # Call record_spike() method on selected gid using common spike_times and neuron_gids Vector for + # cells in the same population + gids_to_cell[gid] = target.id + gids_to_labels[gid] = location.section.labels + gids_to_locs[gid] = location._loc + spike_times, neuron_gids = p.parallel.spike_record( + gid, spike_times, neuron_gids + ) + # If join_population is selected Record a SpikeTrain obj for every model + self._add_spike_recorder( + simdata.result, + spike_times, + neuron_gids, + gids_to_cell, + gids_to_labels, + gids_to_locs, + device=self.name, + t_stop=simulation.duration, + cell_type=target.cell_model.name, + pop_size=len(pop), + ) + else: # We are splitting the outputs + for target in pop: + for location in self.locations.get_locations(target): + gid = target.check_netcon(location, adapter) + + # Call record_spike() method on selected gid, it will build a SpikeTrain obj for every location + spike_times, neuron_gids = p.parallel.spike_record(gid) + self._add_spike_recorder( + simdata.result, + spike_times, + neuron_gids, + device=self.name, + t_stop=simulation.duration, + cell_type=target.cell_model.name, + cell_id=target.id, + labels=location.section.labels, + loc=location._loc, + pop_size=len(pop), + ) + + def _add_spike_recorder( + self, + results, + spike_times, + gids, + cell_dict=None, + labels_dict=None, + locs_dict=None, + **annotations + ): + results.record_spike( + spike_times, gids, cell_dict, labels_dict, locs_dict, **annotations + ) diff --git a/pyproject.toml b/pyproject.toml index cf13240..53c9319 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ dev = [ "build~=1.0", "twine~=4.0", "pre-commit~=3.5", - "black~=24.1.1", - "isort~=5.12", + "black~=25.1.0", + "isort~=6.0", "snakeviz~=2.1", "bump-my-version~=0.24" ] diff --git a/tests/test_devices.py b/tests/test_devices.py new file mode 100644 index 0000000..87853ae --- /dev/null +++ b/tests/test_devices.py @@ -0,0 +1,189 @@ +import importlib +import unittest +from copy import copy + +import numpy as np +from bsb.services import MPI +from bsb.simulation import get_simulation_adapter +from bsb_test import ( + ConfigFixture, + MorphologiesFixture, + NetworkFixture, + NumpyTestCase, + RandomStorageFixture, +) +from patch import p + +from bsb_neuron.cell import ArborizedModel +from bsb_neuron.connection import TransceiverModel + + +def neuron_installed(): + return importlib.util.find_spec("neuron") + + +@unittest.skipIf(not neuron_installed(), "NEURON is not installed") +class TestSpikeRecorder( + RandomStorageFixture, + ConfigFixture, + NetworkFixture, + MorphologiesFixture, + NumpyTestCase, + unittest.TestCase, + config="complete", + morpho_filters=["3branch"], + engine_name="hdf5", +): + + def setUp(self): + super().setUp() + p.parallel.gid_clear() + self.network.network.chunk_size = [10, 10, 10] + for ct in self.network.cell_types.values(): + ct.spatial.morphologies = ["3branch"] + hh_soma = { + "cable_types": { + "soma": { + "cable": {"Ra": 10, "cm": 1}, + "mechanisms": {"pas": {}, "hh": {}}, + }, + "dendrites": { + "cable": {"Ra": 2, "cm": 5}, + "mechanisms": {"pas": {}, "hh": {}}, + }, + }, + "synapse_types": {"ExpSyn": {}}, + } + self.network.simulations.add( + "test", + simulator="neuron", + duration=50, + resolution=0.1, + temperature=32, + cell_models=dict( + A=ArborizedModel(model=hh_soma), + B=ArborizedModel(model=hh_soma), + C=ArborizedModel(model=hh_soma), + ), + connection_models=dict( + B_to_C=TransceiverModel( + synapses=[dict(synapse="ExpSyn", weight=0.001, delay=1)] + ), + ), + devices=dict( + spike_detector=dict( + device="spike_recorder", + targetting={ + "strategy": "cell_model", + "cell_models": ["A", "B", "C"], + }, + ), + first_current=dict( + device="current_clamp", + targetting={ + "strategy": "cell_model", + "cell_models": ["A", "C"], + }, + locations={"strategy": "soma"}, + before=5, + amplitude=50, + duration=1, + ), + second_current=dict( + device="current_clamp", + targetting={ + "strategy": "cell_model", + "cell_models": ["C"], + }, + locations={"strategy": "soma"}, + before=35, + amplitude=50, + duration=1, + ), + ), + ) + self.network.compile() + + def test_simple_stimulus(self): + sim = self.network.simulations.test + adapter = get_simulation_adapter(sim.simulator) + simdata = adapter.prepare(sim) + results = adapter.run(sim) + result = adapter.collect(results)[0] + pop_lenghts = [] + ids = [] + for cm in sim.cell_models: + pop = [cell.id for cell in simdata.populations[sim.cell_models[cm]]] + pop_lenghts.append(len(pop)) + ids.append(pop) + + for index, spk in zip(ids[0], result.spiketrains[: pop_lenghts[0] : 1]): + self.assertEqual(spk.annotations["cell_type"], "A") + self.assertEqual(spk.annotations["cell_id"], index) + self.assertClose( + spk.magnitude, + np.full(pop_lenghts[0], 5.1, dtype=np.float64), + f"SpikeTrains for cell A do not match!", + ) + second_interval = pop_lenghts[0] + pop_lenghts[1] + for index, spk in zip( + ids[1], result.spiketrains[pop_lenghts[0] : second_interval : 1] + ): + self.assertEqual(spk.annotations["cell_type"], "B") + self.assertEqual(spk.annotations["cell_id"], index) + self.assertClose( + spk.magnitude, np.array([]), f"SpikeTrains for cell B should be empty!" + ) + for index, spk in zip(ids[2], result.spiketrains[second_interval::1]): + self.assertEqual(spk.annotations["cell_type"], "C") + self.assertEqual(spk.annotations["cell_id"], index) + self.assertClose( + spk.magnitude, + np.full((pop_lenghts[0], 2), [5.1, 35.1], dtype=np.float64), + f"SpikeTrains for cell C do not match!", + ) + + def test_join_population(self): + """Test that spike_recorder correctly records stimulus and stores a spiketrain for every cell population""" + cfg = self.network.configuration + cfg.simulations.test.devices.spike_detector.join_population = True + self.network.storage.store_active_config(cfg) + sim = self.network.simulations.test + adapter = get_simulation_adapter(sim.simulator) + simdata = adapter.prepare(sim) + results = adapter.run(sim) + result = adapter.collect(results)[0] + self.assertEqual( + len(result.spiketrains), + 3, + "No event should be recorded for B cells but a SpikeTrain should still be allocated", + ) + control_data = [] + for cm in sim.cell_models: + appo = [] + pop = [cell.id for cell in simdata.populations[sim.cell_models[cm]]] + pop_len = len(pop) + appo.append(cm) + if cm == "A": + appo.append(list(pop)) + appo.append(np.full(pop_len, [5.1], dtype=np.float64)) + if cm == "B": + appo.append([]) + appo.append(np.array([], dtype=np.float64)) + if cm == "C": + appo.append([*pop, *pop]) + tot_times = np.concatenate( + ( + np.full(pop_len, [5.1], dtype=np.float64), + np.full(pop_len, [35.1], dtype=np.float64), + ) + ) + appo.append(tot_times) + control_data.append(appo) + + for elem, spike_train in enumerate(result.spiketrains): + self.assertEqual(control_data[elem][0], spike_train.annotations["cell_type"]) + self.assertEqual( + control_data[elem][1], list(spike_train.array_annotations["senders"]) + ) + self.assertClose(control_data[elem][2], spike_train.magnitude)