From 8485e205f88c6ef1f2952365c3224baa65e34d29 Mon Sep 17 00:00:00 2001 From: Tanmay Bankar Date: Thu, 30 Nov 2023 11:04:27 +0530 Subject: [PATCH] functions for cube visualisation, spike raster, spike count, storing feature vectors, connection weights --- neucube/reservoir.py | 148 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 143 insertions(+), 5 deletions(-) diff --git a/neucube/reservoir.py b/neucube/reservoir.py index de13d7c..f282596 100644 --- a/neucube/reservoir.py +++ b/neucube/reservoir.py @@ -3,6 +3,9 @@ import math from .topology import small_world_connectivity from .utils import print_summary +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np class Reservoir(): def __init__(self, cube_shape=(10,10,10), inputs=None, coordinates=None, mapping=None, c=1.2, l=1.6, c_in = 0.9, l_in = 1.2): @@ -45,6 +48,58 @@ def __init__(self, cube_shape=(10,10,10), inputs=None, coordinates=None, mapping self.w_latent = conn_mat.to(self.device) self.w_in = input_conn.to(self.device) + self.coordinates = coordinates + self.mapping = mapping + + + def retrieve_conn_mat(self): + """ + Retrieves the connection matrix established after small world connectivity, and converts + it into a csv file to be saved. + """ + mat = self.w_latent + DF = pd.DataFrame(mat.cpu()) + DF.to_csv("conn.csv") + + + + def visualize_cube(self,cube_shape=(10,10,10),coordinates=None, mapping=None): + """ + Visualises the cube in a 3D space, indicating the input neurons and their positions + + Parameters: + cube_shape(tuple): Dimensions of the cube + coordinates(torch.Tensor): Coordinates of the neurons in the reservoir. + If not provided, the coordinates were generated based on `cube_shape`. + mapping (torch.Tensor): Coordinates of the input neurons. + If not provided, random connectivity was used. + """ + fig = plt.figure(figsize=(15,9)) + ax = fig.add_subplot(111, projection='3d') + if coordinates is None: + x, y, z = torch.meshgrid(torch.linspace(0, 1, cube_shape[0]), torch.linspace(0, 1, cube_shape[1]), torch.linspace(0, 1, cube_shape[2]), indexing='xy') + ax.scatter(x.flatten(), y.flatten(), z.flatten(), s = 8, c='#3258a8') #f5d1b6 #957d5f #A18A6C + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.grid(False) + plt.show() + else: + coordinates_np = coordinates.numpy() + ax.scatter(coordinates_np[:,1], coordinates_np[:,0], coordinates_np[:,2], s = 8, c='#3258a8', zorder = 0) + if mapping is not None: + mapping_np = mapping.numpy() + ax.scatter(mapping_np[:,1], mapping_np[:,0], mapping_np[:,2], s =60, c = 'black', zorder = 10 ) + + + ax.set_xlabel('Y') + ax.set_ylabel('X') + ax.set_zlabel('Z') + ax.invert_xaxis() + ax.grid(False) + plt.show() + def simulate(self, X, mem_thr=0.1, refractory_period=5, train=True, verbose=True): """ Simulates the reservoir activity given input data. @@ -63,7 +118,7 @@ def simulate(self, X, mem_thr=0.1, refractory_period=5, train=True, verbose=True spike_rec = torch.zeros(self.batch_size, self.n_time, self.n_neurons) - for s in tqdm(range(X.shape[0]), disable = not verbose): + for s in tqdm(range(X.shape[0]), disable = not verbose): #range is from 0 to 59, samples spike_latent = torch.zeros(self.n_neurons).to(self.device) mem_poten = torch.zeros(self.n_neurons).to(self.device) @@ -71,9 +126,9 @@ def simulate(self, X, mem_thr=0.1, refractory_period=5, train=True, verbose=True refrac_count = torch.zeros(self.n_neurons).to(self.device) spike_times = torch.zeros(self.n_neurons).to(self.device) - for k in range(self.n_time): + for k in range(self.n_time): # k goes from 0 to 127 (timestamps) - spike_in = X[s,k,:] + spike_in = X[s,k,:] #spike input for all 14 features spike_in = spike_in.to(self.device) refrac[refrac_count < 1] = 1 @@ -99,11 +154,94 @@ def simulate(self, X, mem_thr=0.1, refractory_period=5, train=True, verbose=True self.w_latent += pre_updates self.w_latent += pos_updates + + spike_times[mem_poten >= mem_thr] = k - - spike_rec[s,k,:] = spike_latent + spike_rec[s,k,:] = spike_latent + self.output = spike_rec + return spike_rec + + def post_weights(self): + """ + Retrieves the weight matrix obtained after running the simulate function, and converts + it into a csv file to be saved. + """ + mat = self.w_latent + DF = pd.DataFrame(mat.cpu()) + DF.to_csv("post_conn.csv") + + def input_spike_count(self): + """ + This caclulates the total spike count for input neurons over time, for each sample + + """ + mapping = self.mapping + coordinates = self.coordinates + out = self.output + eeg_np = mapping.numpy() + brain_np = coordinates.numpy() + idx = [] + for row in eeg_np: + mask = np.all(brain_np == row, axis=1) + + # Find the indices where the mask is True + indices = np.where(mask)[0] + + idx.append(indices) + + indexi = [item[0] for item in idx] + + matrix = torch.zeros(out.shape[0], len(indexi)) + + for sm in range (len(out)): + for i in range (len(indexi)): + matrix[sm][i] = out[sm][:,indexi[i]].sum().item() + + + return matrix + + + def feature_vectors(self, k_vec): + """ + This function can be used to extract and store K feature vectors + of length N that represent the number of spikes of each neurons + from the cube within each of the k time intervals from T. + + """ + + out = self.output + window = int(out.shape[1]/k_vec)+1 + feature = torch.zeros(out.shape[0], int(k_vec), out.shape[2]) + + for sm in range(out.shape[0]): + + for neuron in range(out.shape[2]): + i = 0 + idx = 0 + while(i