Balanced spiking neural networks with NumPy

Balanced spiking neural networks are a cornerstone of computational neuroscience. They make for a nice introduction into spiking neuronal networks and they can be trained to store information (Nicola & Clopath, 2017). Here I will present a Python port I made from a MATLAB implementation by Nicola & Clopath and I will go through some of its features. We only need NumPy.

import numpy as np
from numpy.random import rand, randn
import matplotlib.pyplot as plt

def balanced_spiking_network(dt=0.00005, T=2.0, tref=0.002, tm=0.01,
                             vreset=-65.0, vpeak=-40.0, n=2000, 
                             td=0.02, tr=0.002, p=0.1, 
                             offset=-40.00, g=0.04, seed=100, 
    """Simulate a balanced spiking neuronal network

    dt : float
        Sampling interval of the simulation.
    T : float
        Duration of the simulation.
    tref : float
        Refractory time of the neurons.
    tm : float
        Time constant of the neurons.
    vreset : float
        The voltage neurons are set to after a spike.
    vpeak : float
        The voltage above which a spike is triggered
    n : int
        The number of neurons.
    td : float
        Synaptic decay time constant.
    tr : float
        Synaptic rise time constant.
    p : float
        Connection probability between neurons.
    offset : float
        A constant input into all neurons.
    g : float
        Scaling factor of synaptic strength
    seed : int
        The seed makes NumPy random number generator deterministic.
    nrec : int
        The number of neurons to record.

        A 2D array of recorded voltages. Rows are time points,
        columns are the recorded neurons. Shape: (int(T/dt), nrec).

    np.random.seed(seed)  # Seeding randomness for reproducibility

    """Setup weight matrix"""
    w = g * (randn(n, n)) * (rand(n, n) < p) / (np.sqrt(n) * p)
    # Set the row mean to zero
    row_means = np.mean(w, axis=1, where=np.abs(w) > 0)[:, None]
    row_means = np.repeat(row_means, w.shape[0], axis=1)
    w[np.abs(w) > 0] = w[np.abs(w) > 0] - row_means[np.abs(w) > 0]

    """Preinitialize recording"""
    nt = round(T/dt)  # Number of time steps
    rec = np.zeros((nt, nrec))

    """Initial conditions"""
    ipsc = np.zeros(n)  # Post synaptic current storage variable
    hm = np.zeros(n)  # Storage variable for filtered firing rates
    tlast = np.zeros((n))  # Used to set  the refractory times
    v = vreset + rand(n)*(30-vreset)  # Initialize neuron voltage

    """Start integration loop"""
    for i in np.arange(0, nt, 1):
        inp = ipsc + offset  # Total input current

        # Voltage equation with refractory period
        # Only change if voltage outside of refractory time period
        dv = (dt * i > tlast + tref) * (-v + inp) / tm
        v = v + dt*dv

        index = np.argwhere(v >= vpeak)[:, 0]  # Spiked neurons

        # Get the weight matrix column sum of spikers
        if len(index) > 0:
            # Compute the increase in current due to spiking
            jd = w[:, index].sum(axis=1)

            jd = 0*ipsc

        # Used to set the refractory period of LIF neurons
        tlast = (tlast + (dt * i - tlast) *
                 np.array(v >= vpeak, dtype=int))

        ipsc = ipsc * np.exp(-dt / tr) + hm * dt

        # Integrate the current
        hm = (hm * np.exp(-dt / td) + jd *
              (int(len(index) > 0)) / (tr * td))

        v = v + (30 - v) * (v >= vpeak)

        rec[i, :] = v[0:nrec]  # Record a random voltage
        v = v + (vreset - v) * (v >= vpeak)

    return rec

if __name__ == '__main__':
    rec = balanced_spiking_network()
    fig, ax = plt.subplots(1)
    ax.plot(rec[:, 0] - 100.0)
    ax.plot(rec[:, 1])
    ax.plot(rec[:, 2] + 100.0)
Three neurons in the balanced spiking neural network.

The weight matrix

At the core of any balanced network is the weight matrix. We define it on line 54 to 58. Initializing it from a normal distribution and normalizing the row mean makes sure that excitation and inhibition are in balance. That is what keeps the network spiking irregularly although the input to the network remain constant. The constant input to the network is the offset parameter.

Refractory period

The refractory period is a time window where no action potential can be generated. We achieve this by setting the voltage to a low value right after the spike and then we do not update the voltage of the spike for a given time. This time window is given be tref. We update the voltage on line 76. In the same line we check how long ago the last spike occurred with the expression (dt * i > tlast + tref). Therefore, we need to track the most recent spike time with tlast. Of course we have some other things to do when a neuron reaches the spiking threshold vpeak. First we set the voltage to a value well above the threshold on line 99. This is purely visual to give a spiky appearance in the recording. So right after we recorded on line 101 we set the voltage to its reset value vreset.

Play around with some of the parameters. You can find the code here:

Animations with Matplotlib

Anything that can be plotted with Matplotlib can also be animated. This is especially useful when data changes over time. Animations allow us to see the dynamics in our data, which is nearly impossible with most static plots. Here we will learn how to animate with Matplotlib by producing this traveling wave animation.

This is the code to make the animation. It creates the traveling wave, defines two functions that handle the animation and creates the animation with the FuncAnimation class. Let’s take it step by step.

import numpy as np
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt

# Create the traveling wave
def wave(x, t, wavelength, speed):
    return np.sin((2*np.pi)*(x-speed*t)/wavelength)

x = np.arange(0,4,0.01)[np.newaxis,:]
t = np.arange(0,2,0.01)[:,np.newaxis]
wavelength = 1
speed = 1
yt = wave(x, t, wavelength, speed)  # shape is [t,y]

# Create the figure and axes to animate
fig, ax = plt.subplots(1)
# init_func() is called at the beginning of the animation
def init_func():

# update_plot() is called between frames
def update_plot(i):
    ax.plot(x[0,:], yt[i,:], color='k')

# Create animation
anim = FuncAnimation(fig,
                     frames=np.arange(0, len(t[:,0])),

# Save animation'traveling_wave.mp4',

On the first three lines we import NumPy, Matplotlib and most importantly the FuncAnimation class. It will take the center stage in our code as it will create the animation later on by combining all the parts we need. On lines 5-13 we create the traveling wave. I don’t want to go into too much detail, as it is just a toy example for the animation. The important part is that we get the array yt, which defines the wave at each time point. So yt[0] contains the wave at t0 , yt[1] at t1 and so on. This is important, since we will be iterating over time during the animation. If you want to learn more about the traveling wave, you can change wavelength, speed and play around with the wave() function.

Now that we have our wave, we can start preparing the animation. We create a the figure and the axes we want to use with plt.subplots(1). Then we create a the init_func(). This one will be called whenever the animation starts or repeats. In this particular example it is pretty useless. I include it here because it is a useful feature for more complex animations.

Now we get to update_plot(), the heart of our animation. This function updates our figure between frames. It determines what we see on each frame. It is the most important function and it is shockingly simple. The parameter i is an integer that defines what frame we are at. We use that integer as an index into the first dimension of yt. We plot the wave as it looks at t=i. Importantly, we must clean up our axes with ax.clear(). If we would forget about clearing, our plot would quickly become all black, filled with waves.

Now FuncAnimation is where it all comes together. We pass it fig, update_plot and init_func. We also pass frames, those are the values that i will take on during the animation. Technically, this gets the animation going in your interactive Python console but most of the time we want to save our animation. We do that by calling We pass it the file name as a string, the resolution in dpi, the frames per second and finally the writer class used for generating the animation. Not all writers work for all file formats. I prefer .mp4 with the ffmpeg writer. If there are issues with saving, the most common problem is that the writer we are trying to use is not installed. If you want to find out if the ffmpeg writer is available on your machine, you can type matplotlib.animation.FFMpegWriter().isAvailable(). It returns True if the writer is available and False otherwise. If you are using Anaconda you can install the codec from here.

This wraps up our tutorial. This particular example is very simple, but anything that can be plotted can also be animated. I hope you are now on your way to create your own animations. I will leave you with a more involved animation I created.

Smoothing Data by Rolling Average with NumPy

Time series data often comes with some amount of noise. One of the easiest ways to get rid of noise is to smooth the data with a simple uniform kernel, also called a rolling average. The title image shows data and their smoothed version. The data is the second discrete derivative from the recording of a neuronal action potential. Derivatives are notoriously noisy. We can get the result shown in the title image with np.convolve

import numpy as np

data = np.load("example_data.npy")
kernel_size = 10
kernel = np.ones(kernel_size) / kernel_size
data_convolved = np.convolve(data, kernel, mode='same')

Convolution is a mathematical operation that combines two arrays. One of those arrays is our data and we convolve it with the kernel array. During convolution we center the kernel at a data point. We multiple each data point in the kernel with each corresponding data point, sum up all the results and that is the new data point at the center. Let’s look at an example.

data = [2, 3, 1, 4, 1]
kernel = [1, 2, 3, 4]
np.convolve(data, kernel)
# array([ 2,  7, 13, 23, 24, 18, 19,  4])

For this result to make sense you must know, that np.convolve flips the kernel around. So step by step the calculations go as follows:

[4, 3, 2, 1]  # The flipped kernel
         [2, 3, 1, 4, 1]  # The data
          2= 2

   [4, 3, 2, 1]
          x  x
         [2, 3, 1, 4, 1]
          4+ 3= 7
         [2, 7]

      [4, 3, 2, 1]
          x  x  x
         [2, 3, 1, 4, 1]
          6+ 6+ 1=13
         [2, 7, 13]

         [4, 3, 2, 1]
          x  x  x  x
         [2, 3, 1, 4, 1]
          8+ 9+ 2+ 4= 23
         [2, 7,13,23]
# This continues until the arrays stop touching. You get the idea.

One thing you’ll notice is that the edges are problematic. There is really no good way to avoid that. Data points at the edges only see part of the kernel but the mode parameter defines what should happen at the edges. I prefer the 'same' mode because it means that the new array will have the same shape as the original data, which makes plotting easier. However, if you start to use more complicated kernels, the edges might become virtually useless. In that case, mode should be 'valid'. Then, the values at the edges that did not see the entire kernel are discarded. The output array is smaller in shape than the input array.

data = [2, 3, 1, 4, 1]
kernel = [1, 2, 3, 4]
np.convolve(data, kernel, mode='valid')
array([23, 24])

The default behavior you saw above is called 'full'. It keeps all data points, so the output array is larger in shape than the input array. You might also have noticed that the size of the kernel is very important. Actually, we need to divide the array of ones by its length. Can you guess what would happen if we forgot about dividing it?

If you guessed that the signal would become larger in magnitude you guessed right. We would be summing up all data points in the kernel. By dividing it we ensure that we take the average of the data points. But the kernel size is even more important. If we make the kernel larger the outcome changes dramatically.

kernel_size = 10
kernel = np.ones(kernel_size) / kernel_size
data_convolved_10 = np.convolve(data, kernel, mode='same')

kernel_size = 20
kernel = np.ones(kernel_size) / kernel_size
data_convolved_20 = np.convolve(data, kernel, mode='same')

plt.legend(("Kernel Size 10", "Kernel Size 20"))

The larger we make the kernel, the smaller sharp peaks become. The peaks are also shifted in time. To be specific, a rolling mean is a low-pass filter. This means that is leaves low frequency signals alone, while making high frequency signals smaller. Sharp increases in the data have a high frequency. If we make the kernel larger, the filter attenuates high frequency signals more. This is exactly how the rolling average works. It gets rid of high frequency noise. It also means that we must be careful not to distort the signal too much with the rolling average filter.

Threshold Detection in NumPy

Many signals are easily detected by their size. We will learn how to detect the indices where signals cross a threshold with NumPy. These are our practice signals.

import numpy as np
import matplotlib.pyplot as plt
data = np.array([0, 0, 0, 5, 5, 5, 5, 0, 0, 0, 0, 4, 4, 4, 0, 0, 0])
plt.plot(data, marker='o')

We will perform two simple steps to detect the threshold crossings: 1. Make the data binary, in a way that they are true when larger than the threshold and false when lower or equal. 2. Take the difference of the binary signal. This gives us a boolean array that is true when the threshold was crossed. We can combine those steps into one line.

threshold = 2
threshold_crossings = np.diff(data > threshold, prepend=False)

Plotting shows us that threshold_crossings is true after the threshold was crossed.

plt.plot(data2, marker='o')
plt.plot(thr_crossings, marker='o')
plt.legend(("Data", "Threshold Crossings"))

To get the indices of the threshold crossings we can use np.argwhere(), which returns the true indices from a boolean array.

# array([ 3,  7, 11, 14], dtype=int64)

Threshold crossings occur at 3, 7, 11 and 14. Sometimes we only need the upward or downward crossings. We can simply isolate those by slicing the indiced array.

np.argwhere(threshold_crossings)[::2,0]  # Upward crossings
# array([ 3, 11], dtype=int64)
np.argwhere(thr_crossings)[1::2,0]  # Downward crossings
# array([ 7, 14], dtype=int64)

Sometimes we want to find the point before the threshold is crossed, rather than after. There is one simple trick in np.diff instead of setting prepend=False, we set append=False.

threshold = 2
post_threshold_crossings = np.diff(data > threshold, prepend=False)
pre_threshold_crossings = np.diff(data > threshold, append=False)
plt.plot(data2, marker='o')
plt.plot(post_threshold_crossings, marker ='o')
plt.plot(pre_threshold_crossings, marker='o')
plt.legend(("Data", "Post Crossings", "Pre Crossings"))

Make sure to check out the documentation for np.diff and bonus points if you can figure out why exactly this works. Detecting threshold crossings is an easy but important part in most of my analysis pipelines and now you can do it too.

Comparisons and Logic Functions in NumPy

  • We can compare arrays with scalars and other arrays using Pythons standard comparison operators
  • There are two different boolean representations of an array. arr.all is True if all elements are True, arr.any is True if any element is True.
  • To perform logical functions we need the special NumPy functions np.logical_and, np.logical_or, np.logical_not, np.logical_xor


Logic functions allow us to check if logical statements about our arrays are true or false. Luckily, logic functions are very consistent with other array functions. They are performed element-wise by default and they can be performed on specific axes.

Comparing Arrays and Scalars

The same way we can use the arithmetic operators we can use all the logical operators: >, >=, <,< <=, ==, !=,

import numpy as np
arr = np.array([-1, 0, 1])
arr > 0
# array([False, False,  True])
arr >= 0
array([False,  True,  True])
arr < 0
# array([ True, False, False])
arr <= 0
# array([ True,  True, False])
arr == 0
# array([False,  True, False])
arr != 0
# array([ True, False,  True])

Comparing Arrays with Arrays

The result of a logic operation is a boolean array containing binary values of True or False. This particular case shows us logic operations of arrays with scalars. All boolean operations are performed element-wise, so each element is compared against the scalar. We can also compare arrays to arrays if their shape allows it.

arr_one = np.array([-1, 0, 1])
arr_two = np.array([1, 0, -1])
arr_one > arr_two
# array([False, False,  True])
arr_one >= arr_two
# array([False,  True,  True])
arr_one < arr_two
# array([ True, False, False])
arr_one <= arr_two
# array([ True,  True, False])
arr_one == arr_two
# array([False,  True, False])
arr_one != arr_two
array([ True, False,  True])

Truth Value of Arrays and Elements

Sometimes we need to check if an array contains any elements that are considered True in a boolean context. While the boolean value of array elements is well defined, the truth value of an entire array is not defined.

arr = np.array([0, 1])
# False
# True
# ValueError: The truth value of an array with more 
# than one element is ambiguous. Use a.any() or a.all()

NumPy error messages are great. This one is so great that it even tells us which method we need to use to get at the truth value of an array. We can use arr.any() to find out if any of the elements evaluate to True or arr.all() to find out if all elements are True

arr_one = np.array([0, 0])
# False
# False
arr_two = np.array([0, 1])
# True
# False
arr_three = np.array([1, 1])
# True
# True

This can be useful to find out whether an array is empty

arr = np.array([])
# False

Logical Operations

Finally, we need to look at four more logical operations: and, or, not & xor.
Unfortunately we can’t just use the Python keywords. The reason is in the error message above: “ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()”. It is ambiguous because NumPy does not know if we want to perform the operation element-wise or if we want to perform the operation on the truth value of the array. NumPy does not try to guess which one we mean, so it throws the error. To get these logical functions we need to call some more explicit NumPy functions.

arr_one = np.array([0,1,1])
arr_two = np.array([0,0,1])
np.logical_and(arr_one, arr_two)
# array([False, False,  True])
np.logical_or(arr_one, arr_two)
# array([False,  True,  True])
# array([ True, False, False])
np.logical_xor(arr_one, arr_two)
# array([False,  True, False])


We are now well equipped to deal with arrays. We can compare arrays with scalar values and other arrays using the the standard comparison operators. We can also perform logical operations on arrays with the special NumPy functions (logical_and, logical_or, logical_not and logcal_xor). Finally we can get two different boolean values of an arrays using arr.all and arr.any.

NumPy Array Data Type

  • Any array has a data type (dtype)
  • The dtype determines what kind of data is stored in the array
  • Not all operations work for all dtypes

Introduction to Data Types

Having a data type (dtype) is one of the key features that distinguishes NumPy arrays from lists. In lists, the types of elements can be mixed. One index of a list can contain an integer, another can contain a string. This is not the case for arrays. In an array, each element must be of the same type. This gives the array some of its efficiency, because operations can know in advance, what kind of data they will find in each element simply by looking up the data type. At the same time it makes arrays slightly less flexible, because some operations are undefined for some data types and we cannot assign any kind of data to an array. But how does NumPy decide what data type an array should have in the first place?

Guessing or Defining the dtype

So far we were able to create arrays effortlessly without knowing what dtype even means. That is because NumPy will just take a guess, what the dtype should be, based on the input it gets for the array.

arr = np.array([4, 3, 2])
# dtype('int32')
arr = np.array([4, 3.0, 2])
# dtype('float64')
arr = np.array([4, '3', 2])
# dtype('<U11')

In the first case, each element of the list we pass to the array constructor is an integer. Therefore, NumPy decides that the dtype should be integer (32 bit integer to be precise). In the second case, one of the elements (3.0) is a floating-point number. Floats are a more complex data type in Python, which means that all other data types have to follow the more complex one. Therefore, all elements of the array are converted to floats and are stored with the dtype float64. Strings are an even more complex dtype. Because ‘3’ is a string in the final example, the dtype becomes ‘<U11’. U stands for unicode, a type of string encoding and the number indicates the length of the string. In all three cases NumPy guesses the dtype according to the content of the list. This works well most of the time but we can also explicitly define the dtype.

arr = np.array([4, 3, 2], dtype=np.float)
# dtype('float64')
arr = np.array([4, 3, 2], dtype=np.str)
# dtype('<U1')
arr = np.array([4, 3, 2], dtype=np.bool)
# dtype('bool')

Converting arrays to other dtypes can be necessary because some operations will not work on arrays of mixed types. A dtype that is particularly problematic is the np.object dtype. It is the most flexible dtype but it can cause a lot of problems for both experts and beginners.

np.object and the Curse of Flexibility

Most dtypes are very specific. They let you know if the array contains a number (, np.float) or a string (all unicode ‘U’ dtypes). Not so much np.object. It tells you that whatever is inside the array is a thing. Because everything is an object anyway. This can make an array as flexible as a list. Anything can be stored. That is also where the problems come in.

arr = np.array([[3,2,1],[2,5]])
# dtype('O')  # 'O' means object
arr + 5
# TypeError: can only concatenate list (not "int") to list

Suddenly, the plus operation between an array and a scalar fails. What went wrong? Starting from the top, NumPy decides to assign the dtype of np.object to arr because the nested list entries have different lengths. Think of it this way: this array can neither be a (2, 3) nor a(2, 2) array of dtype integer. Therefore, NumPy makes it a (2,) array of dtype object. So the array contains two lists, the first one is of length 3 and the second one of length 2. NumPy generally turns anything that is more complex than a string into np.object. A list is one of those that gets turned into np.object. The error then occurs because the plus operation is not defined for a list with an integer. But that also means, that the operation will work, if the objects contained in the array so happen to work with the operation.

arr = np.array([3,2,1], dtype=np.object)
arr + 5
array([8, 7, 6], dtype=object)

This is one of the main problem of the np.object dtype. Operations work only sometimes and to know if an operation will work, each element has to be checked. With other dtypes, we know which operations will work just by just looking at it.


The dtype is one of the concepts that is closely related to the internal workings of NumPy. There is a lot that could be said about the details but effective beginners only need to remember a few points. First, the dtype determines what is stored in the array. All elements of an array have to conform to a specific type and dtype tells us which one. Second, NumPy guesses the dtype based on the literal data unless we specify which dtype we want. Guessing works most of the time but sometimes explicit types conversion is necessary. Third, operations that we know and love from numeric types (, np.float) may not work on other types (np.str, np.obect). This is particularly annoying for beginners. If you have hard to debug errors, find out what dtype your arrays actually have.

Array Indexing with NumPy

  • Indexing is used to retrieve or change elements of a an array
  • Slice syntax (start:stop:step) gets a range of elements
  • Integer and boolean arrays can get an arbitrary set of elements

Introduction to Array Indexing

Indexing is an important feature that allows us to retrieve and reassign specific parts on an array. You probably already know the basics of indexing from Python lists and tuples. You can index into NumPy arrays the same way you index into those sequences but NumPy indexing comes with many extra features we will learn about here. First, lets look at single value indexing.

Single Value Indexing

We can use indexing to get single (scalar) values from an array. Indexing is always done with square brackets and we always start counting at 0.

import numpy as np
arr = np.arange(10,15)
array([10, 11, 12, 13, 14])

Note that single value indexing does not return an array with a single entry but rather a numpy integer. To get a single value from a multi dimensional array we need to use multiple indices that are separated by commas.

arr = np.arange(20)
arr = arr.reshape((2,2,5))
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]]])

I recommend this way of indexing but you can also use multiple square brackets like you would for Python sequences.

array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]]])

We can also use indexing to reassign elements of an array.

arr = np.arange(10,15)
# array([10, 11, 12, 13, 14])
arr[1] = 20
# array([10, 20, 12, 13, 14])

Slice Indexing

To retrieve a single value, our indices need to resolve all dimensions of the array and arrive at a single value. Whenever one dimension remains unspecified, we get an array (array view technically).

arr = np.array([[[ 0,  1,  2,  3,  4],
                 [ 5,  6,  7,  8,  9]],
                [[10, 11, 12, 13, 14],
                 [15, 16, 17, 18, 19]]])
arr[0, 1]
array([5, 6, 7, 8, 9])
array([[10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

To take an entire dimension we can use the colon.

arr = np.array([[[ 0,  1,  2,  3,  4],
                 [ 5,  6,  7,  8,  9]],
                [[10, 11, 12, 13, 14],
                 [15, 16, 17, 18, 19]]])
arr[0, :, 0]
array([0, 5])
arr[:, 0, 0]
array([ 0, 10])

The colon is very useful for indexing in general, because it allows us to take a slice of values instead of a single value. The syntax of the slice follows start:stop:step. If we leave out start, the slice starts at 0. If we leave out stop, it goes to the end of the dimension. If we leave out step, the step defaults to 1.

arr = np.array([[[ 0,  1,  2,  3,  4],
                 [ 5,  6,  7,  8,  9]],
                [[10, 11, 12, 13, 14],
                 [15, 16, 17, 18, 19]]])
arr[0, 0, 1:5:2]
# array([1, 3])
arr[0, 0, 1:4]
# array([1, 2, 3])
arr[0, 0, 1:]
# array([1, 2, 3, 4])
arr[0, 0, :3]
# array([0, 1, 2])
arr[0, 0, :]
# array([0, 1, 2, 3, 4])

Index Array

So far we learned that we can use integers and slices for indexing. Now we learn that we can also use arrays to index into an array. When we use an array to index, that array has to either contain integers or boolean values. Lets take a look at integer array indexing first.

arr = np.arange(10,50,3)
idc = np.arange(5)
array([10, 13, 16, 19, 22])
idc = np.arange(5,8)
array([25, 28, 31])
idc = np.array([1,2,4])
array([13, 16, 22])

Note that in the examples where we generate index arrays with arange, we could achieve the same result with a slice as shown above and save one line of code. Integer arrays are most useful when they are generated by a process that is more complicated than the arange method. One example is the np.argwhere method we will learn more about in a later post.

Boolean Array

Boolean arrays also deserve at least one post of their own but here I will give you a teaser. We only want to retrieve those values, that satisfy a larger than condition.

arr = np.array([[[ 0,  1,  2,  3,  4],
                 [10, 11, 12, 13, 14]],
                 [[5,  6,  7,  8,  9],
                 [15, 16, 17, 18, 19]]])
boolean_idc = arr > 10
array([[[False, False, False, False, False],
        [False,  True,  True,  True,  True]],

       [[False, False, False, False, False],
        [ True,  True,  True,  True,  True]]])
array([11, 12, 13, 14, 15, 16, 17, 18, 19])


We learned that indexing is useful to retrieve values and reassign parts of an array. There are several ways to index. First, we can use single integers to get to an element of a certain dimension. We can also use slices with the colon syntax start:stop:step to get at a sequence of elements. Furthermore, there are two advances indexing techniques, where we can use arrays containing integers or booleans to find an arbitrary collection of elements.

Broadcasting in NumPy

  • Broadcasting is triggered when an arithmetic operation is done on two arrays of different shape
  • The goal of broadcasting is to make both arrays the same shape by performing transformations on the shape of the smaller array
  • Once arrays have the same shape, the operation is applied element-wise
  • If the arrays cannot be broadcast an error is raised

Broadcasting Introduction

When we try to add two arrays together with the plus operator, addition is performed element-wise. That means, each element is added to a corresponding element is the other array. However, this only works when both arrays have the same shape. If two arrays have different shapes, a process called broadcasting tries to resolve the difference between the arrays by performing a series of transformations on the shape of the array with lower dimensionality. To understand broadcasting we need to understand the steps broadcasting performs. Lets look at a quick example.

import numpy as np
arr_one = np.array([[4, 3, 2, 5, 6, 2],
                    [30, 34, 1, 50, 60, 56],
                    [22, 34, 32, 21, 12, 6]])
arr_two = np.array([1, 10, 20, 30, 40, 50])
(3, 6)
arrs_plus = arr_one + arr_two
array([[  5,  13,  22,  35,  46,  52],
       [ 31,  44,  21,  80, 100, 106],
       [ 23,  44,  52,  51,  52,  56]])

This one works despite both arrays having different shapes, even different number of elements. This next one does not work under seemingly similar circumstances.

arr_one = np.array([[4, 3, 2, 5, 6, 2],
                    [30, 34, 1, 50, 60, 56],
                    [22, 34, 32, 21, 12, 6]])
arr_two = np.array([4, 40, 20])
(3, 6)
arrs_plus = arr_one + arr_two
ValueError: operands could not be broadcast together with shapes (3,6) (3,)

What happened here? In the first example we add an array of shape (6,) to an array of shape (3, 6) and it works. In the second example we add an array of shape (3,) to a (3, 6) array and get an error. In both examples, the arrays have different shapes. Therefor, broadcasting is triggered and to understand what happens we need to understand the broadcasting sequence. Lets first work through the working example.

# Broadcasting rules in order
Rule #1: The array with fewer dimensions is broadcast to match
Rule #2: Array shapes are aligned to the right.
         (3, 6)
Rule #3: All array dimensions must be equal or one
         Otherwise broadcasting fails as: ValueError
         Here 6 is equal to six, so we don't get an error and continue
Rule #4: Array dimensions are expanded in the leftward direction
         (6, 3)
         (1, 3)
Rule #5: Array dimensions of size 1 are duplicated to match.
         (6, 3)
         (6, 3)
Done. Array operation can now be executed element wise.

These are the five broadcasting rules that are followed in order. They explain why operations between a (3,) and a (3, 6) array fail. After aligning both array shapes we encounter a problem. 3 does not equal 6, so rule #3 is violated and gives us the ValueError. There are several ways to make this operation work. However, it is best to first understand array indexing before delving into those. We will learn about array indexing in the next post.


If it weren’t for broadcasting we would have to manually convert the shape of arrays so that they are equal before we can perform arithmetic operations on them. Luckily, we learned that broadcasting always happens when we want to perform an operation on two arrays of different shapes. It tries to resolve the difference in shape by performing a series of steps on the smaller array. When broadcasting finishes successfully the operation can be performed element-wise. If broadcasting fails an error is raised.

NumPy Arrays and Shape

  • The same values can be stored in arrays with different shapes
  • Array methods can perform different operations depending on the array shape
  • The methods .reshape and .flatten change the shape of an array

Introducing Array Shape

Any array has a shape and the shape of an array is important for what kind of operations we can perform. Array shape is sometimes hard to imagine, even for experienced programmers so let’s just look at some code.

import numpy as np
my_array = np.array([3, 2, 5, 6, 3, 4])
my_array_reshaped = my_array.reshape((2,3))
(2, 3)
array([[3, 2, 5],
       [6, 3, 4]])

Here we create an array with 6 elements and my_array.shape tells us that these 6 elements are arranged in a single dimension that has a length of 6. We then reshape the array with its .reshape method into an array with two rows and three columns. This doesn’t look immediately useful but imagine we did an experiment under control and experimental condition with three replicates each. You’d clearly want a structure that represents this. Also, we went from a vector to a matrix with just one line of code. The most important part of array shape is that we can perform array methods only on specific dimensions. To do so we just need to pass the axis argument.

my_array = np.array([[3, 2, 5],
                     [6, 3, 4]])
dim0_sum = my_array.sum(axis=0)
array([9, 5, 9])
dim1_sum = my_array.sum(axis=1)
array([10, 13])

Remember that we start out with a (2, 3) array, 2 rows and 3 columns. When we call sum(axis=0) on that array the 0th dimension is eliminated. The array goes from a (2, 3) shape to a (3, ) shape. It does so by calculating the sum across the 0th dimension. Likewise, when we pass sum(axis=1) the 1st dimension gets eliminated in the same way and the array becomes a (2, ) array. The same concept works of course for arrays of any dimension. But lets get back to array shapes. An array cannot be converted to any shape its shape and limit the shapes it can take.

my_array = np.arange(30)  # A (30,) array
my_array_reshaped = my_array.reshape((5,6))
(5, 6)
my_array_reshaped = my_array.reshape((5,7))
ValueError: cannot reshape array of size 30 into shape (5,7)

Converting from (30,) to (5, 7) didn’t work for one simple reason. 5 times 7 is 35, not 30. In other words, the new array has more elements than the original array and NumPy will not just invent new elements to make reshaping work. If the number of elements checks out, we can reshape not only to two-dimensional arrays but to any dimension.

my_array = np.arange(30)  # A (30,) array
my_array_reshaped = my_array.reshape((5, 2, 3))
(5, 2, 3)
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23]],

       [[24, 25, 26],
        [27, 28, 29]]])

Of course we can also reshape from higher to lower dimensions.

my_array = np.array([[3, 2, 5],
                     [6, 3, 4]])
my_array_reshaped = my_array.reshape((6,))
(2, 3)

If you want combine all dimensions into one single dimension, you can use the .flatten method.

my_array = np.arange(30)  # A (30,) array
my_array_reshaped = my_array.reshape((5, 2, 3))
my_array_flattened = my_array_reshaped.flatten()

Why we need array shapes

We saw how to manipulate array shape and how array methods can use the shape of an array. Lets think a bit about the real world usage of array shape. Let’s say you are working on an image processing project. You are lucky and the images are already pre-processed in a way that each image has 64 pixels in both dimensions. So each image is an array of shape (64, 64) but your dataset consists of 1000 images. So you want your dataset to be stored as a (1000, 64, 64) array. But then your image processing project becomes a volume processing project. So each volume has 100 slices. So you need a (1000, 100, 64, 64) array. But wait. You are actually working on video files. There are 20000 frames for each volume. So you need a (1000, 20000, 100, 64, 64) array. It is rare that you will have to go beyond five dimensions, but you can. In several fields it is very easy to end up with five dimensional arrays (think fMRI).


Here we learned that the shape of an array is useful to store high dimensional data meaningfully and to have array methods operate only on specific dimensions. The .reshape method is important to change the shape of an existing array and the .flatten method can collapse an array into a single dimension. In the next blog post we will learn about broadcasting. Broadcasting is a mechanisms that is triggered whenever we perform an arithmetic operation on two arrays of different shapes (dimensionality). If two arrays have identical shape the operation is performed element-wise. If they have different shapes broadcasting performs a series of transformations on the lower dimensional array to make both arrays identical in shape and finally perform the operation element-wise.

Arithmetic Operations in NumPy

  • NumPy arrays come with many useful methods
  • All arithmetic operations that are used on arrays are performed element-wise
  • NumPy code is almost always faster than native Python (.append is a notable exception)

NumPy arrays are so useful because they allow us to do math on them very efficiently. For example, NumPy arrays come with many useful methods. One such method is the sum method, which calculates the sum of all values in the array

import numpy as np
my_array = np.array([4, 3, 1])

There are many other methods like this and they are extremely useful. Here is a list of the most commonly used methods.

my_array = np.array([4, 3, 1])
my_array.sum()  # Calculate the sum array values
my_array.mean()  # Calculate the mean of array values
my_array.std()  # Calculate the standard deviation of array values
my_array.max()  # Find the maximum value
my_array.min()  # Find the minimum value

To learn about all array methods you can call the dir() function on any array, which will list all its methods. Alternatively you can check out the documentation for the array

Another useful property of arrays is that they do math when they appear together with any of the arithmetic operators (+, -, *, /, **, //, %).

my_array = np.array([4, 3, 1])
my_array_plus = my_array + 2
array([6, 5, 3])

Here, the array appeared together with a scalar value, the single number 2. That number was added to each value. However, we can do the same thing with two arrays, if the have the same shape.

array_one = np.array([4, 3, 1])
array_two = np.array([1, 2, 4])
array_plus_array = array_one + array_two
array([5, 5, 5])

In this case, addition is again performed element-wise. Each element in array_one is added to a corresponding element in array_two. The fact that the array performs useful math in this context might seem unremarkable but remember how the native Python list behaves.

list_one = [4, 3, 1]
list_two = [1, 2, 4]
list_plus_list = list_one + list_two
[3, 2, 1, 1, 2, 4]
array_plus_array = np.array(list_one) + np.array(list_two)
array([5, 5, 5])

If you are in full numerical computation mode this behavior of list might seem stupid to you. But remember: Python is a general purpose programming language and list is a general purpose container to store a sequence of objects. There could be anything in those lists and addition might not be a meaningful operation for those objects. This behavior always works, a list can be concatenated to another list regardless of the objects they store. That’s why we have NumPy. Python has to implement objects in a way that suits its general purpose. NumPy implements behavior in a way that we would expect while we do numerical stuff.

A word on performance

This is one of the rare occasions where it is worthwhile to talk about performance. When you are getting started, I strongly recommend against thinking too much about performance. Write functioning code first, then worry about readability, maintainability, reproducibility etc. etc. and worry about performance last (trust me on this one). But some of you will be working with large amounts of data and you will be delighted to hear that NumPy is much faster than native Python.

my_array = np.random.rand(100000)  # A large array with 100000 elements
my_list = list(my_array)
timeit sum(my_list)
18.1 ms ± 801 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
timeit my_array.sum()
90.3 µs ± 6.86 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

The native Python version of sum is orders of magnitude slower than the NumPy version. You might have noticed that I created a very large array to demonstrate this. Actually the performance difference will increase with increasing array size, you can verify this for yourself. The take home message here is that whenever you can replace native Python with NumPy, you gain performance. But don’t worry about optimizing your NumPy code. One exception is the .append method, but more on that later.


We learned two essential things and one kind of interesting side-note. The first essential lesson is that arrays come with many methods that allow us to do useful math. We learned some of those methods and as you keep working with NumPy those will become second nature. The second thing we learned is that arithmetic operators are applied element-wise to arrays. This means that a scalar value is applied to each element in an array and whenever two arrays of the same shape appear together with an operator each element is applied to each corresponding element. We will learn the details of array shapes in the next blog post. Finally, we also learned that NumPy code is almost always much faster than native Python code. This is good to know. However, especially in the beginning you should focus on anything but performance.