Eventplots with Color in Matplotlib

Here we will learn how to color individual events in a Matplotlib eventplot. An eventplot most commonly shows the timing of events (x axis) from different sources (y axis). Assigning different colors to the sources can be done easily with the colors parameter of the plt.eventplot function. However, color coding each event requires us to create the eventplot and then manually assign the colors with event_collection.set_colors. Here is the code and the resulting output.

import matplotlib as mpl
import matplotlib.pyplot as plt
import shelve
import os
import numpy as np

dirname = os.path.dirname(__file__)
data_path = os.path.join(dirname, 'evenplot_demo_spikes')

data = shelve.open(data_path)

fig, ax = plt.subplots(1)

event_collection = ax.eventplot(data['spikes'])
phases_flattened = np.hstack(data['phases'])
phase_min = np.quantile(phases_flattened, 0.05)
phase_max = np.quantile(phases_flattened, 0.95)

phases_norm = mpl.colors.Normalize(vmin=phase_min, vmax=phase_max)
normalized_phases = phases_norm(data['phases'])
viridis = mpl.colormaps['viridis']

for idx, col in enumerate(event_collection):
    col.set_colors(viridis(normalized_phases[idx]))

plt.colorbar(mpl.cm.ScalarMappable(norm=phases_norm, cmap=viridis),
             ax=ax, label="Phase", fraction=0.05)

We start out by loading the data with the shelve module and the os module helps us search for the data wherever our script is located (os.path.dirname(file)). Next, we create the figure and axis and plot the event times into the axis. The events will be plotted with the default blue of Matplotlib. We have to make sure to get all the events by assigning to event_collection. We want to color the events according to the data in data['phases']. That data contains a single floating point number for each event and we need to convert those numbers to a RGB color. For that, we first want to normalize our data to a value between zero and one with mpl.colors.Normalize(vmin=phase_min, vmax=phase_max). I use a trick here where I take the 5% quantiles instead of the actual maximum and minimum of the data. This should not be done for a scientific figure, because values above the vmax will have the same color although they have different values, which can be misleading. Once we have our normalized values we create a colormap object with mpl.colormaps['viridis'] and we get can get colors from it by passing it the normalized values. We then loop through our event collection (each entry in event collection is a row in our plot) and we set the color with col.set_colors(viridis(normalized_phases[idx])). Finally, we create a colorbar with plt.colorbar. And that is it. Let me know if there is an easier way to to this that I missed.

Wind and Solar Power Generation

As the world tries to stop planet warming emissions, solar and wind power have taken central stage. One of the reason solar and wind go well together is their complementary seasonal pattern. The magnitude of this pattern depends on the location but in Europe winter means less solar power but more wind power. Summer on the other hand means more solar but slightly less wind. So if you want to have renewable energy all year round, it’s a good idea to build both solar and wind capacity.

On https://energy-charts.info/ raw data has recently become available for download and I have been doing data visualization to show the complementary seasonal pattern with actual data from several countries. Here I show some examples and the code. The code I am showing can be adapted for different countries. Here is the code and an example figure.

import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

dirname = os.path.dirname(__file__)
data_path = os.path.join(dirname, 'data')

country = 'Deutschland'

raw_files = [os.path.join(data_path, f)
             for f in os.listdir(data_path) if country in f]

loaded_files = [pd.read_csv(f) for f in raw_files]

df = pd.concat(loaded_files, ignore_index=True)
df['Datum (UTC)'] = pd.to_datetime(df['Datum (UTC)'])
df['day_of_year'] = df['Datum (UTC)'].dt.dayofyear
df['year'] = df['Datum (UTC)'].dt.year
df['month'] = df['Datum (UTC)'].dt.month
df['day'] = df['Datum (UTC)'].dt.day
df['Wind'] = df['Wind Offshore'] + df['Wind Onshore']


# plt.figure()
# plt.hist(df.loc[df['year'] == 2021]['Wind+Solar'], bins=100)

# Calculate year as theta on the cycle
# df['theta_seconds'] = np.nan
# Plot solar development
df_daily = df.groupby(by=['year', 'month', 'day']).mean()
df_daily = df_daily.reset_index()

# Calculate daily theta
for year in df['year'].unique():
    boolean_year = df_daily['year'] == year
    days_zero = (df_daily[boolean_year]['day_of_year']
                 - df_daily[boolean_year]['day_of_year'].min())
    theta = days_zero / days_zero.max()
    
    df_daily.loc[boolean_year, 'theta_days'] = theta * 2 * np.pi

font = {'size'   : 16}
matplotlib.rc('font', **font)

solar_colors = ['#fef0d9','#fdd49e','#fdbb84','#fc8d59',
                '#ef6548','#d7301f','#990000']
fig, ax = plt.subplots(1, 2, subplot_kw={'projection': 'polar'})

for idx, y in enumerate(df_daily['year'].unique()):
    r = df_daily[df_daily['year'] == y]["Solar"] / 1000 # MW to GW
    theta = df_daily[df_daily['year'] == y]['theta_days']
    ax[0].plot(theta, r, alpha=0.8, color=solar_colors[idx])
    r = df_daily[df_daily['year'] == y]["Wind"] / 1000 # MW to GW
    ax[1].plot(theta, r, alpha=0.8, color=solar_colors[idx])

# Find month starting theta
month_thetas = []
for i in range(1,13):
    idx = (df_daily['month'] == i).idxmax()
    month_thetas.append(df_daily['theta_days'][idx])
month_labels = ['January', 'February', 'March', 'April', 'May',
                'June', 'July', 'August', 'September', 'October',
                'November', 'December']

for i in [0,1]:
    tl = np.array(month_thetas) / (2*np.pi) * 360
    ax[i].set_thetagrids(tl, month_labels)
    ax[i].xaxis.set_tick_params(pad=28)
    ax[i].set_theta_direction(-1)

ax[0].set_rticks([5, 7.5, 10, 12.5])
ax[1].set_rticks([20, 30, 40])
ax[0].set_rlabel_position(-10)  # Move radial labels
ax[1].set_rlabel_position(-170)
ax[1].legend(df_daily['year'].unique(), bbox_to_anchor=(1.1, 1.2)) 

fig.suptitle("Average daily power (in GW) from solar (left) and" +
             "wind (right) in Germany for years 2015-2021.\n" + 
             "Data from energy-charts.info. Plot by Daniel" + 
             " Müller-Komorowska.")

The code itself is pretty simple because the data already comes well structured. Every year is in an individual file so we need to find all of those in the data directory. Probably the biggest preprocessing step is to calculate the daily average. The original data has a temporal resolution of 15 minutes which is a bit too noisy for the type of plot we are making. Once we extracted year, month and day we can do it on one line:

df_daily = df.groupby(by=['year', 'month', 'day']).mean().reset_index()

Next, we calculate theta for each day so we can distribute the datapoints along the polar plot. Once we have that we are basically done. The rest is matplotlib styling. Some of these styling aspects are hardcoded because they are hard to automate for different countries. Here are two more examples:

Matplotlib and the Object-Oriented Interface

Matplotlib is a great data visualization library for Python and there are two ways of using it. The functional interface (also known as pyplot interface) allows us to interactively create simple plots. The object-oriented interface on the other hand gives us more control when we create figures that contain multiple plots. While having two interfaces gives us a lot of freedom, they also cause some confusion. The most common error is to use the functional interface when using the object-oriented would be much easier. For beginners it is now highly recommended to use the object-oriented interface under most circumstances because they have a tendency to overuse the functional one. I made that mistake myself for a long time. I started out with the functional interface and only knew that one for a long time. Here I will explain the difference between both, starting with the object-oriented interface. If you have never used it, now is probably the time to start.

Figures, Axes & Methods

When using the object-oriented interface, we create objects and do the plotting with their methods. Methods are the functions that come with the object. We create both a figure and an axes object with plt.subplots(1). Then we use the ax.plot() method from our axes object to create the plot. We also use two more methods, ax.set_xlabel() and ax.set_ylabel() to label our axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0, 10, 0.1)
y = np.sin(np.pi * x) + x

fig, ax = plt.subplots(1)
ax.plot(x, y)
ax.set_xlabel("x")
ax.set_ylabel("y")

The big advantage is that we can very easily create multiple plots and we can very naturally keep track of where we are plotting what, because the method that does the plotting is associated with a specific axes object. In the next example we will plot on three different axes that we create all with plt.subplots(3).

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
ax[0].plot(x,ys[0])
ax[1].plot(x,ys[1])
ax[2].plot(x,ys[2])

ax[0].set_title("Addition")
ax[1].set_title("Multiplication")
ax[2].set_title("Division")

for a in ax:
    a.set_xlabel("x")
    a.set_ylabel("y")

When we create multiple axes objects, they are available to us through the ax array. We can index into them and we can also loop through all of them. We can take the above example even further and pack even the plotting into the for loop.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
titles = ["Addition", "Multiplication", "Division"]
for idx, a in enumerate(ax):
    a.plot(x, ys[idx])
    a.set_title(titles[idx])
    a.set_xlabel("x")
    a.set_ylabel("y")

This code produces exactly the same three axes figure as above. There are other ways to use the object-oriented interface. For example, we can create an empty figure without axes using fig = plt.figure(). We can then create subplots in that figure with ax = fig.add_subplot(). This is exactly the same concept as always but instead of creating figure and axes at the same time, we use the figure method to create axes. If personally prefer fig, ax = plt.subplots() but fig.add_subplot() is slightly more flexible in the way it allows us to arrange the axes. For example, plt.subplots(x, y) allows us to create a figure with axes arranged in x rows and y columns. Using fig.add_subplot() we could create a column with 2 axes and another with 3 axes.

fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax2 = fig.add_subplot(2,2,3)
ax3 = fig.add_subplot(3,2,2)
ax4 = fig.add_subplot(3,2,4)
ax5 = fig.add_subplot(3,2,6)

Personally I prefer to avoid these arrangements, because things like tight_layout don’t work but it is doable and cannot be done with plt.subplots(). This concludes our overview of the object-oriented interface. Simply remember that you want to do your plotting through the methods of an axes object that you can create either with fig, ax = plt.subplots() or fig.add_subplot(). So what is different about the functional interface? Instead of plotting through axes methods, we do all our plotting through functions in the matplotlib.pyplot module.

One pyplot to Rule Them All

The functional interface works entirely through the pyplot module, which we import as plt by convention. In the example below we use it to create the exact same plot as in the beginning. We use plt to create the figure, do the plotting and label the axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0,10,0.1)
y = np.sin(np.pi*x) + x

plt.figure()
plt.plot(x, y)
plt.xlabel("x")
plt.ylabel("y")

You might be wondering, why we don’t need to tell plt where to plot and which axes to label. It always works with the currently active figure or axes object. If there is no active figure, plt.plot() creates its own, including the axes. If a figure is already active, it creates an axes in that figure or plots into already existing axes. This make the functional interface less explicit and slightly less readable, especially for more complex figures. For the object-oriented interface, there is a specific object for any action, because a method must be called through an object. With the functional interface, it can be a guessing game where the plotting happens and we make ourselves highly dependent on the location in our script. The line we call plt.plot() on becomes crucial. Let’s recreate the three subplots example with the functional interface.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
plt.subplot(3, 1, 1)
plt.plot(x, ys[0])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Addition")
plt.subplot(3, 1, 2)
plt.plot(x, ys[1])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Multiplication")
plt.subplot(3, 1, 3)
plt.plot(x, ys[2])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Division")

This one is much longer than the object-oriented code because we cannot label the axes in a for loop. To be fair, in this particular example we can put the entirety of our plotting and labeling into a for loop, but we have to put either everything or nothing into the loop. This is what I mean when I say the functional interface is less flexible.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
titles = ["Addition", "Multiplication", "Division"]
for idx, y in enumerate(ys):
    plt.subplot(3,1,idx+1)
    plt.plot(x, y)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(titles[idx])

Both interfaces are very similar. You might have noticed that the methods in the object-oriented API have the form set_attribute. This is by design and follows from an object oriented convention, where methods that change attributes have a set prefix. Methods that don’t change attributes but create entirely new objects have an add prefix. For example add_subplot. Now that we have seen both APIs at work, why is the object-oriented API recommended?

Advantages of the Object-Oriented API

First of all, Matplotlib is internally object-oriented. The pyplot interface masks that fact in an effort to make the usage more MATLAB like by putting a functional layer on top. If we avoid plt and instead work with the object-oriented interface, our plotting becomes slightly faster. More importantly, the object-oriented interface is considered more readable and explicit. Both are very important when we write Python. Readability can be somewhat subjective but I hope the code could convince you that going through the plotting methods of an axes object makes it much more clear where we are plotting. We also get more flexibility to structure our code. Because plt depends on the order of plotting, we are constraint. With the object-oriented interface we can structure our code more clearly. We can for example split plotting, labeling and other tasks into their own code blocks.

In summary, I hope you will be able to use the object-oriented interface of Matplotlib now. Simply remember to create axes with fig, ax = plt.subplots() and then most of the work happens through the ax object. Finally, the object-oriented interface is recommended because it is more efficient, readable, explicit and flexible.

Plotting 2D Vectors with Matplotlib

Vectors are extremely important in linear algebra and beyond. One of the most common visual representations of a vector is the arrow. Here we will learn how to plot vectors with Matplotlib. The title image shows two vectors and their sum. As a first step we will plot the vectors originating at 0, shown below.

import matplotlib.pyplot as plt
import numpy as np

vectors = np.array(([2, 0], [3, 2]))
vector_addition = vectors[0] + vectors[1]
vectors = np.append(vectors, vector_addition[None,:], axis=0)

tail = [0, 0]
fig, ax = plt.subplots(1)
ax.quiver(*tail,
           vectors[:, 0],
           vectors[:, 1],
           scale=1,
           scale_units='xy',
           angles = 'xy',
           color=['g', 'r', 'k'])

ax.set_xlim((-1, vectors[:,0].max()+1))
ax.set_ylim((-1, vectors[:,1].max()+1))

We have two vectors stored in our vectors array. Those are [2, 0] and [3, 2]. Both in order of [x, y] as you can see from the image. We can perform vector addition between the two by simply adding vectors[0] + vectors[1]. Then we use np.append so we have all three vectors in the same array. Now we define the origin in tail, because we will want the tail of the arrow to be located at [0, 0]. Then we create the figure and axes to plot in with plt.subplots(). The plotting itself can be done with one call to the ax.quiver method. But it is quite the call, with a lot of parameters so let’s go through it.

First, we need to define the origin, so we pass *tail. Why the asterisk? ax.quiver really takes two parameters for the origin, X and Y. The asterisk causes [0, 0] to be unpacked into those two parameters. Next, we pass the x coordinates (vectors[:, 0]) and then the y coordinates (vectors[:, 1]) of our vectors. The next three parameters scale, scale_units and angles are necessary to make the arrow length match the actual numbers. By default, the arrows are scaled, based on the average of all plotted vectors. We get rid of that kind of scaling. Try removing some of those to get a better idea of what I mean. Finally, we pass three colors, one for each arrow.

So what do we need to plot the head to tail aligned vectors as in the title image? We just need to pass the vectors where the origin is the other vector.

ax.quiver(vectors[1::-1,0],
          vectors[1::-1,1],
          vectors[:2,0],
          vectors[:2,1],
          scale=1,
          scale_units='xy',
          angles = 'xy',
          color=['g', 'r'])

This is simple because it is the same quiver method but it is complicated because of the indexing syntax. Now, we no longer unpack *tail. Instead we pass x and y origins separately. In vectors[1::-1,0] the 0 gets the x coordinates. -1 inverts the array. If we would not invert, each vector would be it’s own origin. The 1 skips the first vector, which is the summed vector because we inverted. vectors[1::-1,1] gives us the y coordiantes. Finally we just need to skip the summed vector when we pass x and y magnitudes. The rest is the same.

So that’s it. Unfortunately, ax.quiver only works for 2D vectors. It also isn’t specifically made to present vectors that have a common origin. Its main use case is to plot vector fields. This is why some of the plotting here feels clunky. There is also ax.arrow which is more straightforward but only creates one arrow per method call. I hope this post was helpful for you. Let me know if you have other ways to plot vectors.

Animations with Matplotlib

Anything that can be plotted with Matplotlib can also be animated. This is especially useful when data changes over time. Animations allow us to see the dynamics in our data, which is nearly impossible with most static plots. Here we will learn how to animate with Matplotlib by producing this traveling wave animation.

This is the code to make the animation. It creates the traveling wave, defines two functions that handle the animation and creates the animation with the FuncAnimation class. Let’s take it step by step.

import numpy as np
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt

# Create the traveling wave
def wave(x, t, wavelength, speed):
    return np.sin((2*np.pi)*(x-speed*t)/wavelength)

x = np.arange(0,4,0.01)[np.newaxis,:]
t = np.arange(0,2,0.01)[:,np.newaxis]
wavelength = 1
speed = 1
yt = wave(x, t, wavelength, speed)  # shape is [t,y]

# Create the figure and axes to animate
fig, ax = plt.subplots(1)
# init_func() is called at the beginning of the animation
def init_func():
    ax.clear()

# update_plot() is called between frames
def update_plot(i):
    ax.clear()
    ax.plot(x[0,:], yt[i,:], color='k')

# Create animation
anim = FuncAnimation(fig,
                     update_plot,
                     frames=np.arange(0, len(t[:,0])),
                     init_func=init_func)

# Save animation
anim.save('traveling_wave.mp4',
          dpi=150,
          fps=30,
          writer='ffmpeg')

On the first three lines we import NumPy, Matplotlib and most importantly the FuncAnimation class. It will take the center stage in our code as it will create the animation later on by combining all the parts we need. On lines 5-13 we create the traveling wave. I don’t want to go into too much detail, as it is just a toy example for the animation. The important part is that we get the array yt, which defines the wave at each time point. So yt[0] contains the wave at t0 , yt[1] at t1 and so on. This is important, since we will be iterating over time during the animation. If you want to learn more about the traveling wave, you can change wavelength, speed and play around with the wave() function.

Now that we have our wave, we can start preparing the animation. We create a the figure and the axes we want to use with plt.subplots(1). Then we create a the init_func(). This one will be called whenever the animation starts or repeats. In this particular example it is pretty useless. I include it here because it is a useful feature for more complex animations.

Now we get to update_plot(), the heart of our animation. This function updates our figure between frames. It determines what we see on each frame. It is the most important function and it is shockingly simple. The parameter i is an integer that defines what frame we are at. We use that integer as an index into the first dimension of yt. We plot the wave as it looks at t=i. Importantly, we must clean up our axes with ax.clear(). If we would forget about clearing, our plot would quickly become all black, filled with waves.

Now FuncAnimation is where it all comes together. We pass it fig, update_plot and init_func. We also pass frames, those are the values that i will take on during the animation. Technically, this gets the animation going in your interactive Python console but most of the time we want to save our animation. We do that by calling anim.save(). We pass it the file name as a string, the resolution in dpi, the frames per second and finally the writer class used for generating the animation. Not all writers work for all file formats. I prefer .mp4 with the ffmpeg writer. If there are issues with saving, the most common problem is that the writer we are trying to use is not installed. If you want to find out if the ffmpeg writer is available on your machine, you can type matplotlib.animation.FFMpegWriter().isAvailable(). It returns True if the writer is available and False otherwise. If you are using Anaconda you can install the codec from here.

This wraps up our tutorial. This particular example is very simple, but anything that can be plotted can also be animated. I hope you are now on your way to create your own animations. I will leave you with a more involved animation I created.