Eventplots with Color in Matplotlib

Here we will learn how to color individual events in a Matplotlib eventplot. An eventplot most commonly shows the timing of events (x axis) from different sources (y axis). Assigning different colors to the sources can be done easily with the colors parameter of the plt.eventplot function. However, color coding each event requires us to create the eventplot and then manually assign the colors with event_collection.set_colors. Here is the code and the resulting output.

import matplotlib as mpl
import matplotlib.pyplot as plt
import shelve
import os
import numpy as np

dirname = os.path.dirname(__file__)
data_path = os.path.join(dirname, 'evenplot_demo_spikes')

data = shelve.open(data_path)

fig, ax = plt.subplots(1)

event_collection = ax.eventplot(data['spikes'])
phases_flattened = np.hstack(data['phases'])
phase_min = np.quantile(phases_flattened, 0.05)
phase_max = np.quantile(phases_flattened, 0.95)

phases_norm = mpl.colors.Normalize(vmin=phase_min, vmax=phase_max)
normalized_phases = phases_norm(data['phases'])
viridis = mpl.colormaps['viridis']

for idx, col in enumerate(event_collection):
    col.set_colors(viridis(normalized_phases[idx]))

plt.colorbar(mpl.cm.ScalarMappable(norm=phases_norm, cmap=viridis),
             ax=ax, label="Phase", fraction=0.05)

We start out by loading the data with the shelve module and the os module helps us search for the data wherever our script is located (os.path.dirname(file)). Next, we create the figure and axis and plot the event times into the axis. The events will be plotted with the default blue of Matplotlib. We have to make sure to get all the events by assigning to event_collection. We want to color the events according to the data in data['phases']. That data contains a single floating point number for each event and we need to convert those numbers to a RGB color. For that, we first want to normalize our data to a value between zero and one with mpl.colors.Normalize(vmin=phase_min, vmax=phase_max). I use a trick here where I take the 5% quantiles instead of the actual maximum and minimum of the data. This should not be done for a scientific figure, because values above the vmax will have the same color although they have different values, which can be misleading. Once we have our normalized values we create a colormap object with mpl.colormaps['viridis'] and we get can get colors from it by passing it the normalized values. We then loop through our event collection (each entry in event collection is a row in our plot) and we set the color with col.set_colors(viridis(normalized_phases[idx])). Finally, we create a colorbar with plt.colorbar. And that is it. Let me know if there is an easier way to to this that I missed.

Balanced spiking neural networks with NumPy

Balanced spiking neural networks are a cornerstone of computational neuroscience. They make for a nice introduction into spiking neuronal networks and they can be trained to store information (Nicola & Clopath, 2017). Here I will present a Python port I made from a MATLAB implementation by Nicola & Clopath and I will go through some of its features. We only need NumPy.

import numpy as np
from numpy.random import rand, randn
import matplotlib.pyplot as plt


def balanced_spiking_network(dt=0.00005, T=2.0, tref=0.002, tm=0.01,
                             vreset=-65.0, vpeak=-40.0, n=2000, 
                             td=0.02, tr=0.002, p=0.1, 
                             offset=-40.00, g=0.04, seed=100, 
                             nrec=10):
    """Simulate a balanced spiking neuronal network

    Parameters
    ----------
    dt : float
        Sampling interval of the simulation.
    T : float
        Duration of the simulation.
    tref : float
        Refractory time of the neurons.
    tm : float
        Time constant of the neurons.
    vreset : float
        The voltage neurons are set to after a spike.
    vpeak : float
        The voltage above which a spike is triggered
    n : int
        The number of neurons.
    td : float
        Synaptic decay time constant.
    tr : float
        Synaptic rise time constant.
    p : float
        Connection probability between neurons.
    offset : float
        A constant input into all neurons.
    g : float
        Scaling factor of synaptic strength
    seed : int
        The seed makes NumPy random number generator deterministic.
    nrec : int
        The number of neurons to record.

    Returns
    -------
    ndarray
        A 2D array of recorded voltages. Rows are time points,
        columns are the recorded neurons. Shape: (int(T/dt), nrec).
    """

    np.random.seed(seed)  # Seeding randomness for reproducibility

    """Setup weight matrix"""
    w = g * (randn(n, n)) * (rand(n, n) < p) / (np.sqrt(n) * p)
    # Set the row mean to zero
    row_means = np.mean(w, axis=1, where=np.abs(w) > 0)[:, None]
    row_means = np.repeat(row_means, w.shape[0], axis=1)
    w[np.abs(w) > 0] = w[np.abs(w) > 0] - row_means[np.abs(w) > 0]

    """Preinitialize recording"""
    nt = round(T/dt)  # Number of time steps
    rec = np.zeros((nt, nrec))

    """Initial conditions"""
    ipsc = np.zeros(n)  # Post synaptic current storage variable
    hm = np.zeros(n)  # Storage variable for filtered firing rates
    tlast = np.zeros((n))  # Used to set  the refractory times
    v = vreset + rand(n)*(30-vreset)  # Initialize neuron voltage

    """Start integration loop"""
    for i in np.arange(0, nt, 1):
        inp = ipsc + offset  # Total input current

        # Voltage equation with refractory period
        # Only change if voltage outside of refractory time period
        dv = (dt * i > tlast + tref) * (-v + inp) / tm
        v = v + dt*dv

        index = np.argwhere(v >= vpeak)[:, 0]  # Spiked neurons

        # Get the weight matrix column sum of spikers
        if len(index) > 0:
            # Compute the increase in current due to spiking
            jd = w[:, index].sum(axis=1)

        else:
            jd = 0*ipsc

        # Used to set the refractory period of LIF neurons
        tlast = (tlast + (dt * i - tlast) *
                 np.array(v >= vpeak, dtype=int))

        ipsc = ipsc * np.exp(-dt / tr) + hm * dt

        # Integrate the current
        hm = (hm * np.exp(-dt / td) + jd *
              (int(len(index) > 0)) / (tr * td))

        v = v + (30 - v) * (v >= vpeak)

        rec[i, :] = v[0:nrec]  # Record a random voltage
        v = v + (vreset - v) * (v >= vpeak)

    return rec


if __name__ == '__main__':
    rec = balanced_spiking_network()
    """PLOTTING"""
    fig, ax = plt.subplots(1)
    ax.plot(rec[:, 0] - 100.0)
    ax.plot(rec[:, 1])
    ax.plot(rec[:, 2] + 100.0)
Three neurons in the balanced spiking neural network.

The weight matrix

At the core of any balanced network is the weight matrix. We define it on line 54 to 58. Initializing it from a normal distribution and normalizing the row mean makes sure that excitation and inhibition are in balance. That is what keeps the network spiking irregularly although the input to the network remain constant. The constant input to the network is the offset parameter.

Refractory period

The refractory period is a time window where no action potential can be generated. We achieve this by setting the voltage to a low value right after the spike and then we do not update the voltage of the spike for a given time. This time window is given be tref. We update the voltage on line 76. In the same line we check how long ago the last spike occurred with the expression (dt * i > tlast + tref). Therefore, we need to track the most recent spike time with tlast. Of course we have some other things to do when a neuron reaches the spiking threshold vpeak. First we set the voltage to a value well above the threshold on line 99. This is purely visual to give a spiky appearance in the recording. So right after we recorded on line 101 we set the voltage to its reset value vreset.

Play around with some of the parameters. You can find the code here: https://gist.github.com/danielmk/9adc7409f40a076ffec0cdf85dea4519

Spiking Neuronal Networks in Python

Spiking neural networks (SNNs) turn some input into an output much like artificial neural networks (ANNs), which are already widely used today. Both achieve the same goal in different ways. The units of an ANN are single floating-point numbers that represent the activity levels of the units for a given input. Neuroscientists loosely understand this number as the average spike rate of a unit. In ANNs this number is usually the result of multiplying the input with a weight matrix. SNNs work differently in that they simulate units as spiking point processes. How often and when a unit spikes depends on the input and the connections between neurons. A spike causes a discontinuity in other connected units. SNNs are not yet widely used outside of laboratories but they are important to help neuroscientists model brain circuitry. Here we will create spiking point models and connect them. We will be using Brian2, a Python package to simulate SNNs. You can install it with either

conda install -c conda-forge brian2
or
pip install brian2

We will start by defining our spiking unit model. There are many different models of spiking. Here we will define a conductance based version of the leaky integrate-and-fire model.

from brian2 import *
import numpy as np
import matplotlib.pyplot as plt

start_scope()

# Neuronal Parameters
c = 100*pF
vl = -70*mV
gl = 5*nS

# Synaptic Parameters
ge_tau = 20*ms
ve = 0*mV
gi_tau = 100*ms
vi = -80*mV
w_ge = 1.0*nS
w_gi = 0.0*nS

lif = '''
dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(ve - vi) - I)/c : volt
dge/dt = -ge / ge_tau : siemens
dgi/dt = -gi / gi_tau : siemens
I : amp
'''

The intended way to import Brian2 is from brian2 import *. This feels a bit dangerous but it is very important to keep the code readable, especially when dealing with physical units. Speaking of units, each of our parameters has a physical unit. We have picofarad (pF), millivolt (mV) and nanosiemens (nS). These are all imported from Brian2 and we assign units with the multiplication operator *. To find out what each of those parameters does, we can look at our actual model. The model is defined by the string we assign to lif. The first line of the string is the model of our spiking unit:

dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(v - vi) - I)/c : volt

This is a differential equation that describes the change of voltage with respect to time. In electrical terms, voltage is changed by currents. These physical terms are not strictly necessary for the computational function of SNNs but they are helpful to understand the biological background behind them. Three currents flow in our model.

The first one is gl * (v - vl). This current is given by the conductance (gl), the current voltage (v) and the equilibrium voltage (vl). It flows whenever the voltage differs from vl. gl is just a scaling constant that determines how strong the drive towards vl is. This is why vl is called the resting potential, because v does not change when it is equal to vl. This term makes our model a leaky integrate-and-fire model as opposed to an integrate-and-fire model. Of course the voltage only remains at rest if it is not otherwise changed. There are three other currents that can do that. Two of them correspond to excitatory and inhibitory synaptic inputs. ge * (v - ve) drives the voltage towards ve. Under most circumstances, this will increase to voltage as ve equals 0mV. On the other hand gi *(ve - vi) drives the voltage towards vi, which is even slightly smaller than vl with -80mV. This current keeps the voltage low. Finally there is I, which is the input current that the model receives. We will define this later for each neuron. Finally, currents don’t change the voltage immediately. They are all slowed down by the capacitance c. Therefore, we divide the sum of all currents by c.

There are two more differential equations that describe our model:

dge/dt = -ge / ge_tau : siemens
dgi/dt = -gi / gi_tau : siemens

These describe the change of the excitatory and inhibitory synaptic conductances. We did not yet implement the way spiking changes ge or gi. However, these equations tell us that both ge and gi will decay towards zero with the time constants ge_tau and gi_tau. So far so good, but what about spiking? That comes up next, when we turn the string that represents our model into something that Brian2 can actually simulate.

G = NeuronGroup(3, lif, threshold='v > -40*mV',
                reset='v = vl', method='euler')
G.I = [0.7, 0.5, 0]*nA
G.v = [-70, -70, -70]*mV

The NeuronGroup creates for us three units that are defined by the equations in lif. The threshold parameter gives the condition to register a spike. In this case, we register a spike, when the voltage is larger than 40mV. When a spike is registered, an event is triggered, defined by reset. Once a spike is registered, we reset the voltage to the resting voltage vl. The method parameter gives the integration method to solve the differential equations.

Once our units are defined, we can interface with some parameters of the neurons. For example, G.I = [0.7, 0.5, 0]*nA sets the input current of the zeroth neuron to 0.7nA, the first neuron to 0.5nA and the last neuron to 0nA. Not all parameters are accessible like this. I is available because we defined I : amp in our lif string. This says that I is available for changes. Next, we define the initial state of our neurons. G.v = [-70, -70, -70]*mV sets all of them to the resting voltage. A good place to start.

You might be disappointed by this implementation of spiking. Where is the sodium? Where is the amplification of depolarization? The leaky integrate-and-fire model doesn’t feature a spiking mechanism, except for the discontinuity at the threshold. If you are interested in incorporating spike-like mechanisms you should look for the exponential leaky integrate-and-fire or a Hodgkin-Huxley like model.

We are only missing one more ingredient for an actual network model: the synapses. And we are in a great position, because our model already defines the excitatory and the inhibitory conductances. Now we just need to make use of them.

Se = Synapses(G, G, on_pre='ge_post += w_ge')
Se.connect(i=0, j=2)

Si = Synapses(G, G, on_pre='gi_post += w_gi')
Si.connect(i=1, j=2)

First we create an excitatory connection from our neurons onto themselves. The connection is excitatory because it increases the conductance ge of the postsynaptic neuron by w_ge. We then call Se.connect(i=0, j=2) to define which neurons actually connect. In this case, the zeroth neuron connects to the second neuron. This means, spikes in the zeroth neuron cause ge to increase in the second neuron. We then create the inhibitory connection and make the first neuron inhibit the second neuron. Now we are ready to run the actual simulation and plot the result. Remember that for now w_gi = 0.0*nS, meaning that we will only see the excitatory connection.

M = StateMonitor(G, 'v', record=True)

run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
Voltage traces from three neurons. Spiking of N0 increases voltage in N2. Spiking of N1 does nothing in this simulation because its weight onto N2 was set to 0mV.

The first two neurons are regularly spiking. N0 is slightly faster because it receives a larger input current than N1. N2 on the other hand rests at -70mV because it does not receive an input current. When N0 spikes, it causes the voltage of N2 to increase as we expected, because N0 increases the excitatory conductance. N1 is not doing anything here, because its weight is set to 0. Continued activity of N0 eventually causes N2 to reach the threshold of -40mV, making it register its own spike and reset the voltage. What happens if we introduce the inhibitory weight?

w_gi = 0.5*nS
run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
The same as above but now N1 has a weight of 0.5nS. This inhibits N2 and prevents a spike.

With a weight of 0.5nS, N1 inhibits N2 to an extent that prevents the spike. We have now built a network of leaky integrate-and-fire neurons that features both excitatory and inhibitory synapses. This is just the start to getting a functional network that does something interesting with the input. We will need to increase the number of neurons, decide on a connectivity rule between them, initialize or even learn weights, decide on a coding scheme for the output and much more. Many of these I will cover in later blog posts so stay tuned.

Differential Equations with SciPy – odeint or solve_ivp

As a scientist I have the privilege to always learn and discover new things. I have recently started a new blog called Microbe Food, about single cell food production, you might be interested in. This leaves my programming blog mostly dormant for the moment but I hope the content is still useful.

SciPy features two different interfaces to solve differential equations: odeint and solve_ivp. The newer one is solve_ivp and it is recommended but odeint is still widespread, probably because of its simplicity. Here I will go through the difference between both with a focus on moving to the more modern solve_ivp interface. The primary advantage is that solve_ivp offers several methods for solving differential equations whereas odeint is restricted to one. We get started by setting up our system of differential equations and some parameters of the simulation.

UPDATE 07.02.2021: Note that I am focusing here on usage of the different interfaces and not on benchmarking accuracy or performance. Faruk Krecinic has contacted me and noted that assessing accuracy would require a system with a known solution as a benchmark. This is beyond the scope of this blog post. He also pointed out to me that the hmax parameter of odeint is important for the extent to which both interfaces give similar results. You can learn more about the parameters of odeint in the docs.

import numpy as np
from scipy.integrate import odeint, solve_ivp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def lorenz(t, state, sigma, beta, rho):
    x, y, z = state
    
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    
    return [dx, dy, dz]

sigma = 10.0
beta = 8.0 / 3.0
rho = 28.0

p = (sigma, beta, rho)  # Parameters of the system

y0 = [1.0, 1.0, 1.0]  # Initial state of the system

We will be using the Lorenz system. We can directly move on the solving the system with both odeint and solve_ivp.

t_span = (0.0, 40.0)
t = np.arange(0.0, 40.0, 0.01)

result_odeint = odeint(lorenz, y0, t, p, tfirst=True)
result_solve_ivp = solve_ivp(lorenz, t_span, y0, args=p)

fig = plt.figure()
ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot(result_odeint[:, 0],
        result_odeint[:, 1],
        result_odeint[:, 2])
ax.set_title("odeint")

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot(result_solve_ivp.y[0, :],
        result_solve_ivp.y[1, :],
        result_solve_ivp.y[2, :])
ax.set_title("solve_ivp")
Simulation results from odeint and solve_ivp. Note that the solve_ivp looks very different primarily because of the default temporal resolution that is applied. Changing the the temporal resolution and getting very similar results to odeint is easy and shown below.

The first thing that sticks out is that the solve_ivp solution is less smooth. That is because it is calculated at fewer time points, which in turn has to do with the difference between t_span and t. The odeint interface expects t, an array of time points for which we want to calculate the solution. The temporal resolution of the system is given by the interval between time points. The solve_ivp interface on the other hand expects t_span, a tuple that gives the start and end of the simulation interval. solve_ivp determines the temporal resolution by itself, depending on the integration method and the desired accuracy of the solution. We can confirm that the temporal resolution of solve_ivp is lower in this example by inspecting the output of both functions.

t.shape
# (4000,)

result_odeint.shape
# (4000, 3)

result_solve_ivp.t.shape
# (467,)

The t array has 4000 elements and therefore the result of odeint has 4000 rows, each row being a time point defined by t. The result of solve_ivp is different. It has its own time array as an attribute and it has 1989 elements. This tells us that solve_ivp indeed calculated fewer time points affection temporal resolution. So how can we increase the the number of time points in solve_ivp? There are three ways: 1. We can manually define the time points to integrate, similar to odeint. 2. We can decrease the error of the solution we are willing to tolerate. 3. We can change to a more accurate integration method. We will first change the integration method. The default integration method of solve_ivp is RK45 and we will compare it to the default method of odeint, which is LSODA.

solve_ivp_rk45 = solve_ivp(lorenz, t_span, y0, args=p,
                            method='RK45')
solve_ivp_lsoda = solve_ivp(lorenz, t_span, y0, args=p,
                           method='LSODA')

fig = plt.figure()
ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot(solve_ivp_rk45.y[0, :],
        solve_ivp_rk45.y[1, :],
        solve_ivp_rk45.y[2, :])
ax.set_title("RK45")

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot(solve_ivp_lsoda.y[0, :],
        solve_ivp_lsoda.y[1, :],
        solve_ivp_lsoda.y[2, :])
ax.set_title("LSODA")
Comparison between RK45 and LSODA integration methods of solve_ivp.

The LSODA method is already more accurate but we can make it even more accurate but it is still not as accurate as the solution we got from odeint. That is because we made odeint solve at even higher temporal resolution when we passed it t. To get the exact same result from solve_ivp we got from odeint, we must pass it the exact time points we want to solve with the t_eval parameter.

t = np.arange(0.0, 40.0, 0.01)
result_odeint = odeint(lorenz, y0, t, p, tfirst=True)
result_solve_ivp = solve_ivp(lorenz, t_span, y0, args=p,
                             method='LSODA', t_eval=t)

fig = plt.figure()
ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot(result_odeint[:, 0],
        result_odeint[:, 1],
        result_odeint[:, 2])
ax.set_title("odeint")

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot(result_solve_ivp.y[0, :],
        result_solve_ivp.y[1, :],
        result_solve_ivp.y[2, :])
ax.set_title("solve_ivp LSODA")
odeint and solve_ivp with identical integration method and temporal resolution.

Now both solutions have identical temporal resolution. But their solution is still not identical. I was unable to confirm why that is the case but I suspect very small floating point errors. The Lorenz attractor is a chaotic system and even small errors can make it diverge. The following plot shows the first variable of the system for odeint and solve_ivp from the above simulation. It confirms my suspicion that floating point accuracy is to blame but was unable to confirm the source.

Solutions from odeint and solve_ivp diverge even for identical temporal resolution and integration method.

There is one more way to make the system more smooth: decrease the tolerated error. We will not go into that here because it does not relate to the difference between odeint and solve_ivp. It can be used in both interfaces. There are some more subtle differences. For example, by default, odeint expects the first parameter of the problem to be the state and the second parameter to be time t. This is the other way around for solve_ivp. To make odeint accept a problem definition where time is the first parameter we set the parameter tfirst=True.

In summary, solve_ivp offers several integration methods while odeint only uses LSODA.

Differential Equations in Python with SciPy

Differential equations are special because they don’t tell us the value of a variable straight up. Instead, they tell us by how much the variable will change with respect to the change of another variable. Usually that other variable is time. To numerically solve a system of differential equations we need to track the systems change over time starting at an initial state. This process is called numerical integration and there is a SciPy function for it called odeint. We will learn how to use this package by simulating the ‘hello world’ of differential equations: the Lorenz system.

Here is the first part of the code where we define the function that describes the dynamics of the system.

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
from mpl_toolkits.mplot3d import Axes3D

def lorenz(state, t, sigma, beta, rho):
    x, y, z = state
    
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    
    return [dx, dy, dz]

We start with some imports. Of course we need NumPy and odeint is imported from scipy.integrat. Matplotlib will be used to plot the result of our simulation. After that we define the system of differential equations that defines our Lorenz system. It consists of three differential equations that we fit into one function called lorenz. This function needs a specific call signature (lorenz(state, t, sigma, beta, rho)) because we will later pass it to odeint which expects specific parameters in specific places. Most importantly, the first parameter must be the state of the system.The state of the Lorenz system is defined by three variables: x, y, z. Our state object has to be a sequence with an order that reflects this.

Inside the lorenz function, the first thing we do is to unpack the state into the three state variables. This is followed by the three differential equations that described the dynamic changes of the state variables. The fact that the variable t does not show up in any of these equations is a common point of confusion. The amount of change certainly depends on the amount of time. So why can we ignore t here? The answer is that our numerical integrator will keep track of t for us. For this particular system we could actually build a function that does not take the parameter t but I include it because it can be useful if you want to add discontinuities that depend on t.

While t does not appear in the equations, sigma, beta & rho do. They are the parameters of the system and the system’s properties depend on them. We will set those parameters next.

sigma = 10.0
beta = 8.0 / 3.0
rho = 28.0

p = (sigma, beta, rho)

These are the parameters Lorenz himself used and they are known to produce the type of dynamic that the Lorenz system is most known for: the Lorenz attractor. It is important that we store these parameters in a tuple in this exact order because of our functions structure. It must be a tuple rather than another type of collection because odeint expects it. Now that our parameters are defined, we will move on to define the initial values of the system. This is a critical part of solving differential equations. These equations tell us by how much the system state changes but they cannot tell us where to start.

y0 = [1.0, 1.0, 1.0]

Our system will start with all variables at 1.0. Now we can solve the system and plot the result.

t = np.arange(0.0, 40.0, 0.01)

result = odeint(lorenz, y0, t, p)

fig = plt.figure()
ax = fig.gca(projection="3d")
ax.plot(result[:, 0], result[:, 1], result[:, 2])
The Lorenz attractor.

We solve the system with a simple call to to odeint and we pass it the function that defines out system, the initial state, the time points t for which we want to solve the system and the parameters p. result is a two-dimensional array where the rows are the time points and the columns are the state variables at that those time points. And this is how we can solve differential equations with SciPy.

Getting Started with Pandas DataFrame

A DataFrame is a spreadsheet like data structure. We can think of it as a collection of rows and columns. This row-column structure is useful for many different kinds of data. The most widely used DataFrame implementation in Python is from the Pandas package. First we will learn how to create DataFrames. We will also learn how to do some basic data analysis with them. Finally, we will compare the DataFrame to the ndarray data structure and learn why data frames are useful in other packages such as Seaborn.

How to Create a DataFrame

There two major ways to create a DataFrame. We can directly call DataFrame() and pass it data in a dictionary, list or array. Alternatively we can use several functions to load data from a file directly into a DataFrame. While it is very common in data science to load data from file, there are also many occasions where we need to create DataFrame from other data structures. We will first learn how to create a DataFrame from a dictionary.

import pandas as pd
d = {"Frequency": [20, 50, 8],
     "Location": [2, 3, 1],
     "Cell Type": ["Interneuron", "Interneuron", "Pyramidal"]}
row_names = ["C1", "C2", "C3"]
df = pd.DataFrame(d, index=row_names)
print(df)

"""
    Frequency  Location    Cell Type
C1         20         2  Interneuron
C2         50         3  Interneuron
C3          8         1    Pyramidal
"""

In our dictionary the keys are used as the column names. The data under each key then becomes the column. The row names are defined separately by passing a collection to the index parameter of DataFrame. We can get column and row names with the columns and index attributes.

df.columns
# Index(['Freq (Hz)', 'Loc (cm)', 'Cell Type'], dtype='object')
df.index
# Index(['C1', 'C2', 'C3'], dtype='object')

We can also change column and row names through those same attributes.

df.index = ["Cell_1", "Cell_2", "Cell_3"]
df.columns = ["Freq (Hz)", "Loc (cm)", "Cell Type"]
"""
        Freq (Hz)  Loc (cm)    Cell Type
Cell_1         20         2  Interneuron
Cell_2         50         3  Interneuron
Cell_3          8         1    Pyramidal
"""

These names are useful because they give us a descriptive way of indexing into columns and rows. If we use indexing syntax on the DataFrame, we can get individual columns.

df['Freq (Hz)']
"""
Cell_1    20
Cell_2    50
Cell_3     8
Name: Freq (Hz), dtype: int64
"""

Row names are not found this way and using a row key will raise an error. However, we can get rows with the df.loc attribute.

df['Cell_1']
# KeyError: 'Cell_1'
df.loc['Cell_1']
"""
Freq (Hz)             20
Loc (cm)               2
Cell Type    Interneuron
Name: Cell_1, dtype: object
"""

We could also create a DataFrame from other kinds of collections that are not dictionaries. For example we can use a list.

d = [[20, 2, "Interneuron"],
     [50, 3, "Interneuron"],
     [8, 1, "Pyramidal"]]
column_names = ["Frequency", "Location", "Cells"]
row_names = ["C1", "C2", "C3"]
df = pd.DataFrame(d, columns=column_names, index=row_names)
print(df)
"""
    Frequency  Location        Cells
C1         20         2  Interneuron
C2         50         3  Interneuron
C3          8         1    Pyramidal
"""

In that case there are no dictionary keys that could be use to infer the column names. This means we need to pass the column_names to the columns parameter. Mostly anything that structures our data in a two-dimensional way can be used to create a DataFrame. Next we will learn about functions that allow us to load different file types as a DataFrame.

Loading Files as a DataFrame

The list of file types Pandas can read and write is rather long and you can find it here. I only want to cover the most commonly used .csv file here. They have the particular advantage that they can also be read by humans, because they are essentially text files. They are also widely supported by a variety of languages and programs. First, let’s create our file. Because it is a text file, we can write a literal string to file.

text_file = open("example.csv", "w")
text_file.write(""",Frequency,Location,Cell Type
                 C1,20,2,Interneuron
                 C2,50,3,Interneuron
                 C3,8,1,Pyramidal""")
text_file.close()

In this file columns are separated by commas and rows are separated by new lines. This is what .csv means, it stands for comma-separated values. To load this file into a DataFrame we need to pass the file name and which column contains the row names. Pandas assumes by default that the first row contains the column names.

df = pd.read_csv("example.csv", index_col=0)
print(df)
"""
     Frequency  Location    Cell Type
 C1         20         2  Interneuron
 C2         50         3  Interneuron
 C3          8         1    Pyramidal
"""

There are many more parameters we can specify for read_csv in case we have a file that is structured differently. In fact we can load files that have a value delimiter other than the comma, by specifying the delimiter parameter.

text_file = open("example.csv", "w")
text_file.write("""-Frequency-Location-Cell Type
                 C1-20-2-Interneuron
                 C2-50-3-Interneuron
                 C3-8-1-Pyramidal""")
text_file.close()
df = pd.read_csv("example.csv", index_col=0, delimiter='-')
print(df)
"""
     Frequency  Location    Cell Type
 C1         20         2  Interneuron
 C2         50         3  Interneuron
 C3          8         1    Pyramidal
"""

We specify '-' as the delimiter and and it also works. Although the function is called read_csv it is not strictly bound to comma separated values. We can also skip rows, columns and specify many more options you can learn about from the documentation. For well structured .csv files however, we need very few arguments as shown above. Next we will learn how to do basic calculations with the DataFrame.

Basic Math with DataFrame

A variety of functions such as df.mean(), df.median() and df.std() are available to do basic statistics on our DataFrame. By default they all return values per column. That is because columns are assumed to contain our variables (or features) and each row contains a sample.

df.mean()
"""
Freq (Hz)    26.0
Loc (cm)      2.0
dtype: float64
"""

df.median()
"""
Freq (Hz)    20.0
Loc (cm)      2.0
dtype: float64
"""

df.std()
"""
Freq (Hz)    21.633308
Loc (cm)      1.000000
dtype: float64
"""

One big advantage of the column is that within a column the data type is clearly defined. Within a row on the other hand different data types can exist. In our case we have two numeric types and a string. When we call these statistical methods, numeric types are ignored. In our case that is 'Cell Type'. Technically we can also use the axis parameter to calculate these statistics for each sample but this is not always useful and has to again ignore one of the columns.

df.mean(axis=1)
"""
C1    11.0
C2    26.5
C3     4.5
dtype: float64
"""

We can also use other mathematical operators. They are applied element-wise and their effect will depend on the data type of the value.

print(df * 3)
"""
         Frequency  Location                      Cell Type
 C1         60         6  InterneuronInterneuronInterneuron
 C2        150         9  InterneuronInterneuronInterneuron
 C3         24         3        PyramidalPyramidalPyramidal
"""

Often times these operations make more sense for individual columns. As explained above we can use indexing to get individual columns and we can even assign new results to an existing or new column.

norm_freq = df['Frequency'] / df.mean()['Frequency']
norm_freq
"""
 C1    0.769231
 C2    1.923077
 C3    0.307692
Name: Frequency, dtype: float64
"""
df['Norm Freq'] = norm_freq
print(df)
"""
     Frequency  Location    Cell Type  Norm Freq
 C1         20         2  Interneuron   0.769231
 C2         50         3  Interneuron   1.923077
 C3          8         1    Pyramidal   0.307692
"""

If you are familiar with NumPy, most of these DataFrame operations will seem very familiar because they mostly work like array operations. Because Pandas builds on NumPy, most NumPy functions (for example np.sin) work on numeric columns. I don’t want to go deeper and instead move on to visualizing DataFrames with Seaborn.

Seaborn for Data Visualization

Seaborn is a high-level data visualization package that builds on Matplotlib. It does not necessarily require a DataFrame. It can work with other data structures such as ndarray but it is particularly convenient with DataFrame. First, let us get a more interesting data set. Luckily Seaborn comes with some nice example data sets and they conveniently load into Pandas DataFrame.

import seaborn as sns
df = sns.load_dataset('iris')
type(df)
# pandas.core.frame.DataFrame
print(df)
"""
     sepal_length  sepal_width  petal_length  petal_width    species
0             5.1          3.5           1.4          0.2     setosa
1             4.9          3.0           1.4          0.2     setosa
2             4.7          3.2           1.3          0.2     setosa
3             4.6          3.1           1.5          0.2     setosa
4             5.0          3.6           1.4          0.2     setosa
..            ...          ...           ...          ...        ...
145           6.7          3.0           5.2          2.3  virginica
146           6.3          2.5           5.0          1.9  virginica
147           6.5          3.0           5.2          2.0  virginica
148           6.2          3.4           5.4          2.3  virginica
149           5.9          3.0           5.1          1.8  virginica

[150 rows x 5 columns]
"""

print(df.columns)
"""
Index(['sepal_length', 'sepal_width', 'petal_length', 'petal_width',
       'species'],
      dtype='object')
"""

The Iris data set contains information about different species of iris plants. It contains 150 samples and 5 features. The 'species' feature tells us what species a particular sample belongs to. The names of those columns are very useful when we structure our plots in Seaborn. Let’s first try a basic bar graph.

sns.set(context='paper',
        style='whitegrid',
        palette='colorblind',
        font='Arial',
        font_scale=2,
        color_codes=True)
fig = sns.barplot(x='species', y='sepal_length', data=df)

We use sns.barplot and we have to pass our DataFrame to the data parameter. Then for x and y we define which column name should appear there. We put 'species' on the x-axis so that is how data is aggregated inside the bars. Setosa, versicolor and virginica are the different species. The sns.set() function defines multiple parameters of Seaborn and forces a certain style on the plots that I personally prefer. Bar graphs have grown out of fashion and for good reason. They are not very informative about the distribution of their underlying values. I prefer the violin plot to get a better idea of the distribution.

fig = sns.violinplot(x='species', y='sepal_length', data=df)

We even get a small box plot within the violin plot for free. Seaborn works its magic through the DataFrame column names. This makes plotting more convenient but also makes our code more descriptive than it would be with pure NumPy. Our code literally tells us, that 'species' will be on the x-axis.

Summary

We learned that we can create a DataFrame from a dictionary or another kind of collection. The most important features are the column and row names. Columns organize features and rows organize samples by convention. We can also load files into a DataFrame. For example we can use read_csv to load .csv or other text based files. We can also use methods like df.mean() to get basic statistics of our DataFrame. Finally, Seaborn is very useful to visualize a DataFrame.

Getting Started Programming Julia

To get us started with Julia we cover three basics. Arithmetic operators, name assignment.

Arithmetic operators

The standard arithmetic operators are addition (+), subtraction (-), multiplication (*), division (/), power (^) and remainder (%). They work as expected and the only one that is different for the Python crowd is power. That one is Matlab consistent. The normal precedence of operations applies. First power. Then multiplication and division. Then remainder. Finally addition and subtraction. Parentheses can be used to change the order of operation.

1 + 3 * 2
# 7
(1 + 3) * 2
# 8
2 * 3 ^ 2
# 18

In those examples, both sides of the operator are scalars. It gets a little more interesting when at least one of them is a vector or a matrix. Not all of the above operations are defined between vectors and scalars. Only division and multiplication are defined. We create a vector using square brackets ([]) with the elements separated by commas.

[3, 1, 4] * 2
# 3-element Array{Int64,1}:
#  6
#  2
#  8

[3, 1, 4] / 2
# 3-element Array{Float64,1}:
# 1.5
# 0.5
# 2.0

The other operations are not defined and throw an error.

[3, 1, 4] ^ 2
"""
MethodError: no method matching ^(::Array{Int64,1}, ::Int64)
Closest candidates are:
  ^(!Matched::Float16, ::Integer) at math.jl:885
  ^(!Matched::Regex, ::Integer) at regex.jl:712
  ^(!Matched::Missing, ::Integer) at missing.jl:155
  ...

Stacktrace:
 [1] macro expansion at .\none:0 [inlined]
 [2] literal_pow(::typeof(^), ::Array{Int64,1}, ::Val{2}) at .\none:0
 [3] top-level scope at In[52]:1
"""

The reasons for this design choice have to do with Julias focus on linear algebra and are not important here. If we want this operation to work in an element-wise manner, we have to force it explicitly. We can do so using the dot (.). This way we can explicitly force every operator to be applied element-wise.

[3, 1, 4] .^ 2
# 3-element Array{Int64,1}:
#  9
#  1
# 16
[3, 1, 4] .+ 2
# 3-element Array{Int64,1}:
#  5
#  3
#  6

Now that we have our arithmetic operators, let’s move on to name assignment so we can store the results of our operations.

Name Assignment

We assign names to values with the = operator using the syntax name = value. Once a name is assigned, we can use the name instead of the value in operations.

result_one = 2 + 2
result_two = result_one + 3
# 7

Once a name is assigned to a value we can reassign that same name to a different value without problem.

result = 2 + 2
result = 10
# 10

If we want to assign a name that is not supposed to be reassigned, we can use the const keyword. If we try to reassign a constant name we get an error.

const a = 8.3144621
a = 3
"""
invalid redefinition of constant a

Stacktrace:
 [1] top-level scope at In[73]:2
"""

There are a few rules about the names we can assign. Generally, Unicode characters (UTF-8) are allowed. This means we can do something like this.

δt = 0.0001

Here we are using the special character delta. If you want to quickly generate such a character, many Julia environments allow you to do this by typing \delta and hitting tab. I recommend using these sparingly, as they might confuse people transitioning from other languages that don’t allow unicode names. On the other hand they might be useful to make your code style more mathy. Not allowed as names are built-in keywords that have special meaning and trying to assign them will result in an error.

if = 3
# syntax: unexpected "="

If you are interested in more details about variables and name assignments you can take a look at the official documentation. In the next blog post we will take a look at the type system of Julia.

A Curve Fitting Guide for the Busy Experimentalist

Curve fitting is an extremely useful analysis tool to describe the relationship between variables or discover a trend within noisy data. Here I’ll focus on a pragmatic introduction curve fitting: how to do it in Python, why can it fail and how do we interpret the results? Finally, I will also give a brief glimpse at the larger themes behind curve fitting, such as mathematical optimization, to the extent that I think is useful for the casual curve fitter.

Curve Fitting Made Easy with SciPy

We start by creating a noisy exponential decay function. The exponential decay function has two parameters: the time constant tau and the initial value at the beginning of the curve init. We’ll evenly sample from this function and add some white noise. We then use curve_fit to fit parameters to the data.

import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

# The exponential decay function
def exp_decay(x, tau, init):
    return init*np.e**(-x/tau)

# Parameters for the exp_decay function
real_tau = 30
real_init = 250

# Sample exp_decay function and add noise
np.random.seed(100)
dt=0.1
x = np.arange(0,100,dt)
noise=np.random.normal(scale=50, size=x.shape[0])
y = exp_decay(x, real_tau, real_init)
y_noisy = y + noise

# Use scipy.optimize.curve_fit to fit parameters to noisy data
popt, pcov = scipy.optimize.curve_fit(exp_decay, x, y_noisy)
fit_tau, fit_init = popt

# Sample exp_decay with optimized parameters
y_fit = exp_decay(x, opt_tau, opt_init)

fig, ax = plt.subplots(1)
ax.scatter(x, y_noisy,
           alpha=0.8,
           color= "#1b9e77",
           label="Exponential Decay + Noise")
ax.plot(x, y,
        color="#d95f02",
        label="Exponential Decay")
ax.plot(x, y_fit,
        color="#7570b3",
        label="Fit")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend()
ax.set_title("Curve Fit Exponential Decay")

Our fit parameters are almost identical to the actual parameters. We get 30.60 for fit_tau and 245.03 for fit_init both very close to the real values of 30 and 250. All we had to do was call scipy.optimize.curve_fit and pass it the function we want to fit, the x data and the y data. The function we are passing should have a certain structure. The first argument must be the input data. All other arguments are the parameters to be fit. From the call signature of def exp_decay(x, tau, init) we can see that x is the input data while tau and init are the parameters to be optimized such that the difference between the function output and y_noisy is minimal. Technically this can work for any number of parameters and any kind of function. It also works when the sampling is much more sparse. Below is a fit on 20 randomly chosen data points.

Of course the accuracy will decrease with the sampling. So why would this every fail? The most common failure mode in my opinion is bad initial parameters.

Choosing Good Initial Parameters

The initial parameters of a function are the starting parameters before being optimized. The initial parameters are very important because most optimization methods don’t just look for the best fit randomly. That would take too long. Instead, it starts with the initial parameters, changes them slightly and checks if the fit improves. When changing the parameters shows very little improvement, the fit is considered done. That makes it very easy for the method to stop with bad parameters if it stops in a local minimum or a saddle point. Let’s look at an example of a bad fit. We will change our tau to a negative number, which will result in exponential growth.

In this case fitting didn’t work. For a real_tau and real_init of -30 and 20 we get a fit_tau and fit_init of 885223976.9 and 106.4, both way off. So what happened? Although we never specified the initial parameters (p0), curve_fit chooses default parameters of 1 for both fit_tau and fit_init. Starting from 1, curve_fit never finds good parameters. So what happens if we choose better parameters? Looking at our exp_decay definition and the exponential growth in our noisy data, we know for sure that our tau has to be negative. Let’s see what happens when we choose a negative initial value of -5.

p0 = [-5, 1]
popt, pcov = scipy.optimize.curve_fit(exp_decay, x, y_noisy, p0=p0)
fit_tau, fit_init = popt
y_fit = exp_decay(x, fit_tau, fit_init)
fig, ax = plt.subplots(1)
ax.scatter(x, y_noisy,
           alpha=0.8,
           color= "#1b9e77",
           label="Exponential Decay + Noise")
ax.plot(x, y,
        color="#d95f02",
        label="Exponential Decay")
ax.plot(x, y_fit,
        color="#7570b3",
        label="Fit")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend()
ax.set_title("Curve Fit Exponential Growth Good Initials")

With an initial parameter of -5 for tau we get good parameters of -30.4 for tau and 20.6 for init (real values were -30 and 20). The key point is that initial conditions are extremely important because they can change the result we get. This is an extreme case, where the fit works almost perfectly for some initial parameters or completely fails for others. In more subtle cases different initial conditions might result in slightly better or worse fits that could still be relevant to our research question. But what does it mean for a fit to be better or worse? In our example we can always compare it to the actual function. In more realistic settings we can only compare our fit to the noisy data.

Interpreting Fitting Results

In most research setting we don’t know our exact parameters. If we did, we would not need to do fitting at all. So to compare the goodness of different parameters we need to compare our fit to the data. How do we calculate the error between our data and the prediction of the fit? There are many different measures but among the most simple ones is the sum of squared residuals (SSR).

def ssr(y, fy):
    """Sum of squared residuals"""
    return ((y - fy) ** 2).sum()

We take the difference between our data (y) and the output of our function given a parameter set (fy). We square that difference and sum it up. In fact this is what curve_fit optimizes. Its whole purpose is to find the parameters that give the smallest value of this function, the least square. The parameters that give the smallest SSR are considered the best fit. We saw that this process can fail, depending on the function and the initial parameters, but let’s assume for a moment it worked. If we found the smallest SSR, does that mean we found the perfect fit? Unfortunately not. What we found was a good estimate for the best fitting parameters given our function. There are probably other functions out there that can fit our data better. We can use the SSR to find better fitting functions in a process called cross-validation. Instead of comparing different parameters of the same function we compare different functions. However, if we increase the number of parameters we run into a problem called overfitting. I will not get into the details of overfitting here because it is beyond our scope.

The main point is that we must stay clear of misinterpretations of best fit. We are always fitting the parameters and not the function. If our fitting works, we get a good estimate for the best fitting parameters. But sometimes our fitting doesn’t work. This is because our fitting method did not converge to the minimum SSR and in the final chapter we will find out why that might happen in our example.

The Error Landscape of Exponential Decay

To understand why fitting can fail depending on the initial conditions we should consider the landscape of our sum of squared residuals (SSR). We will calculate it by assuming that we already know the init parameter, so we keep it constant. Then we calculate the SSR for many values of tau smaller than zero and many values for tau larger than zero. Plotting the SSR against the guessed tau will hopefully show us how the SSR looks around the ideal fit.

real_tau = -30.0
real_init = 20.0

noise=np.random.normal(scale=50, size=x.shape[0])
y = exp_decay(x, real_tau, real_init)
y_noisy = y + noise
dtau = 0.1
guess_tau_n = np.arange(-60, -4.9, dtau)
guess_tau_p = np.arange(1, 60, dtau)

# The SSR function
def ssr(y, fy):
    """Sum of squared residuals"""
    return ((y - fy) ** 2).sum()

loss_arr_n = [ssr(y_noisy, exp_decay(x, tau, real_init)) 
              for tau in guess_tau_n]
loss_arr_p = [ssr(y_noisy, exp_decay(x, tau, real_init))
              for tau in guess_tau_p]

"""Plotting"""
fig, ax = plt.subplots(1,2)
ax[0].scatter(guess_tau_n, loss_arr_n)
real_tau_loss = ssr(y_noisy, exp_decay(x, real_tau, real_init))
ax[0].scatter(real_tau, real_tau_loss, s=100)
ax[0].scatter(guess_tau_n[-1], loss_arr_n[-1], s=100)
ax[0].set_yscale("log")
ax[0].set_xlabel("Guessed Tau")
ax[0].set_ylabel("SSR Standard Log Scale")
ax[0].legend(("All Points", "Real Minimum", "-5 Initial Guess"))

ax[1].scatter(guess_tau_p, loss_arr_p)
ax[1].scatter(guess_tau_p[0], loss_arr_p[0], s=100)
ax[1].set_xlabel("Guessed Tau")
ax[1].set_ylabel("SSR")
ax[1].legend(("All Points", "1 Initial Guess"))

On the left we see the SSR landscape for tau smaller than 0. Here we see that towards zero, the error becomes extremely large (note the logarithmic y scale). This is because towards zero the exponential growth becomes ever faster. As we move to more negative values we find a minimum near -30 (orange), our real tau. This is the parameter curve_fit would find if it only optimized tau and started initially at -5 (green). The optimization method does not move to more negative values from -30 because there the SSR becomes worse, it increases.

On the right side we get a picture of why optimization failed when we started at 1. There is no local minimum. The SSR just keeps decreasing with larger values of tau. That is why the tau was so larger when fitting failed (885223976.9). If we set our initial parameter anywhere in this part of the SSR landscape, this is where tau will go. Now there are other optimization methods that can overcome bad initial parameters. But few are completely immune to this issue.

Easy to Learn Hard to Master.

Curve fitting is a very useful technique and it is really easy in Python with Scipy but there are some pitfalls. First of all, be aware of the initial values. They can lead to complete fitting failure or affect results in more subtle systematic ways. We should also remind ourselves that even with decent fitting results, there might be a more suitable function out there that can fit our data even better. In this particular example we always knew what the underlying function was. This is rarely the case in real research settings. Most of the time it is much more productive to think more deeply about possible underlying functions than finding more complicated fitting methods.

Finally, we barely scratched the surface here. Mathematical optimization is an entire field in itself and it is relevant to many areas such as statistics, machine learning, deep learning and many more. I tried to give the most pragmatic introduction to the topic here. If want to go deeper into the topic I recommend this Scipy lecture and of course the official Scipy documentation for optimization and root finding.

Matplotlib and the Object-Oriented Interface

Matplotlib is a great data visualization library for Python and there are two ways of using it. The functional interface (also known as pyplot interface) allows us to interactively create simple plots. The object-oriented interface on the other hand gives us more control when we create figures that contain multiple plots. While having two interfaces gives us a lot of freedom, they also cause some confusion. The most common error is to use the functional interface when using the object-oriented would be much easier. For beginners it is now highly recommended to use the object-oriented interface under most circumstances because they have a tendency to overuse the functional one. I made that mistake myself for a long time. I started out with the functional interface and only knew that one for a long time. Here I will explain the difference between both, starting with the object-oriented interface. If you have never used it, now is probably the time to start.

Figures, Axes & Methods

When using the object-oriented interface, we create objects and do the plotting with their methods. Methods are the functions that come with the object. We create both a figure and an axes object with plt.subplots(1). Then we use the ax.plot() method from our axes object to create the plot. We also use two more methods, ax.set_xlabel() and ax.set_ylabel() to label our axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0, 10, 0.1)
y = np.sin(np.pi * x) + x

fig, ax = plt.subplots(1)
ax.plot(x, y)
ax.set_xlabel("x")
ax.set_ylabel("y")

The big advantage is that we can very easily create multiple plots and we can very naturally keep track of where we are plotting what, because the method that does the plotting is associated with a specific axes object. In the next example we will plot on three different axes that we create all with plt.subplots(3).

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
ax[0].plot(x,ys[0])
ax[1].plot(x,ys[1])
ax[2].plot(x,ys[2])

ax[0].set_title("Addition")
ax[1].set_title("Multiplication")
ax[2].set_title("Division")

for a in ax:
    a.set_xlabel("x")
    a.set_ylabel("y")

When we create multiple axes objects, they are available to us through the ax array. We can index into them and we can also loop through all of them. We can take the above example even further and pack even the plotting into the for loop.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
titles = ["Addition", "Multiplication", "Division"]
for idx, a in enumerate(ax):
    a.plot(x, ys[idx])
    a.set_title(titles[idx])
    a.set_xlabel("x")
    a.set_ylabel("y")

This code produces exactly the same three axes figure as above. There are other ways to use the object-oriented interface. For example, we can create an empty figure without axes using fig = plt.figure(). We can then create subplots in that figure with ax = fig.add_subplot(). This is exactly the same concept as always but instead of creating figure and axes at the same time, we use the figure method to create axes. If personally prefer fig, ax = plt.subplots() but fig.add_subplot() is slightly more flexible in the way it allows us to arrange the axes. For example, plt.subplots(x, y) allows us to create a figure with axes arranged in x rows and y columns. Using fig.add_subplot() we could create a column with 2 axes and another with 3 axes.

fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax2 = fig.add_subplot(2,2,3)
ax3 = fig.add_subplot(3,2,2)
ax4 = fig.add_subplot(3,2,4)
ax5 = fig.add_subplot(3,2,6)

Personally I prefer to avoid these arrangements, because things like tight_layout don’t work but it is doable and cannot be done with plt.subplots(). This concludes our overview of the object-oriented interface. Simply remember that you want to do your plotting through the methods of an axes object that you can create either with fig, ax = plt.subplots() or fig.add_subplot(). So what is different about the functional interface? Instead of plotting through axes methods, we do all our plotting through functions in the matplotlib.pyplot module.

One pyplot to Rule Them All

The functional interface works entirely through the pyplot module, which we import as plt by convention. In the example below we use it to create the exact same plot as in the beginning. We use plt to create the figure, do the plotting and label the axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0,10,0.1)
y = np.sin(np.pi*x) + x

plt.figure()
plt.plot(x, y)
plt.xlabel("x")
plt.ylabel("y")

You might be wondering, why we don’t need to tell plt where to plot and which axes to label. It always works with the currently active figure or axes object. If there is no active figure, plt.plot() creates its own, including the axes. If a figure is already active, it creates an axes in that figure or plots into already existing axes. This make the functional interface less explicit and slightly less readable, especially for more complex figures. For the object-oriented interface, there is a specific object for any action, because a method must be called through an object. With the functional interface, it can be a guessing game where the plotting happens and we make ourselves highly dependent on the location in our script. The line we call plt.plot() on becomes crucial. Let’s recreate the three subplots example with the functional interface.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
plt.subplot(3, 1, 1)
plt.plot(x, ys[0])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Addition")
plt.subplot(3, 1, 2)
plt.plot(x, ys[1])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Multiplication")
plt.subplot(3, 1, 3)
plt.plot(x, ys[2])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Division")

This one is much longer than the object-oriented code because we cannot label the axes in a for loop. To be fair, in this particular example we can put the entirety of our plotting and labeling into a for loop, but we have to put either everything or nothing into the loop. This is what I mean when I say the functional interface is less flexible.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
titles = ["Addition", "Multiplication", "Division"]
for idx, y in enumerate(ys):
    plt.subplot(3,1,idx+1)
    plt.plot(x, y)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(titles[idx])

Both interfaces are very similar. You might have noticed that the methods in the object-oriented API have the form set_attribute. This is by design and follows from an object oriented convention, where methods that change attributes have a set prefix. Methods that don’t change attributes but create entirely new objects have an add prefix. For example add_subplot. Now that we have seen both APIs at work, why is the object-oriented API recommended?

Advantages of the Object-Oriented API

First of all, Matplotlib is internally object-oriented. The pyplot interface masks that fact in an effort to make the usage more MATLAB like by putting a functional layer on top. If we avoid plt and instead work with the object-oriented interface, our plotting becomes slightly faster. More importantly, the object-oriented interface is considered more readable and explicit. Both are very important when we write Python. Readability can be somewhat subjective but I hope the code could convince you that going through the plotting methods of an axes object makes it much more clear where we are plotting. We also get more flexibility to structure our code. Because plt depends on the order of plotting, we are constraint. With the object-oriented interface we can structure our code more clearly. We can for example split plotting, labeling and other tasks into their own code blocks.

In summary, I hope you will be able to use the object-oriented interface of Matplotlib now. Simply remember to create axes with fig, ax = plt.subplots() and then most of the work happens through the ax object. Finally, the object-oriented interface is recommended because it is more efficient, readable, explicit and flexible.

Image Segmentation with scikit-image

Image Segmentation is one of the most important steps in most imaging analysis pipelines. It separates between the background and the features of our images. It can also determine the number of distinct features and their location. Our ability to segment determines what we can analyze. We’ll look at a basic but complete segmentation pipeline with scikit-image. You can see the result in the title image where we segment four cells. First, we will need to threshold the image into a binary version where the background is 0 and the foreground is 1.

import numpy as np
import matplotlib.pyplot as plt
from skimage import io, filters, morphology, color

image = io.imread("example_image.tif")  # Load Image
threshold = filters.threshold_otsu(image)  # Calculate threshold
image_thresholded = image > threshold  # Apply threshold

# Show the results
fig, ax = plt.subplots(1, 2)
ax[0].imshow(image, 'gray')
ax[1].imshow(image_thresholded, 'gray')
ax[0].set_title("Intensity")
ax[1].set_title("Thresholded")

We calculate the threshold with the threshold_otsu function and apply it with a boolean operator. This threshold method works very well but there are two problems. First, there are very small particles that have nothing to do with our cell. To take care of those, we will apply morphological erosion. Second, there are holes in our cells. We will close those with morphological dilation.

# Apply 2 times erosion to get rid of background particles
n_erosion = 2
image_eroded = image_thresholded
for x in range(n_erosion):
    image_eroded = morphology.binary_erosion(image_eroded)

# Apply 14 times dilation to close holes
n_dilation = 14
image_dilated = image_eroded
for x in range(n_dilation):
    image_dilated = morphology.binary_dilation(image_dilated)

# Apply 4 times erosion to recover original size
n_erosion = 4
image_eroded_two = image_dilated
for x in range(n_erosion):
    image_eroded_two = morphology.binary_erosion(image_eroded_two)

fig, ax = plt.subplots(2,2)
ax[0,0].imshow(image_thresholded, 'gray')
ax[0,1].imshow(image_eroded, 'gray')
ax[1,0].imshow(image_dilated, 'gray')
ax[1,1].imshow(image_eroded_two, 'gray')
ax[0,0].set_title("Thresholded")
ax[0,1].set_title("Eroded 2x")
ax[1,0].set_title("Dilated 14x")
ax[1,1].set_title("Eroded 4x")

Erosion turns any pixel black that is contact with another black pixel. This is how erosion can get rid of small particles. In our case we need to apply erosion twice. Once those particles disappeared, we can use dilation to close the holes in our cells. To close all the holes, we have to slightly over dilate, which makes the cells slightly bigger than they actually are. To recover the original morphology we apply some more erosions. Here is an example that shows how erosion and dilation work in detail. It also illustrates what being “in contact” with another pixel means by default.

cross = np.array([[0,0,0,0,0], [0,0,1,0,0], [0,1,1,1,0], 
                [0,0,1,0,0], [0,0,0,0,0]], dtype=np.uint8)
cross_eroded = morphology.binary_erosion(cross)
cross_dilated = morphology.binary_dilation(cross)
fig, ax = plt.subplots(1,3)
ax[0].imshow(cross, 'gray')
ax[1].imshow(cross_eroded, 'gray')
ax[2].imshow(cross_dilated, 'gray')
ax[0].set_title("Cross")
ax[1].set_title("Cross Eroded")
ax[2].set_title("Cross Dilated")

Now we are essentially done segmenting foreground and background. But we also want to assign distinct labels to our objects.

labels = morphology.label(image_eroded_two)
labels_rgb = color.label2rgb(labels,
                             colors=['greenyellow', 'green',
                                     'yellow', 'yellowgreen'],
                             bg_label=0)
image.shape
# (342, 382)
labels.shape
# (342, 382)
fig, ax = plt.subplots(2,2)
ax[0,0].imshow(labels==1, 'gray')
ax[0,1].imshow(labels==2, 'gray')
ax[1,0].imshow(labels==3, 'gray')
ax[1,1].imshow(labels_rgb)
ax[0,0].set_title("label == 1")
ax[0,1].set_title("label == 2")
ax[1,0].set_title("label == 3")
ax[1,1].set_title("All labels RGB")

We use morphology.label to generate a label for each connected feature. This returns an array that has the same shape as our original image but the pixels are no longer zero or one. The background is zero but each feature gets its own integer. All pixels belonging to the first label are equal to 1, pixels of the second label equal to 2 and so on. To visualize those labels all in one image, we call color.label2rgb to get color representations for each label in RGB space. And that’s it.

Segmentation is crucial for image analysis and I hope this tutorial got you on a good way to do your own segmentation with scikit-image. This pipeline is not perfect but illustrates the concept well. There are many more functions in the morphology module to filter binary images, but they all come down to a sequence of erosions and dilations. If you want to adapt this approach for your own images, I would recommend to play around with the number of erosions and dilations. Let me know how it worked for you.