Animations with Julia

Creating great looking animations in Julia is shockingly easy thanks for the Plots package and some macro magic. Here we will learn how to turn data into high quality animations. We will learn about the @animate macro, frames and the gif function.

Two Steps to Animations

To create animations we simply generate frames with the @animate macro and then generate a file with the gif function. Both are part of the Plots package from Julia, so we have to start with using Plots to make that package available. Next create the data we will plot, which is a simple sine wave. Then we generate the frame with the @animate macro and animate them with the gif function. This is the code and the resulting animation.

using Plots

x = collect(1:0.1:30)
y = sin.(x)
df = 2

anim = @animate for i = 1:df:length(x)
    plot(x[1:i], y[1:i], legend=false)
end

gif(anim, "tutorial_anim_fps30.gif", fps = 30)

Macros and Meta-Programming

The @animate macro deserves some extra attention, because it looks like magic. Macros are related to a concept called meta-programming. In Julia, all code is a data structure that can be manipulated in a similar way to all other data structures. This effectively means that we can write code that manipulates our code. That’s what a macro is, a function that modifies code. In our case, the code being modified is the for loop behind our @animate macro. It is modified in a way that it catches the frame at the end of each loop iteration and saves it into anim. We can create code that does the same job ourselves.

anim = Plots.Animation()
for i = 1:df:length(x)
    plot(x[1:i], y[1:i], legend=false)
    Plots.frame(anim)
end

gif(anim, "tutorial_anim_fps30.gif", fps = 30)

We use the Plots.Animation() function to create our animation object where we will store our frames. During the for loop we then call Plots.frame(anim) to store the frame after each iteration in our anim object. These are the essential steps that the @animate macro takes care of. If you want to learn what the macro does in detail you can call @macroexpand on it.

@macroexpand @animate for i = 1:df:length(x)
    plot(x[1:i], y[1:i], legend=false)
end

There is another macro that is even more convenient. The @gif macro. It saves us from having to call gif() on our anim object.

@gif for i = 1:df:length(x)
    plot(x[1:i], y[1:i], legend=false)
end

This directly displays the animation if interactive Julia is available for it. The downside of this is that we do not save the animation to disk and we do not have an anim object available to do more animations later. It is most useful to quickly troubleshoot animations interactively.

Beyond plot()

The @animate macro supports animations of anything that can be plotted with Plots. For example, we can animate a heatmap.

anim = @animate for i = 1:100
    mat = rand(0:100, 32, 32)
    heatmap(mat, clim=(0,255))
end

gif(anim, "tutorial_heatmap_anim.gif", fps = 10)

The frame that is being caught is the state of the active figure at the end of the for loop. The for loop itself gives us a great deal of control, how many frames we want to create. For example in the previous examples, I skipped frames with the df variable.

That’s it for animations. To learn more you can take a look at the official Plots documentation.

The Type System of Julia

Every value in Julia has a type. Like other popular languages such as Python, Julia is dynamically typed. That means, we don’t need to explicitly define the type of a value when we create it. However, types are particularly important in Julia because we can use explicit type definitions to speed up calculations. Here we will get an introduction into the type system of Julia.

Dynamic Typing

When we assign a value in Julia we don’t need to specify its type. The type is guessed based on the given value. Let’s create an integer, a float and a string.

myint = 3
myfloat = 3.0
mystr = "3.0"
println(typeof(myint), ": ", myint)
println(typeof(myfloat), ": ", myfloat)
println(typeof(mystr), ": ", mystr)
# Int64: 3
# Float64: 3.0
# String: 3.0

We use typeof() to find out what type a value has. To create an integer, we assign a number without decimal. To create a float, we simply attach a decimal place. For a string we need quotation marks around our value. There is another way to more explicitly define the type. Our float value is a Float64 (it uses 64 bits). What if we wanted a Float32?

myfloat = convert(Float32, 3.0)
typeof(myfloat)
# Float32

We use the convert(Type, value) function to explicitly convert 3.0 to Float32. If we want to ensure that a value is of a certain type, we can use the double colon syntax (::). It evaluates normally, if the value has the specified type but throws an error, if it has a different type.

3::Int
# 3
3::Float64
# TypeError: in typeassert, expected Float64, got Int64
# 
# Stacktrace:
#  [1] top-level scope at In[8]:1

This type assertion syntax is frequently used when defining functions. A function is a piece of code that takes input arguments and processes them to produce output values. We will learn more about functions later. Because Julia is dynamically typed, we could write our functions in a way that they work the same regardless of input types. However, defining functions in a way that is specific to the input types can be good for performance and Julia makes heavy use of that fact. For a given function, multiple methods might exist. Each method is responsible for a given set of input parameters. For example, we can inspect all the different methods of the mathematical cos by calling the methods function on it.

methods(cos)
"""
# 12 methods for generic function cos:
cos(x::BigFloat) in Base.MPFR at mpfr.jl:744
cos(::Missing) in Base.Math at math.jl:1167
cos(a::Complex{Float16}) in Base.Math at math.jl:1115
cos(a::Float16) in Base.Math at math.jl:1114
cos(z::Complex{T}) where T in Base at complex.jl:823
cos(x::T) where T<:Union{Float32, Float64} in Base.Math at special/trig.jl:100
cos(x::Real) in Base.Math at special/trig.jl:124
cos(A::LinearAlgebra.Hermitian{#s661,S} where S<:(AbstractArray{#s662,2} where #s662<:#s661) where #s661<:Complex) in LinearAlgebra at C:\Users\Daniel\AppData\Local\Programs\Julia\Julia-1.4.2\share\julia\stdlib\v1.4\LinearAlgebra\src\symmetric.jl:907
cos(A::Union{LinearAlgebra.Hermitian{#s662,S}, LinearAlgebra.Symmetric{#s662,S}} where S where #s662<:Real) in LinearAlgebra at C:\Users\Daniel\AppData\Local\Programs\Julia\Julia-1.4.2\share\julia\stdlib\v1.4\LinearAlgebra\src\symmetric.jl:903
cos(D::LinearAlgebra.Diagonal) in LinearAlgebra at C:\Users\Daniel\AppData\Local\Programs\Julia\Julia-1.4.2\share\julia\stdlib\v1.4\LinearAlgebra\src\diagonal.jl:561
cos(A::AbstractArray{#s662,2} where #s662<:Real) in LinearAlgebra at C:\Users\Daniel\AppData\Local\Programs\Julia\Julia-1.4.2\share\julia\stdlib\v1.4\LinearAlgebra\src\dense.jl:793
cos(A::AbstractArray{#s662,2} where #s662<:Complex) in LinearAlgebra at C:\Users\Daniel\AppData\Local\Programs\Julia\Julia-1.4.2\share\julia\stdlib\v1.4\LinearAlgebra\src\dense.jl:800
"""

The cos function has 12 different methods. The double colon type assertion syntax checks the input type. Which method is used can depend on the type of all inputs. This is called multiple dispatch. Multiple inputs determine which method is dispatched. We will learn more about multiple dispatch in another post. Since arrays are central to scientific computing, we next look at the type system and arrays.

Array Types

There are several different ways to define arrays. One way is to use literal numbers enclosed by square brackets. In that case, the type is inferred in the same way as above, where we used literal numbers.

myintarr = [1, 2, 3]
typeof(myintarr)
# Array{Int64,1}
myfloatarr = [1.0, 2.0, 3.0]
typeof(myfloatarr)
# Array{Float64,1}

For an array, typeof() tells us not only the type of the array elements but also how many dimensions the array has. One important feature of the array is that all elements must have the same type. So what happens when we create an array from literals with different types? The type is determined by the most complex type.

myintarr = [1, 2, 3]
typeof(myintarr)
# Array{Int64,1}
myfloatarr = [1, 2.0, 3]
typeof(myfloatarr)
# Array{Float64,1}
mystrarr = [1, "2.0", 3]
typeof(mystrarr)
# Array{Any,1}

We see that only one float value makes the entire array Float64. However, when one of the elements is a string, the type becomes Any, rather than String. That is because Julia does not convert numbers to strings. We get an error if we call convert(String, 3). If the most complex type is the Float64, Julia tries to convert all other values to that type. This works in the case where the other values are integers, because convert(Float64, 3) works. If other values cannot be converted to the most complex type all values take on the Any type. This means that values in the array can be of different types. They could literally be anything. This defeats the purpose of an array so we generally try to avoid it. There are other array creation methods besides the literal way. We can call a variety of methods that allow us to create arrays and they usually allow us to specify the type.

zerosarr = zeros(Int8, (2, 3))
# 2×3 Array{Int8,2}:
#  0  0  0
#  0  0  0
zerosarr = zeros((2, 3))
# 2×3 Array{Float64,2}:
#  0.0  0.0  0.0
#  0.0  0.0  0.0

If we don’t specify the type we want, the zeros() function default to Float64. There are other functions such as ones(), giving an array of ones, or rand(), giving an array of random numbers. The type of the output can be specified for each.

That is it for our basic introduction into types. You can find a more complete introduction to Julia types in the official documentation. In summary, we don’t need to explicitly specify the type of values but sometimes it might help to make math more efficient. We will learn in later posts more about performance and efficiency in Julia.

Getting Started Programming Julia

To get us started with Julia we cover three basics. Arithmetic operators, name assignment.

Arithmetic operators

The standard arithmetic operators are addition (+), subtraction (-), multiplication (*), division (/), power (^) and remainder (%). They work as expected and the only one that is different for the Python crowd is power. That one is Matlab consistent. The normal precedence of operations applies. First power. Then multiplication and division. Then remainder. Finally addition and subtraction. Parentheses can be used to change the order of operation.

1 + 3 * 2
# 7
(1 + 3) * 2
# 8
2 * 3 ^ 2
# 18

In those examples, both sides of the operator are scalars. It gets a little more interesting when at least one of them is a vector or a matrix. Not all of the above operations are defined between vectors and scalars. Only division and multiplication are defined. We create a vector using square brackets ([]) with the elements separated by commas.

[3, 1, 4] * 2
# 3-element Array{Int64,1}:
#  6
#  2
#  8

[3, 1, 4] / 2
# 3-element Array{Float64,1}:
# 1.5
# 0.5
# 2.0

The other operations are not defined and throw an error.

[3, 1, 4] ^ 2
"""
MethodError: no method matching ^(::Array{Int64,1}, ::Int64)
Closest candidates are:
  ^(!Matched::Float16, ::Integer) at math.jl:885
  ^(!Matched::Regex, ::Integer) at regex.jl:712
  ^(!Matched::Missing, ::Integer) at missing.jl:155
  ...

Stacktrace:
 [1] macro expansion at .\none:0 [inlined]
 [2] literal_pow(::typeof(^), ::Array{Int64,1}, ::Val{2}) at .\none:0
 [3] top-level scope at In[52]:1
"""

The reasons for this design choice have to do with Julias focus on linear algebra and are not important here. If we want this operation to work in an element-wise manner, we have to force it explicitly. We can do so using the dot (.). This way we can explicitly force every operator to be applied element-wise.

[3, 1, 4] .^ 2
# 3-element Array{Int64,1}:
#  9
#  1
# 16
[3, 1, 4] .+ 2
# 3-element Array{Int64,1}:
#  5
#  3
#  6

Now that we have our arithmetic operators, let’s move on to name assignment so we can store the results of our operations.

Name Assignment

We assign names to values with the = operator using the syntax name = value. Once a name is assigned, we can use the name instead of the value in operations.

result_one = 2 + 2
result_two = result_one + 3
# 7

Once a name is assigned to a value we can reassign that same name to a different value without problem.

result = 2 + 2
result = 10
# 10

If we want to assign a name that is not supposed to be reassigned, we can use the const keyword. If we try to reassign a constant name we get an error.

const a = 8.3144621
a = 3
"""
invalid redefinition of constant a

Stacktrace:
 [1] top-level scope at In[73]:2
"""

There are a few rules about the names we can assign. Generally, Unicode characters (UTF-8) are allowed. This means we can do something like this.

δt = 0.0001

Here we are using the special character delta. If you want to quickly generate such a character, many Julia environments allow you to do this by typing \delta and hitting tab. I recommend using these sparingly, as they might confuse people transitioning from other languages that don’t allow unicode names. On the other hand they might be useful to make your code style more mathy. Not allowed as names are built-in keywords that have special meaning and trying to assign them will result in an error.

if = 3
# syntax: unexpected "="

If you are interested in more details about variables and name assignments you can take a look at the official documentation. In the next blog post we will take a look at the type system of Julia.

A Curve Fitting Guide for the Busy Experimentalist

Curve fitting is an extremely useful analysis tool to describe the relationship between variables or discover a trend within noisy data. Here I’ll focus on a pragmatic introduction curve fitting: how to do it in Python, why can it fail and how do we interpret the results? Finally, I will also give a brief glimpse at the larger themes behind curve fitting, such as mathematical optimization, to the extent that I think is useful for the casual curve fitter.

Curve Fitting Made Easy with SciPy

We start by creating a noisy exponential decay function. The exponential decay function has two parameters: the time constant tau and the initial value at the beginning of the curve init. We’ll evenly sample from this function and add some white noise. We then use curve_fit to fit parameters to the data.

import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

# The exponential decay function
def exp_decay(x, tau, init):
    return init*np.e**(-x/tau)

# Parameters for the exp_decay function
real_tau = 30
real_init = 250

# Sample exp_decay function and add noise
np.random.seed(100)
dt=0.1
x = np.arange(0,100,dt)
noise=np.random.normal(scale=50, size=x.shape[0])
y = exp_decay(x, real_tau, real_init)
y_noisy = y + noise

# Use scipy.optimize.curve_fit to fit parameters to noisy data
popt, pcov = scipy.optimize.curve_fit(exp_decay, x, y_noisy)
fit_tau, fit_init = popt

# Sample exp_decay with optimized parameters
y_fit = exp_decay(x, opt_tau, opt_init)

fig, ax = plt.subplots(1)
ax.scatter(x, y_noisy,
           alpha=0.8,
           color= "#1b9e77",
           label="Exponential Decay + Noise")
ax.plot(x, y,
        color="#d95f02",
        label="Exponential Decay")
ax.plot(x, y_fit,
        color="#7570b3",
        label="Fit")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend()
ax.set_title("Curve Fit Exponential Decay")

Our fit parameters are almost identical to the actual parameters. We get 30.60 for fit_tau and 245.03 for fit_init both very close to the real values of 30 and 250. All we had to do was call scipy.optimize.curve_fit and pass it the function we want to fit, the x data and the y data. The function we are passing should have a certain structure. The first argument must be the input data. All other arguments are the parameters to be fit. From the call signature of def exp_decay(x, tau, init) we can see that x is the input data while tau and init are the parameters to be optimized such that the difference between the function output and y_noisy is minimal. Technically this can work for any number of parameters and any kind of function. It also works when the sampling is much more sparse. Below is a fit on 20 randomly chosen data points.

Of course the accuracy will decrease with the sampling. So why would this every fail? The most common failure mode in my opinion is bad initial parameters.

Choosing Good Initial Parameters

The initial parameters of a function are the starting parameters before being optimized. The initial parameters are very important because most optimization methods don’t just look for the best fit randomly. That would take too long. Instead, it starts with the initial parameters, changes them slightly and checks if the fit improves. When changing the parameters shows very little improvement, the fit is considered done. That makes it very easy for the method to stop with bad parameters if it stops in a local minimum or a saddle point. Let’s look at an example of a bad fit. We will change our tau to a negative number, which will result in exponential growth.

In this case fitting didn’t work. For a real_tau and real_init of -30 and 20 we get a fit_tau and fit_init of 885223976.9 and 106.4, both way off. So what happened? Although we never specified the initial parameters (p0), curve_fit chooses default parameters of 1 for both fit_tau and fit_init. Starting from 1, curve_fit never finds good parameters. So what happens if we choose better parameters? Looking at our exp_decay definition and the exponential growth in our noisy data, we know for sure that our tau has to be negative. Let’s see what happens when we choose a negative initial value of -5.

p0 = [-5, 1]
popt, pcov = scipy.optimize.curve_fit(exp_decay, x, y_noisy, p0=p0)
fit_tau, fit_init = popt
y_fit = exp_decay(x, fit_tau, fit_init)
fig, ax = plt.subplots(1)
ax.scatter(x, y_noisy,
           alpha=0.8,
           color= "#1b9e77",
           label="Exponential Decay + Noise")
ax.plot(x, y,
        color="#d95f02",
        label="Exponential Decay")
ax.plot(x, y_fit,
        color="#7570b3",
        label="Fit")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend()
ax.set_title("Curve Fit Exponential Growth Good Initials")

With an initial parameter of -5 for tau we get good parameters of -30.4 for tau and 20.6 for init (real values were -30 and 20). The key point is that initial conditions are extremely important because they can change the result we get. This is an extreme case, where the fit works almost perfectly for some initial parameters or completely fails for others. In more subtle cases different initial conditions might result in slightly better or worse fits that could still be relevant to our research question. But what does it mean for a fit to be better or worse? In our example we can always compare it to the actual function. In more realistic settings we can only compare our fit to the noisy data.

Interpreting Fitting Results

In most research setting we don’t know our exact parameters. If we did, we would not need to do fitting at all. So to compare the goodness of different parameters we need to compare our fit to the data. How do we calculate the error between our data and the prediction of the fit? There are many different measures but among the most simple ones is the sum of squared residuals (SSR).

def ssr(y, fy):
    """Sum of squared residuals"""
    return ((y - fy) ** 2).sum()

We take the difference between our data (y) and the output of our function given a parameter set (fy). We square that difference and sum it up. In fact this is what curve_fit optimizes. Its whole purpose is to find the parameters that give the smallest value of this function, the least square. The parameters that give the smallest SSR are considered the best fit. We saw that this process can fail, depending on the function and the initial parameters, but let’s assume for a moment it worked. If we found the smallest SSR, does that mean we found the perfect fit? Unfortunately not. What we found was a good estimate for the best fitting parameters given our function. There are probably other functions out there that can fit our data better. We can use the SSR to find better fitting functions in a process called cross-validation. Instead of comparing different parameters of the same function we compare different functions. However, if we increase the number of parameters we run into a problem called overfitting. I will not get into the details of overfitting here because it is beyond our scope.

The main point is that we must stay clear of misinterpretations of best fit. We are always fitting the parameters and not the function. If our fitting works, we get a good estimate for the best fitting parameters. But sometimes our fitting doesn’t work. This is because our fitting method did not converge to the minimum SSR and in the final chapter we will find out why that might happen in our example.

The Error Landscape of Exponential Decay

To understand why fitting can fail depending on the initial conditions we should consider the landscape of our sum of squared residuals (SSR). We will calculate it by assuming that we already know the init parameter, so we keep it constant. Then we calculate the SSR for many values of tau smaller than zero and many values for tau larger than zero. Plotting the SSR against the guessed tau will hopefully show us how the SSR looks around the ideal fit.

real_tau = -30.0
real_init = 20.0

noise=np.random.normal(scale=50, size=x.shape[0])
y = exp_decay(x, real_tau, real_init)
y_noisy = y + noise
dtau = 0.1
guess_tau_n = np.arange(-60, -4.9, dtau)
guess_tau_p = np.arange(1, 60, dtau)

# The SSR function
def ssr(y, fy):
    """Sum of squared residuals"""
    return ((y - fy) ** 2).sum()

loss_arr_n = [ssr(y_noisy, exp_decay(x, tau, real_init)) 
              for tau in guess_tau_n]
loss_arr_p = [ssr(y_noisy, exp_decay(x, tau, real_init))
              for tau in guess_tau_p]

"""Plotting"""
fig, ax = plt.subplots(1,2)
ax[0].scatter(guess_tau_n, loss_arr_n)
real_tau_loss = ssr(y_noisy, exp_decay(x, real_tau, real_init))
ax[0].scatter(real_tau, real_tau_loss, s=100)
ax[0].scatter(guess_tau_n[-1], loss_arr_n[-1], s=100)
ax[0].set_yscale("log")
ax[0].set_xlabel("Guessed Tau")
ax[0].set_ylabel("SSR Standard Log Scale")
ax[0].legend(("All Points", "Real Minimum", "-5 Initial Guess"))

ax[1].scatter(guess_tau_p, loss_arr_p)
ax[1].scatter(guess_tau_p[0], loss_arr_p[0], s=100)
ax[1].set_xlabel("Guessed Tau")
ax[1].set_ylabel("SSR")
ax[1].legend(("All Points", "1 Initial Guess"))

On the left we see the SSR landscape for tau smaller than 0. Here we see that towards zero, the error becomes extremely large (note the logarithmic y scale). This is because towards zero the exponential growth becomes ever faster. As we move to more negative values we find a minimum near -30 (orange), our real tau. This is the parameter curve_fit would find if it only optimized tau and started initially at -5 (green). The optimization method does not move to more negative values from -30 because there the SSR becomes worse, it increases.

On the right side we get a picture of why optimization failed when we started at 1. There is no local minimum. The SSR just keeps decreasing with larger values of tau. That is why the tau was so larger when fitting failed (885223976.9). If we set our initial parameter anywhere in this part of the SSR landscape, this is where tau will go. Now there are other optimization methods that can overcome bad initial parameters. But few are completely immune to this issue.

Easy to Learn Hard to Master.

Curve fitting is a very useful technique and it is really easy in Python with Scipy but there are some pitfalls. First of all, be aware of the initial values. They can lead to complete fitting failure or affect results in more subtle systematic ways. We should also remind ourselves that even with decent fitting results, there might be a more suitable function out there that can fit our data even better. In this particular example we always knew what the underlying function was. This is rarely the case in real research settings. Most of the time it is much more productive to think more deeply about possible underlying functions than finding more complicated fitting methods.

Finally, we barely scratched the surface here. Mathematical optimization is an entire field in itself and it is relevant to many areas such as statistics, machine learning, deep learning and many more. I tried to give the most pragmatic introduction to the topic here. If want to go deeper into the topic I recommend this Scipy lecture and of course the official Scipy documentation for optimization and root finding.

The Hodgkin-Huxley Neuron in Julia

The Hodgkin-Huxley model is one of the earliest mathematical descriptions of neural spiking. It was originally developed on data from the squid giant axon. Today, Hodgkin-Huxley like dynamics are also used to simulate the spiking of a variety of neuron types. I’ve recently written a script to simulate Hodgkin-Huxley dynamics in Julia. Here I am sharing that code and I will go through the most important elements. As I just started to learn Julia I will also mention some of the things I learned about Julia in the process.

using Plots
gr()

# Hyperparameters
tmin = 0.0
tmax = 1000.0
dt = 0.01
T = tmin:dt:tmax

# Parameters
gK = 35.0
gNa = 40.0
gL = 0.3
Cm = 1.0
EK = -77.0
ENa = 55.0
El = -65.0

# Potassium ion-channel rate functions
alpha_n(Vm) = (0.02 * (Vm - 25.0)) / (1.0 - exp((-1.0 * (Vm - 25.0)) / 9.0))
beta_n(Vm) = (-0.002 * (Vm - 25.0)) / (1.0 - exp((Vm - 25.0) / 9.0))

# Sodium ion-channel rate functions
alpha_m(Vm) = (0.182*(Vm + 35.0)) / (1.0 - exp((-1.0 * (Vm + 35.0)) / 9.0))
beta_m(Vm) = (-0.124 * (Vm + 35.0)) / (1.0 - exp((Vm + 35.0) / 9.0))

alpha_h(Vm) = 0.25 * exp((-1.0 * (Vm + 90.0)) / 12.0)
beta_h(Vm) = (0.25 * exp((Vm + 62.0) / 6.0)) / exp((Vm + 90.0) / 12.0)

# n, m & h steady-states
n_inf(Vm=0.0) = alpha_n(Vm) / (alpha_n(Vm) + beta_n(Vm))
m_inf(Vm=0.0) = alpha_m(Vm) / (alpha_m(Vm) + beta_m(Vm))
h_inf(Vm=0.0) = alpha_h(Vm) / (alpha_h(Vm) + beta_h(Vm))

# Conductances
GK(gK, n) = gK * (n ^ 4.0)
GNa(gNa, m) = gNa * (m ^ 3.0) * h
GL(gL) = gL

# Differential equations
function dv(Vm, GK, GNa, GL, Cm, EK, ENa, El, I, dt);
    ((I  - (GK * (v - EK)) - (GNa * (v - ENa)) - (GL * (v - El))) / Cm) * dt
end
dn(n, Vm, dt) = ((alpha_n(Vm) * (1.0 - n)) - (beta_n(Vm) * n)) * dt
dm(m, Vm, dt) = ((alpha_m(Vm) * (1.0 - m)) - (beta_m(Vm) * m)) * dt
dh(h, Vm, dt) = ((alpha_h(Vm) * (1.0 - h)) - (beta_h(Vm) * h)) * dt
I = T * 0.002

# Initial conditions and setup
v = -65
m = m_inf(v)
n = n_inf(v)
h = h_inf(v)

v_result = Array{Float64}(undef, length(T))
m_result = Array{Float64}(undef, length(T))
n_result = Array{Float64}(undef, length(T))
h_result = Array{Float64}(undef, length(T))

v_result[1] = v
m_result[1] = m
n_result[1] = n
h_result[1] = h


for t = 1:length(T)-1
    GKt = GK(gK, n)
    GNat = GNa(gNa, m)
    GLt = GL(gL)

    dvt = dv(v, GKt, GNat, GLt, Cm, EK, ENa, El, I[t], dt)
    dmt = dm(m, v, dt)
    dnt = dn(n, v, dt)
    dht = dh(h, v, dt)

    global v = v + dvt
    global m = m + dmt
    global n = n + dnt
    global h = h + dht

    v_result[t+1] = v
    m_result[t+1] = m
    n_result[t+1] = n
    h_result[t+1] = h
end

p1 = plot(T, v_result, xlabel="time (ms)", ylabel="voltage (mV)", legend=false, dpi=300)

The Parameters

There are seven parameters that define the Hodgkin-Huxley model. The maximum potassium, sodium and leak conductances gK, gNa and gL. Then there are the equilibrium potentials for potassium, sodium and leak, EK, ENa and El. Finally, there is Cm, the membrane capacitance. In a nutshell, the models comes down to calculating the fraction of sodium and potassium channels that are open at a time point. Together with the maximum conductance, the membrane potential and the equilibrium potential, the fraction of open channels tells us how much current is flowing at the time. The amount of current filtered by the membrane capacitance in turn tells us by how much the voltage changes. The fraction of open channels is given by n, m and h.

Differential Equations

We need to track the change of the voltage (v), potassium channel activation (n), sodium channel activation (m) and sodium channel inactivation (h). Let’s consider the potassium channels first. Active potassium channels can stay active or transition to the inactive state and inactive sodium channels. Since these potassium channels are voltage gated, the chance that they transition depends on the voltage. The function alpha_n gives the transition rate from inactive to active and beta_n gives the transition rate from active to inactive. For the sodium channels, the situation is almost identical. However, they can also be in a third state, that corresponds to depolarization induced inactivation.

Now that we are able to keep track of the states of our channels we can calculate the conductances. GK calculates the potassium conductance, GNa the sodium conductance and GL the leak conductance. Those conductances are then used in the dv function to calculate the currents based on the voltage. And that’s basically it. The integration method is simple forward euler inside the for loop.

Julia Notes From a Beginner

This is my very first Julia script so I have some thoughts. This is a completely non-optimized script and it is extremely fast, despite the for loop. From what I learned so far, for loops in Julia are known to be fast and they can allegedly outperform vectorized solutions. This is very different from Python, where we strictly avoid for loops, especially when concerned about performance.

For now I am confused by the scope and the use of the global keyword. Scope seems to operate similar to Python, where functions and for loops have seamless access to variables in the outer scope. Assigning variables on the other hand seems to be a problem inside the for loop, unless the global keyword is used.

Generally I am very happy with the Julia syntax. I think I could even code Python and Julia back to back. One major problem is of course indices starting at 1 but I get the difference in convention. I’m looking forward to my next script.

Managing Files and Directories with Python

We cannot always avoid the details of file management, especially when analyzing raw data. It can come in the form of multiple files distributed over multiple directories. Accessing those files and the directory system can be a critical aspect of raw data processing. Luckily, Python has very convenient methods to handle files and directories in the os module. Here we will look at functions to look inspect the contents of directories (os.listdir, os.walk) and functions that help us manipulate files and directoris (os.rename, os.remove, os.mkdir).

Directory Contents

One of the most important functions to manage files is os.listdir(directory). It returns a list of strings that contains to the names of all directories and files in path.

import os

directory = r"C:\Users\Daniel\tutorial"

dir_content = os.listdir(directory)
# ['images', 'test_one.txt', 'test_two.txt']

Now that we know what is in our directory, how do we find out which of these are files and which are subdirectories? You might be tempted to just check for a period (.) inside the string because it separates the file from its extension. This method is error prone, because directories could contain periods and files don’t necessarily have extensions. It is much better to use os.path.isfile() and os.path.isdir().

files = [x for x in dir_content if os.path.isfile(directory + os.sep + x)]
# ['test_one.txt', 'test_two.txt']

dirs = [x for x in dir_content if os.path.isdir(directory + os.sep + x)]
# ['images']

Now we know the contents of directory. We have two files and one directory. In case you were wondering about os.sep, that is the directory separator of the operating system. On my Window 10 that is the '\'. What if we need both files that are in our directory and those that are in sub-directories? This is the perfect case to use os.walk(). It gives a convenient way to loop through a directory and all its sub-directories.

for root, dirs, files in os.walk(directory):
    print(root)
    print(dirs)
    print(files)

# C:\Users\Daniel\tutorial
# ['images']
# ['file_1.txt', 'file_2.txt']
# C:\Users\Daniel\tutorial\images
# []
# ['plot_1.png', 'plot_2.png']

By default os.walk() goes from top down. The first name root tells us the full path of the directory we are currently at. Printing root tells us that we start at the top directory. While root is a string, both dirs and files are lists. They tell us for the current root, which files and directories are there. For the first directory we already found out on our own that the contents are. Two text files and a sub-directory. Our loop next goes to the sub-directory images. In there are no more sub-directories but two image files. If there were more sub-categories at any level (directory or directory\images), os.walk would go through all of them. Next we will find out how to create/move/rename/delete both files and directories.

Manipulating Files and Directories

Let’s say I want to rename the .txt files. I don’t like the numbering and would prefer them to have a leading zero in case they go into the double digits. We can use os.rename for this job.

directory = r"C:\Users\Daniel\tutorial"

dir_content = os.listdir(directory)

txt_f = [x for x in dir_content
         if os.path.isfile(directory + os.sep + x) and ".txt" in x]
# ['file_1.txt', 'file_2.txt']
for f_name in txt_f:
    f_name_split = f_name.split("_")
    num = f_name_split[1].split(".")[0]
    new_name = f_name_split[0] + "_" + num.zfill(2) + ".txt"
    os.rename(directory + os.sep + f_name,
              directory + os.sep + new_name)
    
os.listdir(directory)
['file_01.txt', 'file_02.txt', 'images']

Now that we renamed our files, let’s create another sub-directory for these .txt files. To create new directories we use os.mkdir().

os.listdir(directory)
# ['file_01.txt', 'file_02.txt', 'images']

os.mkdir(directory + os.sep + 'texts')

os.listdir(directory)
# ['file_01.txt', 'file_02.txt', 'images', 'texts']

Now we need to move the .txt files. There is no dedicated move function in the os module. Instead we use rename but instead of changing the name of the file, we change the path to the directory.

9directory = r"C:\Users\Daniel\tutorial"
dir_content = os.listdir(directory)
txt_f = [x for x in dir_content
         if os.path.isfile(directory + os.sep + x) and ".txt" in x]

for f in txt_f:
    old_path = directory + os.sep + f
    new_path = directory + os.sep + "texts" + os.sep + f
    os.rename(old_path, new_path)

os.listdir(directory)
# ['images', 'texts']

os.listdir(directory+os.sep+'texts')
# ['file_01.txt', 'file_02.txt']

Now our .txt files are in the ‘\texts’ sub-directory. Unfortunately there is no copy function in os. Instead we have to use another module called shutil. You can use a signature like this.

from shutil import copyfile
copyfile(source, destination)

Finally, to remove a file we simple use os.remove().

os.listdir(directory+os.sep+"texts")
# ['file_01.txt', 'file_02.txt']

os.remove(directory+os.sep+'texts'+os.sep+"file_01.txt")

os.listdir(directory+os.sep+"texts")
# ['file_02.txt']

And that’s it. You might have noticed that we did not cover how to create or read files. The os module is technically able to create and read files but in data science we usually depend on more high level interfaces to read files. For example, we might want to open a .csv file with pandas pd.read_csv. Using the lower level os functions will rarely be necessary. Thank you for reading and let me know if you have any questions.

In case you want to learn more about the os module, here are the os module docs.

Matplotlib and the Object-Oriented Interface

Matplotlib is a great data visualization library for Python and there are two ways of using it. The functional interface (also known as pyplot interface) allows us to interactively create simple plots. The object-oriented interface on the other hand gives us more control when we create figures that contain multiple plots. While having two interfaces gives us a lot of freedom, they also cause some confusion. The most common error is to use the functional interface when using the object-oriented would be much easier. For beginners it is now highly recommended to use the object-oriented interface under most circumstances because they have a tendency to overuse the functional one. I made that mistake myself for a long time. I started out with the functional interface and only knew that one for a long time. Here I will explain the difference between both, starting with the object-oriented interface. If you have never used it, now is probably the time to start.

Figures, Axes & Methods

When using the object-oriented interface, we create objects and do the plotting with their methods. Methods are the functions that come with the object. We create both a figure and an axes object with plt.subplots(1). Then we use the ax.plot() method from our axes object to create the plot. We also use two more methods, ax.set_xlabel() and ax.set_ylabel() to label our axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0, 10, 0.1)
y = np.sin(np.pi * x) + x

fig, ax = plt.subplots(1)
ax.plot(x, y)
ax.set_xlabel("x")
ax.set_ylabel("y")

The big advantage is that we can very easily create multiple plots and we can very naturally keep track of where we are plotting what, because the method that does the plotting is associated with a specific axes object. In the next example we will plot on three different axes that we create all with plt.subplots(3).

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
ax[0].plot(x,ys[0])
ax[1].plot(x,ys[1])
ax[2].plot(x,ys[2])

ax[0].set_title("Addition")
ax[1].set_title("Multiplication")
ax[2].set_title("Division")

for a in ax:
    a.set_xlabel("x")
    a.set_ylabel("y")

When we create multiple axes objects, they are available to us through the ax array. We can index into them and we can also loop through all of them. We can take the above example even further and pack even the plotting into the for loop.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
titles = ["Addition", "Multiplication", "Division"]
for idx, a in enumerate(ax):
    a.plot(x, ys[idx])
    a.set_title(titles[idx])
    a.set_xlabel("x")
    a.set_ylabel("y")

This code produces exactly the same three axes figure as above. There are other ways to use the object-oriented interface. For example, we can create an empty figure without axes using fig = plt.figure(). We can then create subplots in that figure with ax = fig.add_subplot(). This is exactly the same concept as always but instead of creating figure and axes at the same time, we use the figure method to create axes. If personally prefer fig, ax = plt.subplots() but fig.add_subplot() is slightly more flexible in the way it allows us to arrange the axes. For example, plt.subplots(x, y) allows us to create a figure with axes arranged in x rows and y columns. Using fig.add_subplot() we could create a column with 2 axes and another with 3 axes.

fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax2 = fig.add_subplot(2,2,3)
ax3 = fig.add_subplot(3,2,2)
ax4 = fig.add_subplot(3,2,4)
ax5 = fig.add_subplot(3,2,6)

Personally I prefer to avoid these arrangements, because things like tight_layout don’t work but it is doable and cannot be done with plt.subplots(). This concludes our overview of the object-oriented interface. Simply remember that you want to do your plotting through the methods of an axes object that you can create either with fig, ax = plt.subplots() or fig.add_subplot(). So what is different about the functional interface? Instead of plotting through axes methods, we do all our plotting through functions in the matplotlib.pyplot module.

One pyplot to Rule Them All

The functional interface works entirely through the pyplot module, which we import as plt by convention. In the example below we use it to create the exact same plot as in the beginning. We use plt to create the figure, do the plotting and label the axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0,10,0.1)
y = np.sin(np.pi*x) + x

plt.figure()
plt.plot(x, y)
plt.xlabel("x")
plt.ylabel("y")

You might be wondering, why we don’t need to tell plt where to plot and which axes to label. It always works with the currently active figure or axes object. If there is no active figure, plt.plot() creates its own, including the axes. If a figure is already active, it creates an axes in that figure or plots into already existing axes. This make the functional interface less explicit and slightly less readable, especially for more complex figures. For the object-oriented interface, there is a specific object for any action, because a method must be called through an object. With the functional interface, it can be a guessing game where the plotting happens and we make ourselves highly dependent on the location in our script. The line we call plt.plot() on becomes crucial. Let’s recreate the three subplots example with the functional interface.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
plt.subplot(3, 1, 1)
plt.plot(x, ys[0])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Addition")
plt.subplot(3, 1, 2)
plt.plot(x, ys[1])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Multiplication")
plt.subplot(3, 1, 3)
plt.plot(x, ys[2])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Division")

This one is much longer than the object-oriented code because we cannot label the axes in a for loop. To be fair, in this particular example we can put the entirety of our plotting and labeling into a for loop, but we have to put either everything or nothing into the loop. This is what I mean when I say the functional interface is less flexible.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
titles = ["Addition", "Multiplication", "Division"]
for idx, y in enumerate(ys):
    plt.subplot(3,1,idx+1)
    plt.plot(x, y)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(titles[idx])

Both interfaces are very similar. You might have noticed that the methods in the object-oriented API have the form set_attribute. This is by design and follows from an object oriented convention, where methods that change attributes have a set prefix. Methods that don’t change attributes but create entirely new objects have an add prefix. For example add_subplot. Now that we have seen both APIs at work, why is the object-oriented API recommended?

Advantages of the Object-Oriented API

First of all, Matplotlib is internally object-oriented. The pyplot interface masks that fact in an effort to make the usage more MATLAB like by putting a functional layer on top. If we avoid plt and instead work with the object-oriented interface, our plotting becomes slightly faster. More importantly, the object-oriented interface is considered more readable and explicit. Both are very important when we write Python. Readability can be somewhat subjective but I hope the code could convince you that going through the plotting methods of an axes object makes it much more clear where we are plotting. We also get more flexibility to structure our code. Because plt depends on the order of plotting, we are constraint. With the object-oriented interface we can structure our code more clearly. We can for example split plotting, labeling and other tasks into their own code blocks.

In summary, I hope you will be able to use the object-oriented interface of Matplotlib now. Simply remember to create axes with fig, ax = plt.subplots() and then most of the work happens through the ax object. Finally, the object-oriented interface is recommended because it is more efficient, readable, explicit and flexible.

Image Segmentation with scikit-image

Image Segmentation is one of the most important steps in most imaging analysis pipelines. It separates between the background and the features of our images. It can also determine the number of distinct features and their location. Our ability to segment determines what we can analyze. We’ll look at a basic but complete segmentation pipeline with scikit-image. You can see the result in the title image where we segment four cells. First, we will need to threshold the image into a binary version where the background is 0 and the foreground is 1.

import numpy as np
import matplotlib.pyplot as plt
from skimage import io, filters, morphology, color

image = io.imread("example_image.tif")  # Load Image
threshold = filters.threshold_otsu(image)  # Calculate threshold
image_thresholded = image > threshold  # Apply threshold

# Show the results
fig, ax = plt.subplots(1, 2)
ax[0].imshow(image, 'gray')
ax[1].imshow(image_thresholded, 'gray')
ax[0].set_title("Intensity")
ax[1].set_title("Thresholded")

We calculate the threshold with the threshold_otsu function and apply it with a boolean operator. This threshold method works very well but there are two problems. First, there are very small particles that have nothing to do with our cell. To take care of those, we will apply morphological erosion. Second, there are holes in our cells. We will close those with morphological dilation.

# Apply 2 times erosion to get rid of background particles
n_erosion = 2
image_eroded = image_thresholded
for x in range(n_erosion):
    image_eroded = morphology.binary_erosion(image_eroded)

# Apply 14 times dilation to close holes
n_dilation = 14
image_dilated = image_eroded
for x in range(n_dilation):
    image_dilated = morphology.binary_dilation(image_dilated)

# Apply 4 times erosion to recover original size
n_erosion = 4
image_eroded_two = image_dilated
for x in range(n_erosion):
    image_eroded_two = morphology.binary_erosion(image_eroded_two)

fig, ax = plt.subplots(2,2)
ax[0,0].imshow(image_thresholded, 'gray')
ax[0,1].imshow(image_eroded, 'gray')
ax[1,0].imshow(image_dilated, 'gray')
ax[1,1].imshow(image_eroded_two, 'gray')
ax[0,0].set_title("Thresholded")
ax[0,1].set_title("Eroded 2x")
ax[1,0].set_title("Dilated 14x")
ax[1,1].set_title("Eroded 4x")

Erosion turns any pixel black that is contact with another black pixel. This is how erosion can get rid of small particles. In our case we need to apply erosion twice. Once those particles disappeared, we can use dilation to close the holes in our cells. To close all the holes, we have to slightly over dilate, which makes the cells slightly bigger than they actually are. To recover the original morphology we apply some more erosions. Here is an example that shows how erosion and dilation work in detail. It also illustrates what being “in contact” with another pixel means by default.

cross = np.array([[0,0,0,0,0], [0,0,1,0,0], [0,1,1,1,0], 
                [0,0,1,0,0], [0,0,0,0,0]], dtype=np.uint8)
cross_eroded = morphology.binary_erosion(cross)
cross_dilated = morphology.binary_dilation(cross)
fig, ax = plt.subplots(1,3)
ax[0].imshow(cross, 'gray')
ax[1].imshow(cross_eroded, 'gray')
ax[2].imshow(cross_dilated, 'gray')
ax[0].set_title("Cross")
ax[1].set_title("Cross Eroded")
ax[2].set_title("Cross Dilated")

Now we are essentially done segmenting foreground and background. But we also want to assign distinct labels to our objects.

labels = morphology.label(image_eroded_two)
labels_rgb = color.label2rgb(labels,
                             colors=['greenyellow', 'green',
                                     'yellow', 'yellowgreen'],
                             bg_label=0)
image.shape
# (342, 382)
labels.shape
# (342, 382)
fig, ax = plt.subplots(2,2)
ax[0,0].imshow(labels==1, 'gray')
ax[0,1].imshow(labels==2, 'gray')
ax[1,0].imshow(labels==3, 'gray')
ax[1,1].imshow(labels_rgb)
ax[0,0].set_title("label == 1")
ax[0,1].set_title("label == 2")
ax[1,0].set_title("label == 3")
ax[1,1].set_title("All labels RGB")

We use morphology.label to generate a label for each connected feature. This returns an array that has the same shape as our original image but the pixels are no longer zero or one. The background is zero but each feature gets its own integer. All pixels belonging to the first label are equal to 1, pixels of the second label equal to 2 and so on. To visualize those labels all in one image, we call color.label2rgb to get color representations for each label in RGB space. And that’s it.

Segmentation is crucial for image analysis and I hope this tutorial got you on a good way to do your own segmentation with scikit-image. This pipeline is not perfect but illustrates the concept well. There are many more functions in the morphology module to filter binary images, but they all come down to a sequence of erosions and dilations. If you want to adapt this approach for your own images, I would recommend to play around with the number of erosions and dilations. Let me know how it worked for you.

Contrast Adjustment with scikit-image

Contrast is one of the most important properties of an image and contrast adjustment is one of the easiest things we can do to make our images look better. There are many ways to adjust the contrast and with most of them we have to be careful because they can artificially change our images. The title image shows three basic methods of contrast adjustment and how they affect the image histogram. We will cover all three methods here but first let us consider what it means for an image to have low contrast.

Each pixel in an image has an intensity value. In an 8-bit image the values can be between 0 and 255. Contrast issues occur, when the largest intensity value in our image is smaller than 255 or the smallest value is larger than 0. The reason is that we are not using the entire range of possible values, making the image overall darker than it could be. So what about the contrast of our example image? We can use .min() and .max() to find maximum and minimum intensity.

import numpy as np
import matplotlib.pyplot as plt
from skimage import io, exposure, data

image = io.imread("example_image.tif")
image.max()
# 52
image.min()
# 1

Our maximum of 52 is very far away from 255, which explains why our image is so dark. On the minimum we are almost perfect. To correct the contrast we can user the exposure module which gives us the function rescale_intensity.

image_minmax_scaled = exposure.rescale_intensity(image)
image_minmax_scaled.max()
# 255
image_minmax_scaled.min()
# 0

Now both the minimum and the maximum are optimized. All pixels that were equal to the original minimum are now 0 and all pixels equal to the maximum are now 255. But what happens to the values in the middle? Let’s look at a smaller example.

arr = np.array([2, 40, 100, 205, 250], dtype=np.uint8)

arr_rescaled = exposure.rescale_intensity(arr)
# array([  0,  39, 100, 208, 255], dtype=uint8)

As expected, the minimum 2 became 0 and the maximum 250 became 255. In the middle, 40 became smaller, nothing happened to 100 and 205 became larger. We will look at each step to find out how we got there.

arr = np.array([2, 40, 100, 205, 250], dtype=np.uint8)
min, max = arr.min(), arr.max()
arr_subtracted = arr - min  # Subtract the minimum
# array([  0,  38,  98, 203, 248], dtype=uint8)
arr_divided = arr_subtracted / (max - min)  # Divide by new max
# array([0.        , 0.15322581, 0.39516129, 0.81854839, 1.        ])
arr_multiplied = arr_divided * 255  # Multiply by dtype max
# array([  0.        ,  39.07258065, 100.76612903, 208.72983871,
#        255.        ])
# Convert dtype to original uint8
arr_rescaled = np.asarray(arr_multiplied, dtype=arr.dtype)
# array([  0,  39, 100, 208, 255], dtype=uint8)

We can get there in four simple steps. Subtract the minimum, divide by the maximum of the new subtracted array, multiply by the maximum value of the data type and finally convert back to the original data type. This works well if we want to rescale by minimum and maximum of the image but sometimes we need to use different values. Especially the maximum can be easily dominated by noise. For all we know, it could be just one pixel that is not necessarily representative for the entire image. As you can see from the histogram in the title image, very few pixels are near the maximum. This means we can use percentile rescaling with little information loss.

percentiles = np.percentile(image, (0.5, 99.5))
# array([ 1., 28.])
scaled = exposure.rescale_intensity(image,
                                    in_range=tuple(percentiles))

This time we scale from 1 to 28, which means that all values above or equal to 28 become the new maximum 255. As we chose the 99.5 percentile, this affects roughly 5% of the upper pixels. You can see the consequence in the image on the right. The image becomes brighter but it also means that we lose information. Pixels that were distinguishable now look exactly the same, because they are all 255. It is up to you if you can afford to lose those features. If you do quantitative image analysis you should perform rescaling with caution and always look out for information loss.

Plotting 2D Vectors with Matplotlib

Vectors are extremely important in linear algebra and beyond. One of the most common visual representations of a vector is the arrow. Here we will learn how to plot vectors with Matplotlib. The title image shows two vectors and their sum. As a first step we will plot the vectors originating at 0, shown below.

import matplotlib.pyplot as plt
import numpy as np

vectors = np.array(([2, 0], [3, 2]))
vector_addition = vectors[0] + vectors[1]
vectors = np.append(vectors, vector_addition[None,:], axis=0)

tail = [0, 0]
fig, ax = plt.subplots(1)
ax.quiver(*tail,
           vectors[:, 0],
           vectors[:, 1],
           scale=1,
           scale_units='xy',
           angles = 'xy',
           color=['g', 'r', 'k'])

ax.set_xlim((-1, vectors[:,0].max()+1))
ax.set_ylim((-1, vectors[:,1].max()+1))

We have two vectors stored in our vectors array. Those are [2, 0] and [3, 2]. Both in order of [x, y] as you can see from the image. We can perform vector addition between the two by simply adding vectors[0] + vectors[1]. Then we use np.append so we have all three vectors in the same array. Now we define the origin in tail, because we will want the tail of the arrow to be located at [0, 0]. Then we create the figure and axes to plot in with plt.subplots(). The plotting itself can be done with one call to the ax.quiver method. But it is quite the call, with a lot of parameters so let’s go through it.

First, we need to define the origin, so we pass *tail. Why the asterisk? ax.quiver really takes two parameters for the origin, X and Y. The asterisk causes [0, 0] to be unpacked into those two parameters. Next, we pass the x coordinates (vectors[:, 0]) and then the y coordinates (vectors[:, 1]) of our vectors. The next three parameters scale, scale_units and angles are necessary to make the arrow length match the actual numbers. By default, the arrows are scaled, based on the average of all plotted vectors. We get rid of that kind of scaling. Try removing some of those to get a better idea of what I mean. Finally, we pass three colors, one for each arrow.

So what do we need to plot the head to tail aligned vectors as in the title image? We just need to pass the vectors where the origin is the other vector.

ax.quiver(vectors[1::-1,0],
          vectors[1::-1,1],
          vectors[:2,0],
          vectors[:2,1],
          scale=1,
          scale_units='xy',
          angles = 'xy',
          color=['g', 'r'])

This is simple because it is the same quiver method but it is complicated because of the indexing syntax. Now, we no longer unpack *tail. Instead we pass x and y origins separately. In vectors[1::-1,0] the 0 gets the x coordinates. -1 inverts the array. If we would not invert, each vector would be it’s own origin. The 1 skips the first vector, which is the summed vector because we inverted. vectors[1::-1,1] gives us the y coordiantes. Finally we just need to skip the summed vector when we pass x and y magnitudes. The rest is the same.

So that’s it. Unfortunately, ax.quiver only works for 2D vectors. It also isn’t specifically made to present vectors that have a common origin. Its main use case is to plot vector fields. This is why some of the plotting here feels clunky. There is also ax.arrow which is more straightforward but only creates one arrow per method call. I hope this post was helpful for you. Let me know if you have other ways to plot vectors.