Balanced spiking neural networks with NumPy

Balanced spiking neural networks are a cornerstone of computational neuroscience. They make for a nice introduction into spiking neuronal networks and they can be trained to store information (Nicola & Clopath, 2017). Here I will present a Python port I made from a MATLAB implementation by Nicola & Clopath and I will go through some of its features. We only need NumPy.

```import numpy as np
from numpy.random import rand, randn
import matplotlib.pyplot as plt

def balanced_spiking_network(dt=0.00005, T=2.0, tref=0.002, tm=0.01,
vreset=-65.0, vpeak=-40.0, n=2000,
td=0.02, tr=0.002, p=0.1,
offset=-40.00, g=0.04, seed=100,
nrec=10):
"""Simulate a balanced spiking neuronal network

Parameters
----------
dt : float
Sampling interval of the simulation.
T : float
Duration of the simulation.
tref : float
Refractory time of the neurons.
tm : float
Time constant of the neurons.
vreset : float
The voltage neurons are set to after a spike.
vpeak : float
The voltage above which a spike is triggered
n : int
The number of neurons.
td : float
Synaptic decay time constant.
tr : float
Synaptic rise time constant.
p : float
Connection probability between neurons.
offset : float
A constant input into all neurons.
g : float
Scaling factor of synaptic strength
seed : int
The seed makes NumPy random number generator deterministic.
nrec : int
The number of neurons to record.

Returns
-------
ndarray
A 2D array of recorded voltages. Rows are time points,
columns are the recorded neurons. Shape: (int(T/dt), nrec).
"""

np.random.seed(seed)  # Seeding randomness for reproducibility

"""Setup weight matrix"""
w = g * (randn(n, n)) * (rand(n, n) < p) / (np.sqrt(n) * p)
# Set the row mean to zero
row_means = np.mean(w, axis=1, where=np.abs(w) > 0)[:, None]
row_means = np.repeat(row_means, w.shape[0], axis=1)
w[np.abs(w) > 0] = w[np.abs(w) > 0] - row_means[np.abs(w) > 0]

"""Preinitialize recording"""
nt = round(T/dt)  # Number of time steps
rec = np.zeros((nt, nrec))

"""Initial conditions"""
ipsc = np.zeros(n)  # Post synaptic current storage variable
hm = np.zeros(n)  # Storage variable for filtered firing rates
tlast = np.zeros((n))  # Used to set  the refractory times
v = vreset + rand(n)*(30-vreset)  # Initialize neuron voltage

"""Start integration loop"""
for i in np.arange(0, nt, 1):
inp = ipsc + offset  # Total input current

# Voltage equation with refractory period
# Only change if voltage outside of refractory time period
dv = (dt * i > tlast + tref) * (-v + inp) / tm
v = v + dt*dv

index = np.argwhere(v >= vpeak)[:, 0]  # Spiked neurons

# Get the weight matrix column sum of spikers
if len(index) > 0:
# Compute the increase in current due to spiking
jd = w[:, index].sum(axis=1)

else:
jd = 0*ipsc

# Used to set the refractory period of LIF neurons
tlast = (tlast + (dt * i - tlast) *
np.array(v >= vpeak, dtype=int))

ipsc = ipsc * np.exp(-dt / tr) + hm * dt

# Integrate the current
hm = (hm * np.exp(-dt / td) + jd *
(int(len(index) > 0)) / (tr * td))

v = v + (30 - v) * (v >= vpeak)

rec[i, :] = v[0:nrec]  # Record a random voltage
v = v + (vreset - v) * (v >= vpeak)

return rec

if __name__ == '__main__':
rec = balanced_spiking_network()
"""PLOTTING"""
fig, ax = plt.subplots(1)
ax.plot(rec[:, 0] - 100.0)
ax.plot(rec[:, 1])
ax.plot(rec[:, 2] + 100.0)
```

The weight matrix

At the core of any balanced network is the weight matrix. We define it on line 54 to 58. Initializing it from a normal distribution and normalizing the row mean makes sure that excitation and inhibition are in balance. That is what keeps the network spiking irregularly although the input to the network remain constant. The constant input to the network is the `offset` parameter.

Refractory period

The refractory period is a time window where no action potential can be generated. We achieve this by setting the voltage to a low value right after the spike and then we do not update the voltage of the spike for a given time. This time window is given be `tref`. We update the voltage on line 76. In the same line we check how long ago the last spike occurred with the expression `(dt * i > tlast + tref)`. Therefore, we need to track the most recent spike time with `tlast`. Of course we have some other things to do when a neuron reaches the spiking threshold `vpeak`. First we set the voltage to a value well above the threshold on line 99. This is purely visual to give a spiky appearance in the recording. So right after we recorded on line 101 we set the voltage to its reset value `vreset`.

Play around with some of the parameters. You can find the code here: https://gist.github.com/danielmk/9adc7409f40a076ffec0cdf85dea4519

Spiking neural networks for a low-energy future

Spiking neural networks (SNNs) have some disadvantages compared to artificial neural networks (ANNs) but they have the potential to run for a fraction of the energy. Whether SNNs will be able to replace ANNs and how much energy they will be using depends on many engineering and neuroscience advances. Here I will go through some of the technical background of the SNN energy advantages and some of the current numbers.

Energy efficient SNN features

The energy efficiency of SNNs comes primarily from two features. Firstly, the spike is a discrete event and energy is only used when a spike occurs. This is probably the most fundamental feature that distinguishes SNNs from ANNs. This means that the energy efficiency of a SNN depends not only on the number of neurons but also on the number of spikes the model requires to perform. The second feature is local memory. At the heart of all models are parameters. On traditional hardware such as CPUs and GPUs, the part of the chip that performs the calculations is not the same that remembers the parameters. Loading the parameters onto the chip is much more energy intensive than the computation itself. Therefore, when the parameters can be stored locally on the chip that computes, efficiency advantages result. This is not something unique to spiking neural networks. Some tensor processing units (TPUs) also feature local memory and they are specifically designed for ANNs. When most people speak about the energy advantages of SNNs, they assume local memory.

SNNs also require specialized hardware to run efficiently. That hardware is called neuromorphic. It makes efficient use of the binary nature of spikes and local memory. Neuromorphic hardware is so far only available for research purpose and making it more widely available will be one of the challenges to SNN adoption. Next will be some numbers on energy efficiency.

How efficient are we talking?

How much more efficient SNNs are depends on many factors of the comparison. What is the task, what are the model architectures and what is the hardware. Making projections into the future is even harder, since machine learning advances are made quickly on both SNNs and ANNs. Projecting the absolute amount of energy that could be saved is then even harder because it requires AI demand predictions which can change non-linearly with technical advances. I would be interested in finding formal work on some of these uncertainties or work on some myself but for now here are some numbers.

The Loihi processor from Intel Labs is a recent piece of neuromorphic hardware. Depending on the size of their example problem they find that Loihi is 2.58x, 8.08x or 48.74x more energy efficient than a 1.67-GHz Atom CPU (Davies et al. 2018).

Yin et al. (2020) present a method to train SNNs (backpropagation of surrogate gradients). They calculate the theoretical energy consumption for a spiking recurrent network they train with the method and some ANN architectures. Depending on the task, their SNN was 126.2x, 935x, 1602x, 1776x or 3353.3x more efficient than a Long Short-Term Memory network (LSTM; also depends on some details of the LSTM implementation). Their network was 41.3x more efficient compared to a recurrent ANN. Here is a talk from the last author Sander Bohte where he summarizes the findings as >100x more efficient than best recurrent ANN and 1000x more efficient than LSTM. All their calculations assume local memory.

Panda et al. (2012) tried several methods to generate SNNs for image classification and calculated theoretical energy consumption. They estimate better efficiencies of SNNs of 6.52x, 7.7x, 10.6x, 74.9x, 81.3x, 104.8x depending on model architecture and parameter space.

Merolla et al. (2014) present the TrueNorth neuromorphic architecture. They compare synaptic operations per second (SOPS) of their architecture to floating-point operations per second (FLOPS) of traditional chips. They say that TrueNorth can deliver 46 billion SOPS per watt. The most energy-efficient supercomputer they say (at time of their writing) generates 4.5 billion FLOPS per watt.

These numbers highlight the potential for some massive energy savings but benchmarks are always complicated. Making good comparisons can be hard, especially since the unit of computational efficiency is fundamentally different. Either way, SNNs on neuromorphic hardware are extremely energy efficient but to truly save energy they must become better at the tasks ANNs already solve.

References

Davies et al. 2018. Loihi: A Neuromorphic Manycore Processor with On-Chip Learning. IEEE Micro. 10.1109/MM.2018.112130359

Yin, Corradi & Bohte 2020. Effective and Efficient Computation with Multiple-timescale Spiking Recurrent Neural Networks. https://arxiv.org/abs/2005.11633

Panda, Aketi & Roy, 2012. Towards Scalable, Efficient and Accurate Deep Spiking Neural Networks with Backward Residual Connections, Stochastic Softmax and Hybridization. https://arxiv.org/abs/1910.13931.

Merolla et al. (2014). A million spiking-neuron integrated circuit with a scalable communication network and interface. https://science.sciencemag.org/content/345/6197/668

Spiking Neuronal Networks in Python

Spiking neural networks (SNNs) turn some input into an output much like artificial neural networks (ANNs), which are already widely used today. Both achieve the same goal in different ways. The units of an ANN are single floating-point numbers that represent the activity levels of the units for a given input. Neuroscientists loosely understand this number as the average spike rate of a unit. In ANNs this number is usually the result of multiplying the input with a weight matrix. SNNs work differently in that they simulate units as spiking point processes. How often and when a unit spikes depends on the input and the connections between neurons. A spike causes a discontinuity in other connected units. SNNs are not yet widely used outside of laboratories but they are important to help neuroscientists model brain circuitry. Here we will create spiking point models and connect them. We will be using Brian2, a Python package to simulate SNNs. You can install it with either

```conda install -c conda-forge brian2
or
pip install brian2```

We will start by defining our spiking unit model. There are many different models of spiking. Here we will define a conductance based version of the leaky integrate-and-fire model.

```from brian2 import *
import numpy as np
import matplotlib.pyplot as plt

start_scope()

# Neuronal Parameters
c = 100*pF
vl = -70*mV
gl = 5*nS

# Synaptic Parameters
ge_tau = 20*ms
ve = 0*mV
gi_tau = 100*ms
vi = -80*mV
w_ge = 1.0*nS
w_gi = 0.0*nS

lif = '''
dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(ve - vi) - I)/c : volt
dge/dt = -ge / ge_tau : siemens
dgi/dt = -gi / gi_tau : siemens
I : amp
'''
```

The intended way to import Brian2 is `from brian2 import *`. This feels a bit dangerous but it is very important to keep the code readable, especially when dealing with physical units. Speaking of units, each of our parameters has a physical unit. We have picofarad (`pF`), millivolt (`mV`) and nanosiemens (`nS`). These are all imported from Brian2 and we assign units with the multiplication operator `*`. To find out what each of those parameters does, we can look at our actual model. The model is defined by the string we assign to `lif`. The first line of the string is the model of our spiking unit:

``dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(v - vi) - I)/c : volt``

This is a differential equation that describes the change of voltage with respect to time. In electrical terms, voltage is changed by currents. These physical terms are not strictly necessary for the computational function of SNNs but they are helpful to understand the biological background behind them. Three currents flow in our model.

The first one is `gl * (v - vl)`. This current is given by the conductance (`gl`), the current voltage (`v`) and the equilibrium voltage (`vl`). It flows whenever the voltage differs from `vl`. `gl` is just a scaling constant that determines how strong the drive towards `vl` is. This is why `vl` is called the resting potential, because `v` does not change when it is equal to `vl`. This term makes our model a leaky integrate-and-fire model as opposed to an integrate-and-fire model. Of course the voltage only remains at rest if it is not otherwise changed. There are three other currents that can do that. Two of them correspond to excitatory and inhibitory synaptic inputs. `ge * (v - ve)` drives the voltage towards `ve`. Under most circumstances, this will increase to voltage as `ve` equals 0mV. On the other hand `gi *(ve - vi)` drives the voltage towards `vi`, which is even slightly smaller than `vl` with -80mV. This current keeps the voltage low. Finally there is `I`, which is the input current that the model receives. We will define this later for each neuron. Finally, currents don’t change the voltage immediately. They are all slowed down by the capacitance `c`. Therefore, we divide the sum of all currents by `c`.

There are two more differential equations that describe our model:

``dge/dt = -ge / ge_tau : siemensdgi/dt = -gi / gi_tau : siemens``

These describe the change of the excitatory and inhibitory synaptic conductances. We did not yet implement the way spiking changes `ge` or `gi`. However, these equations tell us that both `ge` and `gi` will decay towards zero with the time constants `ge_tau` and `gi_tau`. So far so good, but what about spiking? That comes up next, when we turn the string that represents our model into something that Brian2 can actually simulate.

```G = NeuronGroup(3, lif, threshold='v > -40*mV',
reset='v = vl', method='euler')
G.I = [0.7, 0.5, 0]*nA
G.v = [-70, -70, -70]*mV
```

The `NeuronGroup` creates for us three units that are defined by the equations in `lif`. The `threshold` parameter gives the condition to register a spike. In this case, we register a spike, when the voltage is larger than 40mV. When a spike is registered, an event is triggered, defined by `reset`. Once a spike is registered, we reset the voltage to the resting voltage `vl`. The `method` parameter gives the integration method to solve the differential equations.

Once our units are defined, we can interface with some parameters of the neurons. For example, `G.I = [0.7, 0.5, 0]*nA` sets the input current of the zeroth neuron to 0.7nA, the first neuron to 0.5nA and the last neuron to 0nA. Not all parameters are accessible like this. `I` is available because we defined `I : amp` in our `lif` string. This says that `I` is available for changes. Next, we define the initial state of our neurons. `G.v = [-70, -70, -70]*mV` sets all of them to the resting voltage. A good place to start.

You might be disappointed by this implementation of spiking. Where is the sodium? Where is the amplification of depolarization? The leaky integrate-and-fire model doesn’t feature a spiking mechanism, except for the discontinuity at the threshold. If you are interested in incorporating spike-like mechanisms you should look for the exponential leaky integrate-and-fire or a Hodgkin-Huxley like model.

We are only missing one more ingredient for an actual network model: the synapses. And we are in a great position, because our model already defines the excitatory and the inhibitory conductances. Now we just need to make use of them.

```Se = Synapses(G, G, on_pre='ge_post += w_ge')
Se.connect(i=0, j=2)

Si = Synapses(G, G, on_pre='gi_post += w_gi')
Si.connect(i=1, j=2)
```

First we create an excitatory connection from our neurons onto themselves. The connection is excitatory because it increases the conductance `ge` of the postsynaptic neuron by `w_ge`. We then call `Se.connect(i=0, j=2)` to define which neurons actually connect. In this case, the zeroth neuron connects to the second neuron. This means, spikes in the zeroth neuron cause `ge` to increase in the second neuron. We then create the inhibitory connection and make the first neuron inhibit the second neuron. Now we are ready to run the actual simulation and plot the result. Remember that for now `w_gi = 0.0*nS`, meaning that we will only see the excitatory connection.

```M = StateMonitor(G, 'v', record=True)

run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
```

The first two neurons are regularly spiking. N0 is slightly faster because it receives a larger input current than N1. N2 on the other hand rests at -70mV because it does not receive an input current. When N0 spikes, it causes the voltage of N2 to increase as we expected, because N0 increases the excitatory conductance. N1 is not doing anything here, because its weight is set to 0. Continued activity of N0 eventually causes N2 to reach the threshold of -40mV, making it register its own spike and reset the voltage. What happens if we introduce the inhibitory weight?

```w_gi = 0.5*nS
run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
```

With a weight of 0.5nS, N1 inhibits N2 to an extent that prevents the spike. We have now built a network of leaky integrate-and-fire neurons that features both excitatory and inhibitory synapses. This is just the start to getting a functional network that does something interesting with the input. We will need to increase the number of neurons, decide on a connectivity rule between them, initialize or even learn weights, decide on a coding scheme for the output and much more. Many of these I will cover in later blog posts so stay tuned.