Structural Causal Models to Clarify Causality in Neuroscience

Neuroscience is full of causal questions. Is this brain area necessary for a behavior? Does this drug decrease seizure frequency in epileptic patients? Does the firing rate of a neuron influence the firing rate of that other neuron? But while there are many different causal questions there are also slightly different notions of causality. Barack et al. 2022 rightly call for more clarity when asking and answering causal questions. Here I want to makea point that will hopefully help with adding clarity to causal reasoning in neuroscience. My point specifically relates to the power of structural causal models (SCMs). (Judea Pearl has taken note of the paper and that SCMs are not mentioned. Find his tweet here. I recommend anyone to read The Book of Why.)

Specifically I want to make two main points:

  • A structural causal model (SCM) is extremely useful and adds clarity to causal reasoning.
  • I agree that multiple concepts of causality will be useful in neuroscience but all of them become much clearer when explained with a SCM.

I will start by defining what a SCM is and then I will try to make some neuroscience questions more clear.

The Structural Causal Model

A structural causal model (SCM) consists of three sets. The set U contains the error terms that are outside (exogenous) to the model. The set V contains the variables inside (endogenous) to our model and we are interested in the causal relations between them. Finally, the set E contains a functions that describe each variable in V in terms of other variables in V or U. Thereby, E describes the causal relationships in the SCM. The mathematical description is neat but SCMs also have a very clear graphical representation, where each variable V is a circle and arrows between circles show causal relationships. For example, the SCM below proposes causal influences on neuronal activity during optogenetic perturbation.

Fig 1.

This is extremely useful. It tells us that the neuronal activity is determined by multiple variables and an optogenetic construct we might introduce becomes one of those factors. A small disclaimer: you might be missing the errors terms U. If each variable in V is associated with exactly one error term they are by convention omitted in the graphical model. So gaining optogenetic control over neuronal activity would be hard. Cutting all other arrows would probably be impossible. But we could use the non-linearity of the system and try to find optogenetic stimulation strengths where other variables become negligible. But I don’t want to stay with this model for too long because it’s just for illustrative purposes. Instead I want to talk about the different definitions of causality.

The Path Definition of Causality

The causal definition I usually work with goes like this: if there is a directed path from x to y, x has a causal influence on y. Also: if there is no directed path from x to y, x does not have a causal influence on y. Below are some SCMs where in the upper three (A, B, C) x has a causal influence on y, whereas in the lower three (D, E, F), x does not have a causal influence on y.

Fig 2.

In A, there is a directed path but z also acts as a confounder. In B, z is a mediator. In C there are two causal pathways, a direct one and one mediated by z. In D, there is a path but it is not directed (it collides on z). In E, there is a path but it is not directed (this is the SCM that gives you correlations between ice cream sales and violent crime). In F, the path is in the wrong direction. So we can use SCMs to clarify any kind of causal relation. One thing the graphical representation of SCMs does not tell us what the exact form of the causal relationship is. This is not necessarily a bad thing because it means our definition of causality does not depend on linearity. But sometimes we need to know whether y is continuous (neuronal rate) or nominal (mouse going left or right or staying). That is usually easy to clarify in writing. A bigger issue is that causal inference in SCMs works best only if there are no cycles between variables. Cycles however, are pretty normal in neuroscience (I work on the hippocampus where a lot of information flows in a loop).

So what to do about cycles? This is where time comes in. We can define the SCM at a timescale where the cycle is negligible. For example, in Fig 1 I have the variable “Previous Activity” (for an interesting usage of previous activity as instrumental variable see Lepperød et al. 2022). If that doesn’t work for you because you are interested in a time scale where the previous activity is still being influenced by the current activity (a truly unbreakable cycle), then there is probably no way around actually simulating the dynamical system with respect to time.

In summary, I believe that any concept of causality becomes more clear when we draw a SCM. If you can think of a concept that does not work in SCMs let me know. Another concept that does not come up in Barack et al. 2022 is do-calculus which also becomes very clear when shown with a SCM.

do-calculus, interventions and counterfactuals would be interesting to write about at some point but for now I’m out of time.

Wind and Solar Power Generation

As the world tries to stop planet warming emissions, solar and wind power have taken central stage. One of the reason solar and wind go well together is their complementary seasonal pattern. The magnitude of this pattern depends on the location but in Europe winter means less solar power but more wind power. Summer on the other hand means more solar but slightly less wind. So if you want to have renewable energy all year round, it’s a good idea to build both solar and wind capacity.

On https://energy-charts.info/ raw data has recently become available for download and I have been doing data visualization to show the complementary seasonal pattern with actual data from several countries. Here I show some examples and the code. The code I am showing can be adapted for different countries. Here is the code and an example figure.

import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

dirname = os.path.dirname(__file__)
data_path = os.path.join(dirname, 'data')

country = 'Deutschland'

raw_files = [os.path.join(data_path, f)
             for f in os.listdir(data_path) if country in f]

loaded_files = [pd.read_csv(f) for f in raw_files]

df = pd.concat(loaded_files, ignore_index=True)
df['Datum (UTC)'] = pd.to_datetime(df['Datum (UTC)'])
df['day_of_year'] = df['Datum (UTC)'].dt.dayofyear
df['year'] = df['Datum (UTC)'].dt.year
df['month'] = df['Datum (UTC)'].dt.month
df['day'] = df['Datum (UTC)'].dt.day
df['Wind'] = df['Wind Offshore'] + df['Wind Onshore']


# plt.figure()
# plt.hist(df.loc[df['year'] == 2021]['Wind+Solar'], bins=100)

# Calculate year as theta on the cycle
# df['theta_seconds'] = np.nan
# Plot solar development
df_daily = df.groupby(by=['year', 'month', 'day']).mean()
df_daily = df_daily.reset_index()

# Calculate daily theta
for year in df['year'].unique():
    boolean_year = df_daily['year'] == year
    days_zero = (df_daily[boolean_year]['day_of_year']
                 - df_daily[boolean_year]['day_of_year'].min())
    theta = days_zero / days_zero.max()
    
    df_daily.loc[boolean_year, 'theta_days'] = theta * 2 * np.pi

font = {'size'   : 16}
matplotlib.rc('font', **font)

solar_colors = ['#fef0d9','#fdd49e','#fdbb84','#fc8d59',
                '#ef6548','#d7301f','#990000']
fig, ax = plt.subplots(1, 2, subplot_kw={'projection': 'polar'})

for idx, y in enumerate(df_daily['year'].unique()):
    r = df_daily[df_daily['year'] == y]["Solar"] / 1000 # MW to GW
    theta = df_daily[df_daily['year'] == y]['theta_days']
    ax[0].plot(theta, r, alpha=0.8, color=solar_colors[idx])
    r = df_daily[df_daily['year'] == y]["Wind"] / 1000 # MW to GW
    ax[1].plot(theta, r, alpha=0.8, color=solar_colors[idx])

# Find month starting theta
month_thetas = []
for i in range(1,13):
    idx = (df_daily['month'] == i).idxmax()
    month_thetas.append(df_daily['theta_days'][idx])
month_labels = ['January', 'February', 'March', 'April', 'May',
                'June', 'July', 'August', 'September', 'October',
                'November', 'December']

for i in [0,1]:
    tl = np.array(month_thetas) / (2*np.pi) * 360
    ax[i].set_thetagrids(tl, month_labels)
    ax[i].xaxis.set_tick_params(pad=28)
    ax[i].set_theta_direction(-1)

ax[0].set_rticks([5, 7.5, 10, 12.5])
ax[1].set_rticks([20, 30, 40])
ax[0].set_rlabel_position(-10)  # Move radial labels
ax[1].set_rlabel_position(-170)
ax[1].legend(df_daily['year'].unique(), bbox_to_anchor=(1.1, 1.2)) 

fig.suptitle("Average daily power (in GW) from solar (left) and" +
             "wind (right) in Germany for years 2015-2021.\n" + 
             "Data from energy-charts.info. Plot by Daniel" + 
             " Müller-Komorowska.")

The code itself is pretty simple because the data already comes well structured. Every year is in an individual file so we need to find all of those in the data directory. Probably the biggest preprocessing step is to calculate the daily average. The original data has a temporal resolution of 15 minutes which is a bit too noisy for the type of plot we are making. Once we extracted year, month and day we can do it on one line:

df_daily = df.groupby(by=['year', 'month', 'day']).mean().reset_index()

Next, we calculate theta for each day so we can distribute the datapoints along the polar plot. Once we have that we are basically done. The rest is matplotlib styling. Some of these styling aspects are hardcoded because they are hard to automate for different countries. Here are two more examples:

Measuring and Visualizing GPU Power Usage in Real Time with asyncio and Matplotlib

In this post we will learn how to periodically measure the power power usage of our GPU and plot it in real time with a single Python program. For this we need concurrency between the measuring and the plotting part of our code. Concurrency means that the measuring process will got to sleep after measuring. While the measuring process is asleep the plotting process can do the plotting and goes to sleep as well. After a defined amount of time the measuring process wakes up and does the measuring if the CPU allows it, then the plotting process starts and so on. We achieve concurrency with asyncio and the plotting is done with Matplotlib. To measure the GPU power we use pynmvl (Python Bindings for the NVIDIA Management Library). Before we get into the code, here is a video showing the interface in action.

This video shows the power in Watt at my GPU over a twenty seconds time window. The measurements are taken every 100 milliseconds and the plotting is done every 200 milliseconds. Everything happens in one Python script. I ran this by simply passing the below script to python in my command line.
import pynvml
import matplotlib.pyplot as plt
import time
import numpy as np
import asyncio

"""Initialize GPU measurement and parameters"""
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
measurement_interval = 0.1  # in seconds
plotting_interval = 0.2  # in seconds
time_span = 20  # time span on the plot x-axis in seconds
m = t = np.array([np.nan]*int(time_span / measurement_interval))
mW_to_W = 1e3

"""Initialize the plot"""
plt.ion()
plt.rcParams.update({'font.size': 18})
figure, ax = plt.subplots(figsize=(8,6))
line1, = ax.plot(t, m, linewidth=3)
ax.set_xlabel("Time (s)")
ax.set_ylabel("GPU Power (W)")

async def measure():
    while True:
        measure = pynvml.nvmlDeviceGetPowerUsage(handle) / mW_to_W
        dt = time.time() - ts
        m[:-1] = m[1:]
        m[-1] = measure
        t[:-1] = t[1:]
        t[-1] = dt
        await asyncio.sleep(measurement_interval)

async def plot():
    while True:
        line1.set_data(t, m)
        tmin, tmax = np.nanmin(t), np.nanmax(t)
        mmin, mmax = np.nanmin(m), np.nanmax(m)
        margin = (np.abs(mmax - mmin) / 10) + 0.1
        ax.set_xlim((tmin, tmax + 1))
        ax.set_ylim((mmin - margin, mmax + margin))
        figure.canvas.flush_events()
        await asyncio.sleep(plotting_interval)

async def main():
    t1 = loop.create_task(measure())
    t2 = loop.create_task(plot())
    await t2, t1
    
if __name__ == "__main__":
    ts = time.time()
    loop = asyncio.new_event_loop()
    loop.run_until_complete(main())

We will start with the functions async def measure() and async def plot() since they are central to the program. First, note that neither of them are ordinary functions because of the async keyword. This keyword has been added in Python 3.5 and in earlier Python versions we could have instead decorated the functions with the @asyncio.coroutine decorator. The async keyword turns our function into a coroutine which allows us to use the await keyword inside. With the await keyword we can put the coroutine to sleep with await asyncio.sleep(measurement_interval). While asleep the asyncio event loop can run other coroutines that are not asleep. More on the asyncio event loop later. Because we want to keep measuring until someone terminates the program we wrap everything in measure into an infinite loop while True:.

So what do we do while measuring? Outside of the coroutine we define two arrays m, t, one to hold the measured power and the other to measure the passed time. Measuring time is important because energy is power during a time period and we generally need to be sure that the coroutine isn’t getting stuck asleep much longer than we want it to. When we measure a value we move the current elements in the measurement array one to the left by assignment with m[:-1] = m[1:]. We then assign the newly measured value to the right of the array with m[-1] = measure. That is all there is to our measurements.

Our plot coroutine works just like the measure coroutine except that it plots whatever is in the time and measurement arrays before it goes to sleep. The plotting itself is basic matplotlib but it is important to note that figure.canvas.flush_events() is critical for updating the plot in real time. Furthermore, when we initialize the plot, plt.ion() is important for the plot to show properly.

Coroutines are not called like normal functions. They do their work as tasks within an asyncio event loop. This event loop knows which coroutines are asleep and decides which coroutine starts working next. This task may seem manageable with two coroutines but with three it becomes tedious already. As a coroutine goes to sleep two may be awake, waiting to get to work. The event loop has to decide which one goes next. Luckily asyncio takes care of the details for us and we can focus on the work we want to get done instead. However, we need to create an event loop with loop = asyncio.new_event_loop() and then we start it with loop.run_until_complete(main()). The coroutines only get to work when the loop starts. Both our coroutines are in main(), thereby both become part of the event loop. Because of the event loop I recommend running the code from the command line. Running it in interactive environments can cause problems because other event loops might already be running there.

With that, we already covered the most important parts of the code. There are several things we could do differently and some of those might make the code better. For one, we could use a technique called blitting (explained here) to improve the performance of the plotting. We could also do the plotting with FuncAnimation (explained here) instead of writing our own coroutine. I tried that for a while but was not able to make the animation and the measurement() coroutine work together in the same event loop. There probably is a way to do it that I did not find. Let me know if you have other points for improvement.

You can find pynvml here. asyncio is part of the Python installation and you can find the docs here. I was inspired to do this project by a package called codecarbon that you can find here. It estimates the carbon footprint of computation and I plan to blog about it soon.

Interactive data dashboards in Jupyter notebook with ipywidgets and Bokeh

In this post I will go though the code for a simple data dashboard that visualizes the Iris dataset. It features two dropdown menus and three checkboxes. The dropdown menus choose the features on the x and y axes, while the checkboxes make samples visible or invisible based on their species. Below is a screenshot and a video of the dashboard. Before I start with the code I give a brief rational for using ipywidgets and Bokeh.

Combining ipywidgets with Bokeh

A main advantage of ipywidgets is that it is designed specifically for Jupyter notebooks and the IPython kernel. Bokeh on the other hand can build data dashboard for a variety of more complex web deployment contexts. This makes it more powerful and technically it could be used to build the entire dashboard. In a notebook context however, I prefer the simplicity of ipywidgets over the power of Bokeh. I actually started doing the plotting with Matplotlib or seaborn. This worked well in general but I ran into some problems when directing the plot to specific places in the dashboard. I did not encounter any such issues with Bokeh for reasons I do not yet understand. Using Bokeh also gives some nice interactive features in the figure without any extra effort. I think it just lends itself better to these interactive dashboards than Matplotlib. If you don’t want to learn about Bokeh and already know Matplotlib, ipywidgets plus Matplotlib is definitely a good option and most of the ipywidgets principles I show here apply either way.

The dashboard code

Here is the code that generates the dashboard when executed in a Jupyter notebook.

from sklearn import datasets
import pandas as pd
import numpy as np
from bokeh.plotting import figure, show, output_notebook
import ipywidgets as widgets
from IPython.display import display, clear_output
output_notebook()

"""Load Iris dataset and transform the pandas DataFrame"""
iris = datasets.load_iris()
data = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
                     columns= iris['feature_names'] + ['target'])

"""Define callback function for the UI"""
def var_dropdown(x):
    """This function is executed when a dropdown value is changed.
    It creates a new figure according to the new dropdown values."""
    p = create_figure(
    x_dropdown.children[0].value,
    y_dropdown.children[0].value,
    data)
    fig[0] = p
    
    for species, checkbox in species_checkboxes.items():
        check = checkbox.children[0].value
        fig[0].select_one({'name': species}).visible = check
    
    with output_figure:
        clear_output(True)
        show(fig[0])
    fig[0]=p
    
    return x

def f_species_checkbox(x, q):
    """This function is executed when a checkbox is clicked.
    It directly changes the visibility of the current figure."""
    fig[0].select_one({'name': q}).visible = x
    with output_figure:
        clear_output(True)
        show(fig[0])
    return x

def create_figure(x_var, y_var, data):
    """This is a helper function that creates a new figure and 
    plots values from all three species. x_var and y_var control
    the features on each axis."""
    species_colors=['coral', 'deepskyblue', 'darkblue']
    p = figure(title="",
               x_axis_label=x_var,
               y_axis_label=y_var)
    species_nr = 0
    for species in iris['target_names']:
        curr_dtps = data['target'] == species_nr
        circle = p.circle(
            data[x_var][curr_dtps],
            data[y_var][curr_dtps],
            line_width=2,
            color=species_colors[species_nr],
            name=species
            )
        species_nr += 1
    return p

# The output widget is where we direct our figures
output_figure = widgets.Output()

# Create the default figure
fig = []  # Storing the figure in a singular list is a bit of a 
          # hack. We need it to properly mutate the current
          # figure in our callbacks.
p = create_figure(
    iris['feature_names'][0],
    iris['feature_names'][1],
    data)
fig.append(p)
with output_figure:
    show(fig[0])

# Checkboxes to select visible species.
species_checkboxes = {}
for species in iris['target_names']:
    curr_cb = widgets.interactive(f_species_checkbox,
                                  x=True,
                                  q=widgets.fixed(species))
    curr_cb.children[0].description = species
    species_checkboxes[species] = curr_cb
    
"""Create the widgets in the menu"""
# Dropdown menu for x-axis feature.
x_dropdown = widgets.interactive(var_dropdown,
                                 x=iris['feature_names']);
x_dropdown.children[0].description = 'x-axis'
x_dropdown.children[0].value = iris['feature_names'][0]

# Dropdown menu for y-axis feature.
y_dropdown = widgets.interactive(var_dropdown,
                                 x=iris['feature_names']);
y_dropdown.children[0].description = 'y-axis'
y_dropdown.children[0].value = iris['feature_names'][1]



# This creates the menu 
menu=widgets.VBox([x_dropdown,
                   y_dropdown,
                   *species_checkboxes.values()])

"""Create the full app with menu and output"""
# The Layout adds some styling to our app.
# You can add Layout to any widget.
app_layout = widgets.Layout(display='flex',
                flex_flow='row nowrap',
                align_items='center',
                border='none',
                width='100%',
                margin='5px 5px 5px 5px')

# The final app is just a box
app=widgets.Box([menu, output_figure], layout=app_layout)

# Display the app
display(app)

Loading the Iris data

The Iris dataset contains 150 samples. Each sample belongs to one of three species and four features are measured for each sample: sepal length, sepal width, petal length and petal width, all in cm. The goal of the dashboard is to show a scatterplot of two features at a time and an option to turn visibility for each species on or off. I load the Iris data from the sklearn package but it is a widely used toy dataset and you can get it from other places. We also convert the dataset to a Pandas DataFrame. That just makes the data easier to handle. The callback functions that make the UI interactive are defined next.

Callback functions

Callback functions are executed when something happens in the UI. For now, the functions are just defined and later they are connected to widgets. We define two callback functions, var_dropdown(x) and f_species_checkbox(x, q). create_figure is not a callback function but a helper to create a new figure. var_dropdown(x) is responsible for the two dropdown menus. The dropdowns determine what is displayed on the figure axes. When a user changes the dropdown value, var_dropdown(x) creates a new figure where the features on x- and y-axis are determined by the new dropdown values. The first parameter x is the new value that the user chose. It is not used here, because the same call function will serve two different dropdown menus. So we don’t know which menu x refers to. Instead, we directly access the dropdown values with x_dropdown.children[0].value. This will be defined later. The same goes for the checkboxes we access in the for loop with checkbox.children[0].value. Each species has a checkbox and we will create them later in the code. Finally we display our figure in the with output_figure: context, which directs our figure to the output_figure widget.

The f_species_checkbox(x, q) callback is similar. It additionally features a parameter q, which is a fixed parameter which identifies the checkbox that triggered the callback. We use it to determine, which parts of the figure we need to make visible/invisible with fig[0].select_one({'name': q}).visible = x. Whenever we make changes to the look of the figure, we must redirect it to our output inside with output_figure: to make the changes visible.

Those are our callbacks. The figure creation in create_figure(x_var, y_var, data): is straightforward thanks to Bokeh. figure() creates the figure and then the for species in iris['target_names']: loop creates the points for each species. Next up is the actual widget creation.

Creating widgets

The first widget we create is output_figure = widgets.Output() which will display the figure. We next create a default figure and direct it to output_figure. The fact that we store the figure in a singular list by fig.append(p) is a bit of a hack. This has to do with scope within the callback functions. If we reassign p = figure(), we only change p inside the function. If we change fig[0] = figure() on the other hand, we change the list outside the function because lists are mutable.

Now that we have our figure we create a checkbox for each species in the for species in iris['target_names']: loop and store it in a dictionary so we can access each with the species name. Most of the magic happens in the widgets.interactive(f_species_checkbox, x=True, q=widgets.fixed(species)). It create the widget and links it to the f_species_checkbox callback all in one line. It decides to create a checkbox based on x=True. Booleans create checkboxes but later we will create the dropdown menus also with widgets.interactive. The additional parameter q=widgets.fixed(species) tells us which checkbox called f_species_checkbox.

The following two dropdown widgets are very similar. Nothing special there.

Now that we have our widgets, we need to actually assemble them in our UI. We do that with menu=widgets.VBox([x_dropdown, y_dropdown, *species_checkboxes.values()]). This creates a box where the widgets inside are oriented in a column. The V in VBox means vertical, hence a column. We are almost done. The specifications in widgets.Layout() are not critical but I want to show them here. This is not a widget in itself but we can pass it to a widget to change the style. widgets.Layout() exposed properties you might know from CSS.

To finish up we create the full app with app=widgets.Box([menu, output_figure], layout=app_layout). The app_layout specifies with the flex_flow that the menu and the figure output are oriented as a row, menu on the left and figure on the right. Finally, we display out app in the output under our cell with display(app).

Possible improvements

To improve this code, I think it would be better if the callbacks depended less on global variables. Therefore, a closer look at widgets.interactive might be useful. Alternatively, the global variables that are known to be used inside functions could be made explicit with the global keyword. Finally, in create_figure(), I am creating a counter variable species_nr = 0. This is unnecessary in Python but I did not have time to think through the Pythonic way to do this. I hope this has been useful for you. Let me know what kind of data dashboards you are building.

How to organize your research data during analysis

As researchers we often have to manage the entire data lifecycle from generation to the final report. Raw data are major parts of any analysis pipeline and it is important to properly store them. Many guidelines on data management focus on raw data but also stop there. However, raw data without analysis are pretty much worthless and there are major data management decisions to be made during analysis. Which intermediate results to save? What format? Where? Cloud storage? Do I build a database? Every decision will have effects in the actual analysis code and can determine how easy errors can be found and fixed. Here I want to give a detailed overview of the different options we have to organize our data during analysis. Python is my main programming language and this will be reflected here in that I prefer open source tools and I am not experienced with some of the commercial solutions that might exist. However, I will also cover the advantages/disadvantages of human-readable data formats and tools such as Microsoft Excel. We get started with raw data.

Raw data

Raw data are like a good hypothesis. Just like a hypothesis is perfect until you ruin it with data, raw data are perfect until you ruin them with a thorough analysis. Raw data are truth. This is how it happened. Therefore, the raw data is never changed and we always keep multiple copies of it. Even after we published the results of our analysis, it is a good idea to keep the raw data for at least 10 years. Depending on your funding and where you publish you might be legally and/or ethically required to keep raw data for a defined time period.

Anything we do with our raw data during the analysis we do in read-only mode. The only thing we might want to change is the file name. However, we should check whether there is a way to automate the file naming in the program that generates them. Automating file names is preferred and avoids human error. If we receive data from collaborators we should also make sure that we are not deleting or changing data that they placed in the file name or the directory hierarchy.

From an analysis point of view, we might not need any information in the file name at all. We could randomly generate it. However, we almost always want to store important metadata in the file name because it has two major advantages: 1. The file name is human-readable. You need neither programming knowledge nor knowledge about the file format to read it. 2. The file name is readable by the operating system. This means, our programs can read it even if they don’t know the file format. For example, you can use file name information to decide whether or not to load a file during analysis and save loading time.

So what do we put in the file name? The first rule is that it is better to save too much information as opposed to too little. But most file systems have limits on the file name size (I hit it many times in Windows 10). When you choose your file name format you can ask yourself: what would I want a human to know about the files before they even open them and what would be good to have easily accessible during analysis? You will read in a lot of guides that you must save the date and time. That is useful because it makes it easier to relate a file to your lab book. Otherwise I only recommend you to give each file a unique identifier. This can be a number or a string of characters. I prefer numbers because they can have the added bonus of giving the recording order at a glance. Other information in the file names is up to you.

Once you have decided which other information to put in your file names, choose one delimiter to separate items and another one for readability. I use underscore to separate items and minus for readability. That allows me to do things like: “001_YYYY-MM-DD_version-2.fmt”. Avoid empty spaces in file names. Whatever decision you make, you must remain consistent.

When it comes to the directories, I recommend having all raw files in the same directory. I used to save some directory hierarchy structure but nowadays I would put all of that into the file name. If you have a lot of files, some directory structure can help humans navigate it but during analysis I almost always prefer one directory for simplicity. The file name is certainly a good way to save metadata but there are other ways to save them. In the next chapter we will discuss metadata more generally.

Metadata

At the raw data stage, the definition of metadata is clear. It is any data that describes the context of the raw data but does not fit any of the raw data formats. Images are a good example. The raw data of an image are the pixel values that make up the digital image. Metadata of the image can be anything that relates to the context where the image was taken. The time it was taken, the place it was taken, the exposure time, the model of the lens used, the temperature that day, the person who took the image, the width, the height and much, much more. As the analysis progresses, metadata may become data. Let’s say your analysis pipeline counts the number of birds in each image and you want to find out whether there are more early birds than late birds. In that case the time of day the picture was taken becomes an independent variable.

So what kind of metadata should you save? That is up to you, your research question and your field but it is usually better to save too much metadata than too little. Metadata you did not save is almost impossible to reconstruct later, so think carefully when you decide that something is not worth saving.

How to store metadata? If you are lucky, all metadata you need might already be inside your raw data files. Most of the time however, you will have to add some metadata manually. In that case I prefer to create a single .csv file where each row describes a single raw data file. One of the advantages of .csv files is that they can be read by humans and script with a variety of programs.

The same metadata.csv file opened in four different programs. OpenOffice Calc (top left), Notepad (top right), Spyder IDE (lower left) and Notepad++ (lower right).

The id uniquely identifies the row and thereby the raw data file it refers to. A drawback of a .csv file is that it only works well for a single spreadsheet. If you require multiple related sheets – for example because you have multiple raw data sources – you might require a more complex format such as JSON or even build a database. More later on the reasons for and against building a database. Next, we will assume that we found a way to process our raw data and need to decide how to store it for analysis.

Analysis and storage of transformed raw data

Analysis of raw data most often means to iterate through all raw data files and do some computation on each. In our fictional example, we were taking images and then extracting birds from those images. Once we have extracted the birds, our unit of analysis changes. This has consequences for the metadata, which describes the raw data files. We need to relate the metadata from the raw data files to the birds and there are two major options to do this. First, we generate a spreadsheet where each row is a bird. In the image_id column we save the unique id of the image the bird was extracted from.

The extracted birds are related to a raw image by the fact that a specific bird was extracted from a single specific image. The tables in the top left and top right contain all information but during analysis we would need to open two files and relate the information. The table at the bottom has the advantage that both tables are merged into one. This can offer convenience but it increases the storage size of the file.

This is a relational data structure. The image_id in our birds table identifies a unique id in our metadata which allows us to figure out where, when and with what exposure a given bird was photographed. This is a great way to structure data. The only downside is that we have to load two files and then merge them during the analysis. Personally I prefer to create a single file where both tables are already merged. The downside of this data structure is that is is larger in storage size (because the metadata columns are repeated multiple times). Whether or not this is an issue for you depends on the total size of your data, the amount of rows extracted per image and your available disk space. If you decide on the relational data structure you might also want to consider building a database, because they are particularly well suited for relational storage. Next, we will discuss some advantages and disadvantages of databases.

You probably don’t want to build a database. Unless you do.

A database is any structured collection of information. This means that a collection of spreadsheets as we saw above could technically be considered a database. However, a database usually implies more rigorous rules of storage and access than a spreadsheet can guarantee. These rules give databases some of their major advantages. For example, multiple people and programs can interface with a database without messing things up. If two people open the same spreadsheet in excel and work on it they are almost guaranteed to end up with two conflicting versions. There are many other advantages of databases. When you decide for or against a database you should anticipate whether you will be able to make use of those advantages. Here is my list.

You should consider building a database if:

  • You will have to integrate multiple raw data sources.
    • Multiple raw data sources usually mean more relationships. Multiple relationships become harder to manage with spreadsheet but are the perfect use-case for a relational database.
  • You expect to store a massive amount of data.
    • A database usually stores data more efficiently, which can save you storage space.
    • A database can also make access more efficient, because you don’t load the entire database but request smaller packages of information at a time. This avoids running into RAM issues.
  • You will be adding data continuously to the project over a long period of time.
    • Having a fixed logic for adding information to the database is a massive advantage over long time periods and can avoid a lot of errors.
    • During very long projects you might want to train other people to interface with the database or even hand the project over to someone else. Having the relational rules that avoid errors is very helpful, as opposed to a spreadsheet where anyone can do whatever.
  • You want to make you data available through a web interface.
    • A web server interfacing with a spreadsheet would be very inefficient.
  • Multiple people need read/write access simultaneously.
    • As mentioned above, multi-user access works much better for databases.
  • You have security concerns.
    • This can be a security concern regarding the physical integrity of the data or malicious/unauthorized access. Databases are not perfect but they provide better safeguards than a spreadsheet

Those advantages sound amazing. So why would you ever not build a database? Well, many research projects are structured in a way that you cannot make use of any of these advantages. If a single person generates all their data in a week, processes the raw data the other week and prepares a final report in the third week, a database is overkill. The rules that database architectures follow provide major advantages but they also make them less flexible. Building and maintaining a database comes with some extra work that you only want to take on when you can make use of some of the advantages. Generally speaking, the larger a project is (regarding everything from number of people involved over data sources to time frame), the more worthwhile a database becomes.

Before we move on to data analysis, I want to briefly mention cloud storage. Cloud storage has some advantages but most of them can also be achieved in other ways. In my opinion, the most important reason to consider cloud storage is if you regularly need access to the data from a large variety of different locations that are well connected to the internet. However, if you have too much money, a commercial cloud storage can also provide services for you that might be annoying to maintain yourself. On the other hand, as a researcher you may also not be allowed (because of funding or institutional policies) to use certain cloud services. Other details of cloud storage are beyond this guide.

Visualization and statistical analysis

So far we extracted samples (birds) from our images and quantified features of those birds (species and size). Some additional features are given by the image metadata. There might be additional intermediate processing steps whose results you want to save. That’s up to you. For example, maybe you had to apply some filters to the images when you extracted the birds. Do you want to save these filtered images? If the filtering takes very long and you want to visually inspect the result, you probably want to save them. Otherwise, having the script that performs the filtering is sufficient.

The plots you draw are something you definitely want to save. You also want to save the results of statistical analysis. At this stage it is important to have a clear relationship between analysis scripts and their outputs. I hate to look through multiple scripts, trying to find the code that generated a specific figure I want to change. You could just try to put everything into one script. I have done this before but depending on the size of the project this becomes unwieldy very quickly and you will find yourselves scrolling through thousands of lines of code. There is no perfect recipe here so I think a practical example is in order.

A practical example for a project structure

Here I will give some examples, how the bird projects could look as a directory hierarchy.

The directory structure of an example project.

In the project folder itself are only Python scripts and a README.txt, where we can give information that anyone looking at our project should be aware of. Having everything else in specific folders makes it easier to locate a Python script. There is one more advantage: if you use git to version control or share your projects, you can exclude data and figures from tracking. Data is often too large to push to remote or we don’t want to make unpublished data public. Therefore we can simply add the data folders we want to keep local to a .gitignore file. However, managing your project with git is not a must and for smaller projects it does not always pay off.

Now for the scripts themselves. Each script should have a very well defined input and output. extract_birds.py for example loops through all raw files in ./raw and creates the file ./data/birds.csv. The other two scripts both use that same birds.csv file but they answer different questions and have different outputs. analyze_location.py creates the figure 02_figure_location.png and the statistical output location_anova.txt. On the other hand analyze_species_size.py creates 01_figure_species_size.png and species_ttest.txt.

In this example the same script visualizes and performs the statistical test. You could split this up further and have a script for visualization and another for stats. I prefer to combine them because both tasks require loading of the same data. On another note, I prefer to keep things like reports, papers and presentation in a separate folder. In my experience they clutter the project space.

Planning a project in detail is important but there are also diminishing returns. At some point the only way forward is to actually start the project. I hope you learned some useful tricks to manage your own data more effectively.

Balanced spiking neural networks with NumPy

Balanced spiking neural networks are a cornerstone of computational neuroscience. They make for a nice introduction into spiking neuronal networks and they can be trained to store information (Nicola & Clopath, 2017). Here I will present a Python port I made from a MATLAB implementation by Nicola & Clopath and I will go through some of its features. We only need NumPy.

import numpy as np
from numpy.random import rand, randn
import matplotlib.pyplot as plt


def balanced_spiking_network(dt=0.00005, T=2.0, tref=0.002, tm=0.01,
                             vreset=-65.0, vpeak=-40.0, n=2000, 
                             td=0.02, tr=0.002, p=0.1, 
                             offset=-40.00, g=0.04, seed=100, 
                             nrec=10):
    """Simulate a balanced spiking neuronal network

    Parameters
    ----------
    dt : float
        Sampling interval of the simulation.
    T : float
        Duration of the simulation.
    tref : float
        Refractory time of the neurons.
    tm : float
        Time constant of the neurons.
    vreset : float
        The voltage neurons are set to after a spike.
    vpeak : float
        The voltage above which a spike is triggered
    n : int
        The number of neurons.
    td : float
        Synaptic decay time constant.
    tr : float
        Synaptic rise time constant.
    p : float
        Connection probability between neurons.
    offset : float
        A constant input into all neurons.
    g : float
        Scaling factor of synaptic strength
    seed : int
        The seed makes NumPy random number generator deterministic.
    nrec : int
        The number of neurons to record.

    Returns
    -------
    ndarray
        A 2D array of recorded voltages. Rows are time points,
        columns are the recorded neurons. Shape: (int(T/dt), nrec).
    """

    np.random.seed(seed)  # Seeding randomness for reproducibility

    """Setup weight matrix"""
    w = g * (randn(n, n)) * (rand(n, n) < p) / (np.sqrt(n) * p)
    # Set the row mean to zero
    row_means = np.mean(w, axis=1, where=np.abs(w) > 0)[:, None]
    row_means = np.repeat(row_means, w.shape[0], axis=1)
    w[np.abs(w) > 0] = w[np.abs(w) > 0] - row_means[np.abs(w) > 0]

    """Preinitialize recording"""
    nt = round(T/dt)  # Number of time steps
    rec = np.zeros((nt, nrec))

    """Initial conditions"""
    ipsc = np.zeros(n)  # Post synaptic current storage variable
    hm = np.zeros(n)  # Storage variable for filtered firing rates
    tlast = np.zeros((n))  # Used to set  the refractory times
    v = vreset + rand(n)*(30-vreset)  # Initialize neuron voltage

    """Start integration loop"""
    for i in np.arange(0, nt, 1):
        inp = ipsc + offset  # Total input current

        # Voltage equation with refractory period
        # Only change if voltage outside of refractory time period
        dv = (dt * i > tlast + tref) * (-v + inp) / tm
        v = v + dt*dv

        index = np.argwhere(v >= vpeak)[:, 0]  # Spiked neurons

        # Get the weight matrix column sum of spikers
        if len(index) > 0:
            # Compute the increase in current due to spiking
            jd = w[:, index].sum(axis=1)

        else:
            jd = 0*ipsc

        # Used to set the refractory period of LIF neurons
        tlast = (tlast + (dt * i - tlast) *
                 np.array(v >= vpeak, dtype=int))

        ipsc = ipsc * np.exp(-dt / tr) + hm * dt

        # Integrate the current
        hm = (hm * np.exp(-dt / td) + jd *
              (int(len(index) > 0)) / (tr * td))

        v = v + (30 - v) * (v >= vpeak)

        rec[i, :] = v[0:nrec]  # Record a random voltage
        v = v + (vreset - v) * (v >= vpeak)

    return rec


if __name__ == '__main__':
    rec = balanced_spiking_network()
    """PLOTTING"""
    fig, ax = plt.subplots(1)
    ax.plot(rec[:, 0] - 100.0)
    ax.plot(rec[:, 1])
    ax.plot(rec[:, 2] + 100.0)
Three neurons in the balanced spiking neural network.

The weight matrix

At the core of any balanced network is the weight matrix. We define it on line 54 to 58. Initializing it from a normal distribution and normalizing the row mean makes sure that excitation and inhibition are in balance. That is what keeps the network spiking irregularly although the input to the network remain constant. The constant input to the network is the offset parameter.

Refractory period

The refractory period is a time window where no action potential can be generated. We achieve this by setting the voltage to a low value right after the spike and then we do not update the voltage of the spike for a given time. This time window is given be tref. We update the voltage on line 76. In the same line we check how long ago the last spike occurred with the expression (dt * i > tlast + tref). Therefore, we need to track the most recent spike time with tlast. Of course we have some other things to do when a neuron reaches the spiking threshold vpeak. First we set the voltage to a value well above the threshold on line 99. This is purely visual to give a spiky appearance in the recording. So right after we recorded on line 101 we set the voltage to its reset value vreset.

Play around with some of the parameters. You can find the code here: https://gist.github.com/danielmk/9adc7409f40a076ffec0cdf85dea4519

Spiking neural networks for a low-energy future

Spiking neural networks (SNNs) have some disadvantages compared to artificial neural networks (ANNs) but they have the potential to run for a fraction of the energy. Whether SNNs will be able to replace ANNs and how much energy they will be using depends on many engineering and neuroscience advances. Here I will go through some of the technical background of the SNN energy advantages and some of the current numbers.

Energy efficient SNN features

The energy efficiency of SNNs comes primarily from two features. Firstly, the spike is a discrete event and energy is only used when a spike occurs. This is probably the most fundamental feature that distinguishes SNNs from ANNs. This means that the energy efficiency of a SNN depends not only on the number of neurons but also on the number of spikes the model requires to perform. The second feature is local memory. At the heart of all models are parameters. On traditional hardware such as CPUs and GPUs, the part of the chip that performs the calculations is not the same that remembers the parameters. Loading the parameters onto the chip is much more energy intensive than the computation itself. Therefore, when the parameters can be stored locally on the chip that computes, efficiency advantages result. This is not something unique to spiking neural networks. Some tensor processing units (TPUs) also feature local memory and they are specifically designed for ANNs. When most people speak about the energy advantages of SNNs, they assume local memory.

SNNs also require specialized hardware to run efficiently. That hardware is called neuromorphic. It makes efficient use of the binary nature of spikes and local memory. Neuromorphic hardware is so far only available for research purpose and making it more widely available will be one of the challenges to SNN adoption. Next will be some numbers on energy efficiency.

How efficient are we talking?

How much more efficient SNNs are depends on many factors of the comparison. What is the task, what are the model architectures and what is the hardware. Making projections into the future is even harder, since machine learning advances are made quickly on both SNNs and ANNs. Projecting the absolute amount of energy that could be saved is then even harder because it requires AI demand predictions which can change non-linearly with technical advances. I would be interested in finding formal work on some of these uncertainties or work on some myself but for now here are some numbers.

The Loihi processor from Intel Labs is a recent piece of neuromorphic hardware. Depending on the size of their example problem they find that Loihi is 2.58x, 8.08x or 48.74x more energy efficient than a 1.67-GHz Atom CPU (Davies et al. 2018).

Yin et al. (2020) present a method to train SNNs (backpropagation of surrogate gradients). They calculate the theoretical energy consumption for a spiking recurrent network they train with the method and some ANN architectures. Depending on the task, their SNN was 126.2x, 935x, 1602x, 1776x or 3353.3x more efficient than a Long Short-Term Memory network (LSTM; also depends on some details of the LSTM implementation). Their network was 41.3x more efficient compared to a recurrent ANN. Here is a talk from the last author Sander Bohte where he summarizes the findings as >100x more efficient than best recurrent ANN and 1000x more efficient than LSTM. All their calculations assume local memory.

Panda et al. (2012) tried several methods to generate SNNs for image classification and calculated theoretical energy consumption. They estimate better efficiencies of SNNs of 6.52x, 7.7x, 10.6x, 74.9x, 81.3x, 104.8x depending on model architecture and parameter space.

Merolla et al. (2014) present the TrueNorth neuromorphic architecture. They compare synaptic operations per second (SOPS) of their architecture to floating-point operations per second (FLOPS) of traditional chips. They say that TrueNorth can deliver 46 billion SOPS per watt. The most energy-efficient supercomputer they say (at time of their writing) generates 4.5 billion FLOPS per watt.

These numbers highlight the potential for some massive energy savings but benchmarks are always complicated. Making good comparisons can be hard, especially since the unit of computational efficiency is fundamentally different. Either way, SNNs on neuromorphic hardware are extremely energy efficient but to truly save energy they must become better at the tasks ANNs already solve.

References

Davies et al. 2018. Loihi: A Neuromorphic Manycore Processor with On-Chip Learning. IEEE Micro. 10.1109/MM.2018.112130359

Yin, Corradi & Bohte 2020. Effective and Efficient Computation with Multiple-timescale Spiking Recurrent Neural Networks. https://arxiv.org/abs/2005.11633

Panda, Aketi & Roy, 2012. Towards Scalable, Efficient and Accurate Deep Spiking Neural Networks with Backward Residual Connections, Stochastic Softmax and Hybridization. https://arxiv.org/abs/1910.13931.

Merolla et al. (2014). A million spiking-neuron integrated circuit with a scalable communication network and interface. https://science.sciencemag.org/content/345/6197/668

Spiking Neuronal Networks in Python

Spiking neural networks (SNNs) turn some input into an output much like artificial neural networks (ANNs), which are already widely used today. Both achieve the same goal in different ways. The units of an ANN are single floating-point numbers that represent the activity levels of the units for a given input. Neuroscientists loosely understand this number as the average spike rate of a unit. In ANNs this number is usually the result of multiplying the input with a weight matrix. SNNs work differently in that they simulate units as spiking point processes. How often and when a unit spikes depends on the input and the connections between neurons. A spike causes a discontinuity in other connected units. SNNs are not yet widely used outside of laboratories but they are important to help neuroscientists model brain circuitry. Here we will create spiking point models and connect them. We will be using Brian2, a Python package to simulate SNNs. You can install it with either

conda install -c conda-forge brian2
or
pip install brian2

We will start by defining our spiking unit model. There are many different models of spiking. Here we will define a conductance based version of the leaky integrate-and-fire model.

from brian2 import *
import numpy as np
import matplotlib.pyplot as plt

start_scope()

# Neuronal Parameters
c = 100*pF
vl = -70*mV
gl = 5*nS

# Synaptic Parameters
ge_tau = 20*ms
ve = 0*mV
gi_tau = 100*ms
vi = -80*mV
w_ge = 1.0*nS
w_gi = 0.0*nS

lif = '''
dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(ve - vi) - I)/c : volt
dge/dt = -ge / ge_tau : siemens
dgi/dt = -gi / gi_tau : siemens
I : amp
'''

The intended way to import Brian2 is from brian2 import *. This feels a bit dangerous but it is very important to keep the code readable, especially when dealing with physical units. Speaking of units, each of our parameters has a physical unit. We have picofarad (pF), millivolt (mV) and nanosiemens (nS). These are all imported from Brian2 and we assign units with the multiplication operator *. To find out what each of those parameters does, we can look at our actual model. The model is defined by the string we assign to lif. The first line of the string is the model of our spiking unit:

dv/dt = -(gl * (v - vl) + ge * (v - ve) + gi *(v - vi) - I)/c : volt

This is a differential equation that describes the change of voltage with respect to time. In electrical terms, voltage is changed by currents. These physical terms are not strictly necessary for the computational function of SNNs but they are helpful to understand the biological background behind them. Three currents flow in our model.

The first one is gl * (v - vl). This current is given by the conductance (gl), the current voltage (v) and the equilibrium voltage (vl). It flows whenever the voltage differs from vl. gl is just a scaling constant that determines how strong the drive towards vl is. This is why vl is called the resting potential, because v does not change when it is equal to vl. This term makes our model a leaky integrate-and-fire model as opposed to an integrate-and-fire model. Of course the voltage only remains at rest if it is not otherwise changed. There are three other currents that can do that. Two of them correspond to excitatory and inhibitory synaptic inputs. ge * (v - ve) drives the voltage towards ve. Under most circumstances, this will increase to voltage as ve equals 0mV. On the other hand gi *(ve - vi) drives the voltage towards vi, which is even slightly smaller than vl with -80mV. This current keeps the voltage low. Finally there is I, which is the input current that the model receives. We will define this later for each neuron. Finally, currents don’t change the voltage immediately. They are all slowed down by the capacitance c. Therefore, we divide the sum of all currents by c.

There are two more differential equations that describe our model:

dge/dt = -ge / ge_tau : siemens
dgi/dt = -gi / gi_tau : siemens

These describe the change of the excitatory and inhibitory synaptic conductances. We did not yet implement the way spiking changes ge or gi. However, these equations tell us that both ge and gi will decay towards zero with the time constants ge_tau and gi_tau. So far so good, but what about spiking? That comes up next, when we turn the string that represents our model into something that Brian2 can actually simulate.

G = NeuronGroup(3, lif, threshold='v > -40*mV',
                reset='v = vl', method='euler')
G.I = [0.7, 0.5, 0]*nA
G.v = [-70, -70, -70]*mV

The NeuronGroup creates for us three units that are defined by the equations in lif. The threshold parameter gives the condition to register a spike. In this case, we register a spike, when the voltage is larger than 40mV. When a spike is registered, an event is triggered, defined by reset. Once a spike is registered, we reset the voltage to the resting voltage vl. The method parameter gives the integration method to solve the differential equations.

Once our units are defined, we can interface with some parameters of the neurons. For example, G.I = [0.7, 0.5, 0]*nA sets the input current of the zeroth neuron to 0.7nA, the first neuron to 0.5nA and the last neuron to 0nA. Not all parameters are accessible like this. I is available because we defined I : amp in our lif string. This says that I is available for changes. Next, we define the initial state of our neurons. G.v = [-70, -70, -70]*mV sets all of them to the resting voltage. A good place to start.

You might be disappointed by this implementation of spiking. Where is the sodium? Where is the amplification of depolarization? The leaky integrate-and-fire model doesn’t feature a spiking mechanism, except for the discontinuity at the threshold. If you are interested in incorporating spike-like mechanisms you should look for the exponential leaky integrate-and-fire or a Hodgkin-Huxley like model.

We are only missing one more ingredient for an actual network model: the synapses. And we are in a great position, because our model already defines the excitatory and the inhibitory conductances. Now we just need to make use of them.

Se = Synapses(G, G, on_pre='ge_post += w_ge')
Se.connect(i=0, j=2)

Si = Synapses(G, G, on_pre='gi_post += w_gi')
Si.connect(i=1, j=2)

First we create an excitatory connection from our neurons onto themselves. The connection is excitatory because it increases the conductance ge of the postsynaptic neuron by w_ge. We then call Se.connect(i=0, j=2) to define which neurons actually connect. In this case, the zeroth neuron connects to the second neuron. This means, spikes in the zeroth neuron cause ge to increase in the second neuron. We then create the inhibitory connection and make the first neuron inhibit the second neuron. Now we are ready to run the actual simulation and plot the result. Remember that for now w_gi = 0.0*nS, meaning that we will only see the excitatory connection.

M = StateMonitor(G, 'v', record=True)

run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
Voltage traces from three neurons. Spiking of N0 increases voltage in N2. Spiking of N1 does nothing in this simulation because its weight onto N2 was set to 0mV.

The first two neurons are regularly spiking. N0 is slightly faster because it receives a larger input current than N1. N2 on the other hand rests at -70mV because it does not receive an input current. When N0 spikes, it causes the voltage of N2 to increase as we expected, because N0 increases the excitatory conductance. N1 is not doing anything here, because its weight is set to 0. Continued activity of N0 eventually causes N2 to reach the threshold of -40mV, making it register its own spike and reset the voltage. What happens if we introduce the inhibitory weight?

w_gi = 0.5*nS
run(100*ms)

fig, ax = plt.subplots(1)
ax.plot(M.t/ms, M.v[0]/mV)
ax.plot(M.t/ms, M.v[1]/mV)
ax.plot(M.t/ms, M.v[2]/mV)
ax.set_xlabel('time (ms)')
ax.set_ylabel('voltage (mV)')
ax.legend(('N0', 'N1', 'N2'))
The same as above but now N1 has a weight of 0.5nS. This inhibits N2 and prevents a spike.

With a weight of 0.5nS, N1 inhibits N2 to an extent that prevents the spike. We have now built a network of leaky integrate-and-fire neurons that features both excitatory and inhibitory synapses. This is just the start to getting a functional network that does something interesting with the input. We will need to increase the number of neurons, decide on a connectivity rule between them, initialize or even learn weights, decide on a coding scheme for the output and much more. Many of these I will cover in later blog posts so stay tuned.

Differential Equations with SciPy – odeint or solve_ivp

SciPy features two different interfaces to solve differential equations: odeint and solve_ivp. The newer one is solve_ivp and it is recommended but odeint is still widespread, probably because of its simplicity. Here I will go through the difference between both with a focus on moving to the more modern solve_ivp interface. The primary advantage is that solve_ivp offers several methods for solving differential equations whereas odeint is restricted to one. We get started by setting up our system of differential equations and some parameters of the simulation.

UPDATE 07.02.2021: Note that I am focusing here on usage of the different interfaces and not on benchmarking accuracy or performance. Faruk Krecinic has contacted me and noted that assessing accuracy would require a system with a known solution as a benchmark. This is beyond the scope of this blog post. He also pointed out to me that the hmax parameter of odeint is important for the extent to which both interfaces give similar results. You can learn more about the parameters of odeint in the docs.

import numpy as np
from scipy.integrate import odeint, solve_ivp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def lorenz(t, state, sigma, beta, rho):
    x, y, z = state
    
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    
    return [dx, dy, dz]

sigma = 10.0
beta = 8.0 / 3.0
rho = 28.0

p = (sigma, beta, rho)  # Parameters of the system

y0 = [1.0, 1.0, 1.0]  # Initial state of the system

We will be using the Lorenz system. We can directly move on the solving the system with both odeint and solve_ivp.

t_span = (0.0, 40.0)
t = np.arange(0.0, 40.0, 0.01)

result_odeint = odeint(lorenz, y0, t, p, tfirst=True)
result_solve_ivp = solve_ivp(lorenz, t_span, y0, args=p)

fig = plt.figure()
ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot(result_odeint[:, 0],
        result_odeint[:, 1],
        result_odeint[:, 2])
ax.set_title("odeint")

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot(result_solve_ivp.y[0, :],
        result_solve_ivp.y[1, :],
        result_solve_ivp.y[2, :])
ax.set_title("solve_ivp")
Simulation results from odeint and solve_ivp. Note that the solve_ivp looks very different primarily because of the default temporal resolution that is applied. Changing the the temporal resolution and getting very similar results to odeint is easy and shown below.

The first thing that sticks out is that the solve_ivp solution is less smooth. That is because it is calculated at fewer time points, which in turn has to do with the difference between t_span and t. The odeint interface expects t, an array of time points for which we want to calculate the solution. The temporal resolution of the system is given by the interval between time points. The solve_ivp interface on the other hand expects t_span, a tuple that gives the start and end of the simulation interval. solve_ivp determines the temporal resolution by itself, depending on the integration method and the desired accuracy of the solution. We can confirm that the temporal resolution of solve_ivp is lower in this example by inspecting the output of both functions.

t.shape
# (4000,)

result_odeint.shape
# (4000, 3)

result_solve_ivp.t.shape
# (467,)

The t array has 4000 elements and therefore the result of odeint has 4000 rows, each row being a time point defined by t. The result of solve_ivp is different. It has its own time array as an attribute and it has 1989 elements. This tells us that solve_ivp indeed calculated fewer time points affection temporal resolution. So how can we increase the the number of time points in solve_ivp? There are three ways: 1. We can manually define the time points to integrate, similar to odeint. 2. We can decrease the error of the solution we are willing to tolerate. 3. We can change to a more accurate integration method. We will first change the integration method. The default integration method of solve_ivp is RK45 and we will compare it to the default method of odeint, which is LSODA.

solve_ivp_rk45 = solve_ivp(lorenz, t_span, y0, args=p,
                            method='RK45')
solve_ivp_lsoda = solve_ivp(lorenz, t_span, y0, args=p,
                           method='LSODA')

fig = plt.figure()
ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot(solve_ivp_rk45.y[0, :],
        solve_ivp_rk45.y[1, :],
        solve_ivp_rk45.y[2, :])
ax.set_title("RK45")

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot(solve_ivp_lsoda.y[0, :],
        solve_ivp_lsoda.y[1, :],
        solve_ivp_lsoda.y[2, :])
ax.set_title("LSODA")
Comparison between RK45 and LSODA integration methods of solve_ivp.

The LSODA method is already more accurate but we can make it even more accurate but it is still not as accurate as the solution we got from odeint. That is because we made odeint solve at even higher temporal resolution when we passed it t. To get the exact same result from solve_ivp we got from odeint, we must pass it the exact time points we want to solve with the t_eval parameter.

t = np.arange(0.0, 40.0, 0.01)
result_odeint = odeint(lorenz, y0, t, p, tfirst=True)
result_solve_ivp = solve_ivp(lorenz, t_span, y0, args=p,
                             method='LSODA', t_eval=t)

fig = plt.figure()
ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot(result_odeint[:, 0],
        result_odeint[:, 1],
        result_odeint[:, 2])
ax.set_title("odeint")

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot(result_solve_ivp.y[0, :],
        result_solve_ivp.y[1, :],
        result_solve_ivp.y[2, :])
ax.set_title("solve_ivp LSODA")
odeint and solve_ivp with identical integration method and temporal resolution.

Now both solutions have identical temporal resolution. But their solution is still not identical. I was unable to confirm why that is the case but I suspect very small floating point errors. The Lorenz attractor is a chaotic system and even small errors can make it diverge. The following plot shows the first variable of the system for odeint and solve_ivp from the above simulation. It confirms my suspicion that floating point accuracy is to blame but was unable to confirm the source.

Solutions from odeint and solve_ivp diverge even for identical temporal resolution and integration method.

There is one more way to make the system more smooth: decrease the tolerated error. We will not go into that here because it does not relate to the difference between odeint and solve_ivp. It can be used in both interfaces. There are some more subtle differences. For example, by default, odeint expects the first parameter of the problem to be the state and the second parameter to be time t. This is the other way around for solve_ivp. To make odeint accept a problem definition where time is the first parameter we set the parameter tfirst=True.

In summary, solve_ivp offers several integration methods while odeint only uses LSODA.

Differential Equations in Python with SciPy

Differential equations are special because they don’t tell us the value of a variable straight up. Instead, they tell us by how much the variable will change with respect to the change of another variable. Usually that other variable is time. To numerically solve a system of differential equations we need to track the systems change over time starting at an initial state. This process is called numerical integration and there is a SciPy function for it called odeint. We will learn how to use this package by simulating the ‘hello world’ of differential equations: the Lorenz system.

Here is the first part of the code where we define the function that describes the dynamics of the system.

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
from mpl_toolkits.mplot3d import Axes3D

def lorenz(state, t, sigma, beta, rho):
    x, y, z = state
    
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    
    return [dx, dy, dz]

We start with some imports. Of course we need NumPy and odeint is imported from scipy.integrat. Matplotlib will be used to plot the result of our simulation. After that we define the system of differential equations that defines our Lorenz system. It consists of three differential equations that we fit into one function called lorenz. This function needs a specific call signature (lorenz(state, t, sigma, beta, rho)) because we will later pass it to odeint which expects specific parameters in specific places. Most importantly, the first parameter must be the state of the system.The state of the Lorenz system is defined by three variables: x, y, z. Our state object has to be a sequence with an order that reflects this.

Inside the lorenz function, the first thing we do is to unpack the state into the three state variables. This is followed by the three differential equations that described the dynamic changes of the state variables. The fact that the variable t does not show up in any of these equations is a common point of confusion. The amount of change certainly depends on the amount of time. So why can we ignore t here? The answer is that our numerical integrator will keep track of t for us. For this particular system we could actually build a function that does not take the parameter t but I include it because it can be useful if you want to add discontinuities that depend on t.

While t does not appear in the equations, sigma, beta & rho do. They are the parameters of the system and the system’s properties depend on them. We will set those parameters next.

sigma = 10.0
beta = 8.0 / 3.0
rho = 28.0

p = (sigma, beta, rho)

These are the parameters Lorenz himself used and they are known to produce the type of dynamic that the Lorenz system is most known for: the Lorenz attractor. It is important that we store these parameters in a tuple in this exact order because of our functions structure. It must be a tuple rather than another type of collection because odeint expects it. Now that our parameters are defined, we will move on to define the initial values of the system. This is a critical part of solving differential equations. These equations tell us by how much the system state changes but they cannot tell us where to start.

y0 = [1.0, 1.0, 1.0]

Our system will start with all variables at 1.0. Now we can solve the system and plot the result.

t = np.arange(0.0, 40.0, 0.01)

result = odeint(lorenz, y0, t, p)

fig = plt.figure()
ax = fig.gca(projection="3d")
ax.plot(result[:, 0], result[:, 1], result[:, 2])
The Lorenz attractor.

We solve the system with a simple call to to odeint and we pass it the function that defines out system, the initial state, the time points t for which we want to solve the system and the parameters p. result is a two-dimensional array where the rows are the time points and the columns are the state variables at that those time points. And this is how we can solve differential equations with SciPy.