Balanced spiking neural networks with NumPy

Balanced spiking neural networks are a cornerstone of computational neuroscience. They make for a nice introduction into spiking neuronal networks and they can be trained to store information (Nicola & Clopath, 2017). Here I will present a Python port I made from a MATLAB implementation by Nicola & Clopath and I will go through some of its features. We only need NumPy.

import numpy as np
from numpy.random import rand, randn
import matplotlib.pyplot as plt


def balanced_spiking_network(dt=0.00005, T=2.0, tref=0.002, tm=0.01,
                             vreset=-65.0, vpeak=-40.0, n=2000, 
                             td=0.02, tr=0.002, p=0.1, 
                             offset=-40.00, g=0.04, seed=100, 
                             nrec=10):
    """Simulate a balanced spiking neuronal network

    Parameters
    ----------
    dt : float
        Sampling interval of the simulation.
    T : float
        Duration of the simulation.
    tref : float
        Refractory time of the neurons.
    tm : float
        Time constant of the neurons.
    vreset : float
        The voltage neurons are set to after a spike.
    vpeak : float
        The voltage above which a spike is triggered
    n : int
        The number of neurons.
    td : float
        Synaptic decay time constant.
    tr : float
        Synaptic rise time constant.
    p : float
        Connection probability between neurons.
    offset : float
        A constant input into all neurons.
    g : float
        Scaling factor of synaptic strength
    seed : int
        The seed makes NumPy random number generator deterministic.
    nrec : int
        The number of neurons to record.

    Returns
    -------
    ndarray
        A 2D array of recorded voltages. Rows are time points,
        columns are the recorded neurons. Shape: (int(T/dt), nrec).
    """

    np.random.seed(seed)  # Seeding randomness for reproducibility

    """Setup weight matrix"""
    w = g * (randn(n, n)) * (rand(n, n) < p) / (np.sqrt(n) * p)
    # Set the row mean to zero
    row_means = np.mean(w, axis=1, where=np.abs(w) > 0)[:, None]
    row_means = np.repeat(row_means, w.shape[0], axis=1)
    w[np.abs(w) > 0] = w[np.abs(w) > 0] - row_means[np.abs(w) > 0]

    """Preinitialize recording"""
    nt = round(T/dt)  # Number of time steps
    rec = np.zeros((nt, nrec))

    """Initial conditions"""
    ipsc = np.zeros(n)  # Post synaptic current storage variable
    hm = np.zeros(n)  # Storage variable for filtered firing rates
    tlast = np.zeros((n))  # Used to set  the refractory times
    v = vreset + rand(n)*(30-vreset)  # Initialize neuron voltage

    """Start integration loop"""
    for i in np.arange(0, nt, 1):
        inp = ipsc + offset  # Total input current

        # Voltage equation with refractory period
        # Only change if voltage outside of refractory time period
        dv = (dt * i > tlast + tref) * (-v + inp) / tm
        v = v + dt*dv

        index = np.argwhere(v >= vpeak)[:, 0]  # Spiked neurons

        # Get the weight matrix column sum of spikers
        if len(index) > 0:
            # Compute the increase in current due to spiking
            jd = w[:, index].sum(axis=1)

        else:
            jd = 0*ipsc

        # Used to set the refractory period of LIF neurons
        tlast = (tlast + (dt * i - tlast) *
                 np.array(v >= vpeak, dtype=int))

        ipsc = ipsc * np.exp(-dt / tr) + hm * dt

        # Integrate the current
        hm = (hm * np.exp(-dt / td) + jd *
              (int(len(index) > 0)) / (tr * td))

        v = v + (30 - v) * (v >= vpeak)

        rec[i, :] = v[0:nrec]  # Record a random voltage
        v = v + (vreset - v) * (v >= vpeak)

    return rec


if __name__ == '__main__':
    rec = balanced_spiking_network()
    """PLOTTING"""
    fig, ax = plt.subplots(1)
    ax.plot(rec[:, 0] - 100.0)
    ax.plot(rec[:, 1])
    ax.plot(rec[:, 2] + 100.0)
Three neurons in the balanced spiking neural network.

The weight matrix

At the core of any balanced network is the weight matrix. We define it on line 54 to 58. Initializing it from a normal distribution and normalizing the row mean makes sure that excitation and inhibition are in balance. That is what keeps the network spiking irregularly although the input to the network remain constant. The constant input to the network is the offset parameter.

Refractory period

The refractory period is a time window where no action potential can be generated. We achieve this by setting the voltage to a low value right after the spike and then we do not update the voltage of the spike for a given time. This time window is given be tref. We update the voltage on line 76. In the same line we check how long ago the last spike occurred with the expression (dt * i > tlast + tref). Therefore, we need to track the most recent spike time with tlast. Of course we have some other things to do when a neuron reaches the spiking threshold vpeak. First we set the voltage to a value well above the threshold on line 99. This is purely visual to give a spiky appearance in the recording. So right after we recorded on line 101 we set the voltage to its reset value vreset.

Play around with some of the parameters. You can find the code here: https://gist.github.com/danielmk/9adc7409f40a076ffec0cdf85dea4519

Spiking Neuronal Networks in Python

Spiking neural networks (SNNs) turn some input into an output much like artificial neural networks (ANNs), which are already widely used today. Both achieve the same goal in different ways. The units of an ANN are single floating-point numbers that represent the activity levels of the units for a given input. Neuroscientists loosely understand this number as the average spike rate of a unit. In ANNs this number is usually the result of multiplying the input with a weight matrix. SNNs work differently in that they simulate units as spiking point processes. How often and when a unit spikes depends on the input and the connections between neurons. A spike causes a discontinuity in other connected units. SNNs are not yet widely used outside of laboratories but they are important to help neuroscientists model brain circuitry. Here we will create spiking point models and connect them. We will be using Brian2, a Python package to simulate SNNs. You can install it with either

conda install -c conda-forge brian2
or
pip install brian2

We will start by defining our spiking unit model. There are many different models of spiking. Here we will define a conductance based version of the leaky integrate-and-fire model.

from brian2 import *
import numpy as np
import matplotlib.pyplot as plt

start_scope()

# Neuronal Parameters
c = 100*pF
vl = -70*mV
gl = 5*nS

# Synaptic Parameters
ge_tau = 20*ms
ve = 0*mV
gi_tau = 100*ms
vi = -80*mV
w_ge = 1.0*nS
w_gi = 0.0*nS

lif = '''
dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(ve - vi) - I)/c : volt
dge/dt = -ge / ge_tau : siemens
dgi/dt = -gi / gi_tau : siemens
I : amp
'''

The intended way to import Brian2 is from brian2 import *. This feels a bit dangerous but it is very important to keep the code readable, especially when dealing with physical units. Speaking of units, each of our parameters has a physical unit. We have picofarad (pF), millivolt (mV) and nanosiemens (nS). These are all imported from Brian2 and we assign units with the multiplication operator *. To find out what each of those parameters does, we can look at our actual model. The model is defined by the string we assign to lif. The first line of the string is the model of our spiking unit:

dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(v - vi) - I)/c : volt

This is a differential equation that describes the change of voltage with respect to time. In electrical terms, voltage is changed by currents. These physical terms are not strictly necessary for the computational function of SNNs but they are helpful to understand the biological background behind them. Three currents flow in our model.

The first one is gl * (v - vl). This current is given by the conductance (gl), the current voltage (v) and the equilibrium voltage (vl). It flows whenever the voltage differs from vl. gl is just a scaling constant that determines how strong the drive towards vl is. This is why vl is called the resting potential, because v does not change when it is equal to vl. This term makes our model a leaky integrate-and-fire model as opposed to an integrate-and-fire model. Of course the voltage only remains at rest if it is not otherwise changed. There are three other currents that can do that. Two of them correspond to excitatory and inhibitory synaptic inputs. ge * (v - ve) drives the voltage towards ve. Under most circumstances, this will increase to voltage as ve equals 0mV. On the other hand gi *(ve - vi) drives the voltage towards vi, which is even slightly smaller than vl with -80mV. This current keeps the voltage low. Finally there is I, which is the input current that the model receives. We will define this later for each neuron. Finally, currents don’t change the voltage immediately. They are all slowed down by the capacitance c. Therefore, we divide the sum of all currents by c.

There are two more differential equations that describe our model:

dge/dt = -ge / ge_tau : siemens
dgi/dt = -gi / gi_tau : siemens

These describe the change of the excitatory and inhibitory synaptic conductances. We did not yet implement the way spiking changes ge or gi. However, these equations tell us that both ge and gi will decay towards zero with the time constants ge_tau and gi_tau. So far so good, but what about spiking? That comes up next, when we turn the string that represents our model into something that Brian2 can actually simulate.

G = NeuronGroup(3, lif, threshold='v > -40*mV',
                reset='v = vl', method='euler')
G.I = [0.7, 0.5, 0]*nA
G.v = [-70, -70, -70]*mV

The NeuronGroup creates for us three units that are defined by the equations in lif. The threshold parameter gives the condition to register a spike. In this case, we register a spike, when the voltage is larger than 40mV. When a spike is registered, an event is triggered, defined by reset. Once a spike is registered, we reset the voltage to the resting voltage vl. The method parameter gives the integration method to solve the differential equations.

Once our units are defined, we can interface with some parameters of the neurons. For example, G.I = [0.7, 0.5, 0]*nA sets the input current of the zeroth neuron to 0.7nA, the first neuron to 0.5nA and the last neuron to 0nA. Not all parameters are accessible like this. I is available because we defined I : amp in our lif string. This says that I is available for changes. Next, we define the initial state of our neurons. G.v = [-70, -70, -70]*mV sets all of them to the resting voltage. A good place to start.

You might be disappointed by this implementation of spiking. Where is the sodium? Where is the amplification of depolarization? The leaky integrate-and-fire model doesn’t feature a spiking mechanism, except for the discontinuity at the threshold. If you are interested in incorporating spike-like mechanisms you should look for the exponential leaky integrate-and-fire or a Hodgkin-Huxley like model.

We are only missing one more ingredient for an actual network model: the synapses. And we are in a great position, because our model already defines the excitatory and the inhibitory conductances. Now we just need to make use of them.

Se = Synapses(G, G, on_pre='ge_post += w_ge')
Se.connect(i=0, j=2)

Si = Synapses(G, G, on_pre='gi_post += w_gi')
Si.connect(i=1, j=2)

First we create an excitatory connection from our neurons onto themselves. The connection is excitatory because it increases the conductance ge of the postsynaptic neuron by w_ge. We then call Se.connect(i=0, j=2) to define which neurons actually connect. In this case, the zeroth neuron connects to the second neuron. This means, spikes in the zeroth neuron cause ge to increase in the second neuron. We then create the inhibitory connection and make the first neuron inhibit the second neuron. Now we are ready to run the actual simulation and plot the result. Remember that for now w_gi = 0.0*nS, meaning that we will only see the excitatory connection.

M = StateMonitor(G, 'v', record=True)

run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
Voltage traces from three neurons. Spiking of N0 increases voltage in N2. Spiking of N1 does nothing in this simulation because its weight onto N2 was set to 0mV.

The first two neurons are regularly spiking. N0 is slightly faster because it receives a larger input current than N1. N2 on the other hand rests at -70mV because it does not receive an input current. When N0 spikes, it causes the voltage of N2 to increase as we expected, because N0 increases the excitatory conductance. N1 is not doing anything here, because its weight is set to 0. Continued activity of N0 eventually causes N2 to reach the threshold of -40mV, making it register its own spike and reset the voltage. What happens if we introduce the inhibitory weight?

w_gi = 0.5*nS
run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
The same as above but now N1 has a weight of 0.5nS. This inhibits N2 and prevents a spike.

With a weight of 0.5nS, N1 inhibits N2 to an extent that prevents the spike. We have now built a network of leaky integrate-and-fire neurons that features both excitatory and inhibitory synapses. This is just the start to getting a functional network that does something interesting with the input. We will need to increase the number of neurons, decide on a connectivity rule between them, initialize or even learn weights, decide on a coding scheme for the output and much more. Many of these I will cover in later blog posts so stay tuned.