Eventplots with Color in Matplotlib

Here we will learn how to color individual events in a Matplotlib eventplot. An eventplot most commonly shows the timing of events (x axis) from different sources (y axis). Assigning different colors to the sources can be done easily with the colors parameter of the plt.eventplot function. However, color coding each event requires us to create the eventplot and then manually assign the colors with event_collection.set_colors. Here is the code and the resulting output.

import matplotlib as mpl
import matplotlib.pyplot as plt
import shelve
import os
import numpy as np

dirname = os.path.dirname(__file__)
data_path = os.path.join(dirname, 'evenplot_demo_spikes')

data = shelve.open(data_path)

fig, ax = plt.subplots(1)

event_collection = ax.eventplot(data['spikes'])
phases_flattened = np.hstack(data['phases'])
phase_min = np.quantile(phases_flattened, 0.05)
phase_max = np.quantile(phases_flattened, 0.95)

phases_norm = mpl.colors.Normalize(vmin=phase_min, vmax=phase_max)
normalized_phases = phases_norm(data['phases'])
viridis = mpl.colormaps['viridis']

for idx, col in enumerate(event_collection):
    col.set_colors(viridis(normalized_phases[idx]))

plt.colorbar(mpl.cm.ScalarMappable(norm=phases_norm, cmap=viridis),
             ax=ax, label="Phase", fraction=0.05)

We start out by loading the data with the shelve module and the os module helps us search for the data wherever our script is located (os.path.dirname(file)). Next, we create the figure and axis and plot the event times into the axis. The events will be plotted with the default blue of Matplotlib. We have to make sure to get all the events by assigning to event_collection. We want to color the events according to the data in data['phases']. That data contains a single floating point number for each event and we need to convert those numbers to a RGB color. For that, we first want to normalize our data to a value between zero and one with mpl.colors.Normalize(vmin=phase_min, vmax=phase_max). I use a trick here where I take the 5% quantiles instead of the actual maximum and minimum of the data. This should not be done for a scientific figure, because values above the vmax will have the same color although they have different values, which can be misleading. Once we have our normalized values we create a colormap object with mpl.colormaps['viridis'] and we get can get colors from it by passing it the normalized values. We then loop through our event collection (each entry in event collection is a row in our plot) and we set the color with col.set_colors(viridis(normalized_phases[idx])). Finally, we create a colorbar with plt.colorbar. And that is it. Let me know if there is an easier way to to this that I missed.

Measuring and Visualizing GPU Power Usage in Real Time with asyncio and Matplotlib

In this post we will learn how to periodically measure the power power usage of our GPU and plot it in real time with a single Python program. For this we need concurrency between the measuring and the plotting part of our code. Concurrency means that the measuring process will got to sleep after measuring. While the measuring process is asleep the plotting process can do the plotting and goes to sleep as well. After a defined amount of time the measuring process wakes up and does the measuring if the CPU allows it, then the plotting process starts and so on. We achieve concurrency with asyncio and the plotting is done with Matplotlib. To measure the GPU power we use pynmvl (Python Bindings for the NVIDIA Management Library). Before we get into the code, here is a video showing the interface in action.

This video shows the power in Watt at my GPU over a twenty seconds time window. The measurements are taken every 100 milliseconds and the plotting is done every 200 milliseconds. Everything happens in one Python script. I ran this by simply passing the below script to python in my command line.
import pynvml
import matplotlib.pyplot as plt
import time
import numpy as np
import asyncio

"""Initialize GPU measurement and parameters"""
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
measurement_interval = 0.1  # in seconds
plotting_interval = 0.2  # in seconds
time_span = 20  # time span on the plot x-axis in seconds
m = t = np.array([np.nan]*int(time_span / measurement_interval))
mW_to_W = 1e3

"""Initialize the plot"""
plt.ion()
plt.rcParams.update({'font.size': 18})
figure, ax = plt.subplots(figsize=(8,6))
line1, = ax.plot(t, m, linewidth=3)
ax.set_xlabel("Time (s)")
ax.set_ylabel("GPU Power (W)")

async def measure():
    while True:
        measure = pynvml.nvmlDeviceGetPowerUsage(handle) / mW_to_W
        dt = time.time() - ts
        m[:-1] = m[1:]
        m[-1] = measure
        t[:-1] = t[1:]
        t[-1] = dt
        await asyncio.sleep(measurement_interval)

async def plot():
    while True:
        line1.set_data(t, m)
        tmin, tmax = np.nanmin(t), np.nanmax(t)
        mmin, mmax = np.nanmin(m), np.nanmax(m)
        margin = (np.abs(mmax - mmin) / 10) + 0.1
        ax.set_xlim((tmin, tmax + 1))
        ax.set_ylim((mmin - margin, mmax + margin))
        figure.canvas.flush_events()
        await asyncio.sleep(plotting_interval)

async def main():
    t1 = loop.create_task(measure())
    t2 = loop.create_task(plot())
    await t2, t1
    
if __name__ == "__main__":
    ts = time.time()
    loop = asyncio.new_event_loop()
    loop.run_until_complete(main())

We will start with the functions async def measure() and async def plot() since they are central to the program. First, note that neither of them are ordinary functions because of the async keyword. This keyword has been added in Python 3.5 and in earlier Python versions we could have instead decorated the functions with the @asyncio.coroutine decorator. The async keyword turns our function into a coroutine which allows us to use the await keyword inside. With the await keyword we can put the coroutine to sleep with await asyncio.sleep(measurement_interval). While asleep the asyncio event loop can run other coroutines that are not asleep. More on the asyncio event loop later. Because we want to keep measuring until someone terminates the program we wrap everything in measure into an infinite loop while True:.

So what do we do while measuring? Outside of the coroutine we define two arrays m, t, one to hold the measured power and the other to measure the passed time. Measuring time is important because energy is power during a time period and we generally need to be sure that the coroutine isn’t getting stuck asleep much longer than we want it to. When we measure a value we move the current elements in the measurement array one to the left by assignment with m[:-1] = m[1:]. We then assign the newly measured value to the right of the array with m[-1] = measure. That is all there is to our measurements.

Our plot coroutine works just like the measure coroutine except that it plots whatever is in the time and measurement arrays before it goes to sleep. The plotting itself is basic matplotlib but it is important to note that figure.canvas.flush_events() is critical for updating the plot in real time. Furthermore, when we initialize the plot, plt.ion() is important for the plot to show properly.

Coroutines are not called like normal functions. They do their work as tasks within an asyncio event loop. This event loop knows which coroutines are asleep and decides which coroutine starts working next. This task may seem manageable with two coroutines but with three it becomes tedious already. As a coroutine goes to sleep two may be awake, waiting to get to work. The event loop has to decide which one goes next. Luckily asyncio takes care of the details for us and we can focus on the work we want to get done instead. However, we need to create an event loop with loop = asyncio.new_event_loop() and then we start it with loop.run_until_complete(main()). The coroutines only get to work when the loop starts. Both our coroutines are in main(), thereby both become part of the event loop. Because of the event loop I recommend running the code from the command line. Running it in interactive environments can cause problems because other event loops might already be running there.

With that, we already covered the most important parts of the code. There are several things we could do differently and some of those might make the code better. For one, we could use a technique called blitting (explained here) to improve the performance of the plotting. We could also do the plotting with FuncAnimation (explained here) instead of writing our own coroutine. I tried that for a while but was not able to make the animation and the measurement() coroutine work together in the same event loop. There probably is a way to do it that I did not find. Let me know if you have other points for improvement.

You can find pynvml here. asyncio is part of the Python installation and you can find the docs here. I was inspired to do this project by a package called codecarbon that you can find here. It estimates the carbon footprint of computation and I plan to blog about it soon.

Matplotlib and the Object-Oriented Interface

Matplotlib is a great data visualization library for Python and there are two ways of using it. The functional interface (also known as pyplot interface) allows us to interactively create simple plots. The object-oriented interface on the other hand gives us more control when we create figures that contain multiple plots. While having two interfaces gives us a lot of freedom, they also cause some confusion. The most common error is to use the functional interface when using the object-oriented would be much easier. For beginners it is now highly recommended to use the object-oriented interface under most circumstances because they have a tendency to overuse the functional one. I made that mistake myself for a long time. I started out with the functional interface and only knew that one for a long time. Here I will explain the difference between both, starting with the object-oriented interface. If you have never used it, now is probably the time to start.

Figures, Axes & Methods

When using the object-oriented interface, we create objects and do the plotting with their methods. Methods are the functions that come with the object. We create both a figure and an axes object with plt.subplots(1). Then we use the ax.plot() method from our axes object to create the plot. We also use two more methods, ax.set_xlabel() and ax.set_ylabel() to label our axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0, 10, 0.1)
y = np.sin(np.pi * x) + x

fig, ax = plt.subplots(1)
ax.plot(x, y)
ax.set_xlabel("x")
ax.set_ylabel("y")

The big advantage is that we can very easily create multiple plots and we can very naturally keep track of where we are plotting what, because the method that does the plotting is associated with a specific axes object. In the next example we will plot on three different axes that we create all with plt.subplots(3).

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
ax[0].plot(x,ys[0])
ax[1].plot(x,ys[1])
ax[2].plot(x,ys[2])

ax[0].set_title("Addition")
ax[1].set_title("Multiplication")
ax[2].set_title("Division")

for a in ax:
    a.set_xlabel("x")
    a.set_ylabel("y")

When we create multiple axes objects, they are available to us through the ax array. We can index into them and we can also loop through all of them. We can take the above example even further and pack even the plotting into the for loop.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
titles = ["Addition", "Multiplication", "Division"]
for idx, a in enumerate(ax):
    a.plot(x, ys[idx])
    a.set_title(titles[idx])
    a.set_xlabel("x")
    a.set_ylabel("y")

This code produces exactly the same three axes figure as above. There are other ways to use the object-oriented interface. For example, we can create an empty figure without axes using fig = plt.figure(). We can then create subplots in that figure with ax = fig.add_subplot(). This is exactly the same concept as always but instead of creating figure and axes at the same time, we use the figure method to create axes. If personally prefer fig, ax = plt.subplots() but fig.add_subplot() is slightly more flexible in the way it allows us to arrange the axes. For example, plt.subplots(x, y) allows us to create a figure with axes arranged in x rows and y columns. Using fig.add_subplot() we could create a column with 2 axes and another with 3 axes.

fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax2 = fig.add_subplot(2,2,3)
ax3 = fig.add_subplot(3,2,2)
ax4 = fig.add_subplot(3,2,4)
ax5 = fig.add_subplot(3,2,6)

Personally I prefer to avoid these arrangements, because things like tight_layout don’t work but it is doable and cannot be done with plt.subplots(). This concludes our overview of the object-oriented interface. Simply remember that you want to do your plotting through the methods of an axes object that you can create either with fig, ax = plt.subplots() or fig.add_subplot(). So what is different about the functional interface? Instead of plotting through axes methods, we do all our plotting through functions in the matplotlib.pyplot module.

One pyplot to Rule Them All

The functional interface works entirely through the pyplot module, which we import as plt by convention. In the example below we use it to create the exact same plot as in the beginning. We use plt to create the figure, do the plotting and label the axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0,10,0.1)
y = np.sin(np.pi*x) + x

plt.figure()
plt.plot(x, y)
plt.xlabel("x")
plt.ylabel("y")

You might be wondering, why we don’t need to tell plt where to plot and which axes to label. It always works with the currently active figure or axes object. If there is no active figure, plt.plot() creates its own, including the axes. If a figure is already active, it creates an axes in that figure or plots into already existing axes. This make the functional interface less explicit and slightly less readable, especially for more complex figures. For the object-oriented interface, there is a specific object for any action, because a method must be called through an object. With the functional interface, it can be a guessing game where the plotting happens and we make ourselves highly dependent on the location in our script. The line we call plt.plot() on becomes crucial. Let’s recreate the three subplots example with the functional interface.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
plt.subplot(3, 1, 1)
plt.plot(x, ys[0])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Addition")
plt.subplot(3, 1, 2)
plt.plot(x, ys[1])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Multiplication")
plt.subplot(3, 1, 3)
plt.plot(x, ys[2])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Division")

This one is much longer than the object-oriented code because we cannot label the axes in a for loop. To be fair, in this particular example we can put the entirety of our plotting and labeling into a for loop, but we have to put either everything or nothing into the loop. This is what I mean when I say the functional interface is less flexible.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
titles = ["Addition", "Multiplication", "Division"]
for idx, y in enumerate(ys):
    plt.subplot(3,1,idx+1)
    plt.plot(x, y)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(titles[idx])

Both interfaces are very similar. You might have noticed that the methods in the object-oriented API have the form set_attribute. This is by design and follows from an object oriented convention, where methods that change attributes have a set prefix. Methods that don’t change attributes but create entirely new objects have an add prefix. For example add_subplot. Now that we have seen both APIs at work, why is the object-oriented API recommended?

Advantages of the Object-Oriented API

First of all, Matplotlib is internally object-oriented. The pyplot interface masks that fact in an effort to make the usage more MATLAB like by putting a functional layer on top. If we avoid plt and instead work with the object-oriented interface, our plotting becomes slightly faster. More importantly, the object-oriented interface is considered more readable and explicit. Both are very important when we write Python. Readability can be somewhat subjective but I hope the code could convince you that going through the plotting methods of an axes object makes it much more clear where we are plotting. We also get more flexibility to structure our code. Because plt depends on the order of plotting, we are constraint. With the object-oriented interface we can structure our code more clearly. We can for example split plotting, labeling and other tasks into their own code blocks.

In summary, I hope you will be able to use the object-oriented interface of Matplotlib now. Simply remember to create axes with fig, ax = plt.subplots() and then most of the work happens through the ax object. Finally, the object-oriented interface is recommended because it is more efficient, readable, explicit and flexible.

Plotting 2D Vectors with Matplotlib

Vectors are extremely important in linear algebra and beyond. One of the most common visual representations of a vector is the arrow. Here we will learn how to plot vectors with Matplotlib. The title image shows two vectors and their sum. As a first step we will plot the vectors originating at 0, shown below.

import matplotlib.pyplot as plt
import numpy as np

vectors = np.array(([2, 0], [3, 2]))
vector_addition = vectors[0] + vectors[1]
vectors = np.append(vectors, vector_addition[None,:], axis=0)

tail = [0, 0]
fig, ax = plt.subplots(1)
ax.quiver(*tail,
           vectors[:, 0],
           vectors[:, 1],
           scale=1,
           scale_units='xy',
           angles = 'xy',
           color=['g', 'r', 'k'])

ax.set_xlim((-1, vectors[:,0].max()+1))
ax.set_ylim((-1, vectors[:,1].max()+1))

We have two vectors stored in our vectors array. Those are [2, 0] and [3, 2]. Both in order of [x, y] as you can see from the image. We can perform vector addition between the two by simply adding vectors[0] + vectors[1]. Then we use np.append so we have all three vectors in the same array. Now we define the origin in tail, because we will want the tail of the arrow to be located at [0, 0]. Then we create the figure and axes to plot in with plt.subplots(). The plotting itself can be done with one call to the ax.quiver method. But it is quite the call, with a lot of parameters so let’s go through it.

First, we need to define the origin, so we pass *tail. Why the asterisk? ax.quiver really takes two parameters for the origin, X and Y. The asterisk causes [0, 0] to be unpacked into those two parameters. Next, we pass the x coordinates (vectors[:, 0]) and then the y coordinates (vectors[:, 1]) of our vectors. The next three parameters scale, scale_units and angles are necessary to make the arrow length match the actual numbers. By default, the arrows are scaled, based on the average of all plotted vectors. We get rid of that kind of scaling. Try removing some of those to get a better idea of what I mean. Finally, we pass three colors, one for each arrow.

So what do we need to plot the head to tail aligned vectors as in the title image? We just need to pass the vectors where the origin is the other vector.

ax.quiver(vectors[1::-1,0],
          vectors[1::-1,1],
          vectors[:2,0],
          vectors[:2,1],
          scale=1,
          scale_units='xy',
          angles = 'xy',
          color=['g', 'r'])

This is simple because it is the same quiver method but it is complicated because of the indexing syntax. Now, we no longer unpack *tail. Instead we pass x and y origins separately. In vectors[1::-1,0] the 0 gets the x coordinates. -1 inverts the array. If we would not invert, each vector would be it’s own origin. The 1 skips the first vector, which is the summed vector because we inverted. vectors[1::-1,1] gives us the y coordiantes. Finally we just need to skip the summed vector when we pass x and y magnitudes. The rest is the same.

So that’s it. Unfortunately, ax.quiver only works for 2D vectors. It also isn’t specifically made to present vectors that have a common origin. Its main use case is to plot vector fields. This is why some of the plotting here feels clunky. There is also ax.arrow which is more straightforward but only creates one arrow per method call. I hope this post was helpful for you. Let me know if you have other ways to plot vectors.

Animations with Matplotlib

Anything that can be plotted with Matplotlib can also be animated. This is especially useful when data changes over time. Animations allow us to see the dynamics in our data, which is nearly impossible with most static plots. Here we will learn how to animate with Matplotlib by producing this traveling wave animation.

This is the code to make the animation. It creates the traveling wave, defines two functions that handle the animation and creates the animation with the FuncAnimation class. Let’s take it step by step.

import numpy as np
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt

# Create the traveling wave
def wave(x, t, wavelength, speed):
    return np.sin((2*np.pi)*(x-speed*t)/wavelength)

x = np.arange(0,4,0.01)[np.newaxis,:]
t = np.arange(0,2,0.01)[:,np.newaxis]
wavelength = 1
speed = 1
yt = wave(x, t, wavelength, speed)  # shape is [t,y]

# Create the figure and axes to animate
fig, ax = plt.subplots(1)
# init_func() is called at the beginning of the animation
def init_func():
    ax.clear()

# update_plot() is called between frames
def update_plot(i):
    ax.clear()
    ax.plot(x[0,:], yt[i,:], color='k')

# Create animation
anim = FuncAnimation(fig,
                     update_plot,
                     frames=np.arange(0, len(t[:,0])),
                     init_func=init_func)

# Save animation
anim.save('traveling_wave.mp4',
          dpi=150,
          fps=30,
          writer='ffmpeg')

On the first three lines we import NumPy, Matplotlib and most importantly the FuncAnimation class. It will take the center stage in our code as it will create the animation later on by combining all the parts we need. On lines 5-13 we create the traveling wave. I don’t want to go into too much detail, as it is just a toy example for the animation. The important part is that we get the array yt, which defines the wave at each time point. So yt[0] contains the wave at t0 , yt[1] at t1 and so on. This is important, since we will be iterating over time during the animation. If you want to learn more about the traveling wave, you can change wavelength, speed and play around with the wave() function.

Now that we have our wave, we can start preparing the animation. We create a the figure and the axes we want to use with plt.subplots(1). Then we create a the init_func(). This one will be called whenever the animation starts or repeats. In this particular example it is pretty useless. I include it here because it is a useful feature for more complex animations.

Now we get to update_plot(), the heart of our animation. This function updates our figure between frames. It determines what we see on each frame. It is the most important function and it is shockingly simple. The parameter i is an integer that defines what frame we are at. We use that integer as an index into the first dimension of yt. We plot the wave as it looks at t=i. Importantly, we must clean up our axes with ax.clear(). If we would forget about clearing, our plot would quickly become all black, filled with waves.

Now FuncAnimation is where it all comes together. We pass it fig, update_plot and init_func. We also pass frames, those are the values that i will take on during the animation. Technically, this gets the animation going in your interactive Python console but most of the time we want to save our animation. We do that by calling anim.save(). We pass it the file name as a string, the resolution in dpi, the frames per second and finally the writer class used for generating the animation. Not all writers work for all file formats. I prefer .mp4 with the ffmpeg writer. If there are issues with saving, the most common problem is that the writer we are trying to use is not installed. If you want to find out if the ffmpeg writer is available on your machine, you can type matplotlib.animation.FFMpegWriter().isAvailable(). It returns True if the writer is available and False otherwise. If you are using Anaconda you can install the codec from here.

This wraps up our tutorial. This particular example is very simple, but anything that can be plotted can also be animated. I hope you are now on your way to create your own animations. I will leave you with a more involved animation I created.