Matplotlib and the Object-Oriented Interface

Matplotlib is a great data visualization library for Python and there are two ways of using it. The functional interface (also known as pyplot interface) allows us to interactively create simple plots. The object-oriented interface on the other hand gives us more control when we create figures that contain multiple plots. While having two interfaces gives us a lot of freedom, they also cause some confusion. The most common error is to use the functional interface when using the object-oriented would be much easier. For beginners it is now highly recommended to use the object-oriented interface under most circumstances because they have a tendency to overuse the functional one. I made that mistake myself for a long time. I started out with the functional interface and only knew that one for a long time. Here I will explain the difference between both, starting with the object-oriented interface. If you have never used it, now is probably the time to start.

Figures, Axes & Methods

When using the object-oriented interface, we create objects and do the plotting with their methods. Methods are the functions that come with the object. We create both a figure and an axes object with plt.subplots(1). Then we use the ax.plot() method from our axes object to create the plot. We also use two more methods, ax.set_xlabel() and ax.set_ylabel() to label our axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0, 10, 0.1)
y = np.sin(np.pi * x) + x

fig, ax = plt.subplots(1)
ax.plot(x, y)
ax.set_xlabel("x")
ax.set_ylabel("y")

The big advantage is that we can very easily create multiple plots and we can very naturally keep track of where we are plotting what, because the method that does the plotting is associated with a specific axes object. In the next example we will plot on three different axes that we create all with plt.subplots(3).

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
ax[0].plot(x,ys[0])
ax[1].plot(x,ys[1])
ax[2].plot(x,ys[2])

ax[0].set_title("Addition")
ax[1].set_title("Multiplication")
ax[2].set_title("Division")

for a in ax:
    a.set_xlabel("x")
    a.set_ylabel("y")

When we create multiple axes objects, they are available to us through the ax array. We can index into them and we can also loop through all of them. We can take the above example even further and pack even the plotting into the for loop.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

fig, ax = plt.subplots(3)
titles = ["Addition", "Multiplication", "Division"]
for idx, a in enumerate(ax):
    a.plot(x, ys[idx])
    a.set_title(titles[idx])
    a.set_xlabel("x")
    a.set_ylabel("y")

This code produces exactly the same three axes figure as above. There are other ways to use the object-oriented interface. For example, we can create an empty figure without axes using fig = plt.figure(). We can then create subplots in that figure with ax = fig.add_subplot(). This is exactly the same concept as always but instead of creating figure and axes at the same time, we use the figure method to create axes. If personally prefer fig, ax = plt.subplots() but fig.add_subplot() is slightly more flexible in the way it allows us to arrange the axes. For example, plt.subplots(x, y) allows us to create a figure with axes arranged in x rows and y columns. Using fig.add_subplot() we could create a column with 2 axes and another with 3 axes.

fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax2 = fig.add_subplot(2,2,3)
ax3 = fig.add_subplot(3,2,2)
ax4 = fig.add_subplot(3,2,4)
ax5 = fig.add_subplot(3,2,6)

Personally I prefer to avoid these arrangements, because things like tight_layout don’t work but it is doable and cannot be done with plt.subplots(). This concludes our overview of the object-oriented interface. Simply remember that you want to do your plotting through the methods of an axes object that you can create either with fig, ax = plt.subplots() or fig.add_subplot(). So what is different about the functional interface? Instead of plotting through axes methods, we do all our plotting through functions in the matplotlib.pyplot module.

One pyplot to Rule Them All

The functional interface works entirely through the pyplot module, which we import as plt by convention. In the example below we use it to create the exact same plot as in the beginning. We use plt to create the figure, do the plotting and label the axes.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0,10,0.1)
y = np.sin(np.pi*x) + x

plt.figure()
plt.plot(x, y)
plt.xlabel("x")
plt.ylabel("y")

You might be wondering, why we don’t need to tell plt where to plot and which axes to label. It always works with the currently active figure or axes object. If there is no active figure, plt.plot() creates its own, including the axes. If a figure is already active, it creates an axes in that figure or plots into already existing axes. This make the functional interface less explicit and slightly less readable, especially for more complex figures. For the object-oriented interface, there is a specific object for any action, because a method must be called through an object. With the functional interface, it can be a guessing game where the plotting happens and we make ourselves highly dependent on the location in our script. The line we call plt.plot() on becomes crucial. Let’s recreate the three subplots example with the functional interface.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
plt.subplot(3, 1, 1)
plt.plot(x, ys[0])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Addition")
plt.subplot(3, 1, 2)
plt.plot(x, ys[1])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Multiplication")
plt.subplot(3, 1, 3)
plt.plot(x, ys[2])
plt.xlabel("x")
plt.ylabel("y")
plt.title("Division")

This one is much longer than the object-oriented code because we cannot label the axes in a for loop. To be fair, in this particular example we can put the entirety of our plotting and labeling into a for loop, but we have to put either everything or nothing into the loop. This is what I mean when I say the functional interface is less flexible.

x = np.arange(0,10,0.1)
ys = [np.sin(np.pi*x) + x,
      np.sin(np.pi*x) * x,
      np.sin(np.pi*x) / x]

plt.figure()
titles = ["Addition", "Multiplication", "Division"]
for idx, y in enumerate(ys):
    plt.subplot(3,1,idx+1)
    plt.plot(x, y)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(titles[idx])

Both interfaces are very similar. You might have noticed that the methods in the object-oriented API have the form set_attribute. This is by design and follows from an object oriented convention, where methods that change attributes have a set prefix. Methods that don’t change attributes but create entirely new objects have an add prefix. For example add_subplot. Now that we have seen both APIs at work, why is the object-oriented API recommended?

Advantages of the Object-Oriented API

First of all, Matplotlib is internally object-oriented. The pyplot interface masks that fact in an effort to make the usage more MATLAB like by putting a functional layer on top. If we avoid plt and instead work with the object-oriented interface, our plotting becomes slightly faster. More importantly, the object-oriented interface is considered more readable and explicit. Both are very important when we write Python. Readability can be somewhat subjective but I hope the code could convince you that going through the plotting methods of an axes object makes it much more clear where we are plotting. We also get more flexibility to structure our code. Because plt depends on the order of plotting, we are constraint. With the object-oriented interface we can structure our code more clearly. We can for example split plotting, labeling and other tasks into their own code blocks.

In summary, I hope you will be able to use the object-oriented interface of Matplotlib now. Simply remember to create axes with fig, ax = plt.subplots() and then most of the work happens through the ax object. Finally, the object-oriented interface is recommended because it is more efficient, readable, explicit and flexible.

Plotting 2D Vectors with Matplotlib

Vectors are extremely important in linear algebra and beyond. One of the most common visual representations of a vector is the arrow. Here we will learn how to plot vectors with Matplotlib. The title image shows two vectors and their sum. As a first step we will plot the vectors originating at 0, shown below.

import matplotlib.pyplot as plt
import numpy as np

vectors = np.array(([2, 0], [3, 2]))
vector_addition = vectors[0] + vectors[1]
vectors = np.append(vectors, vector_addition[None,:], axis=0)

tail = [0, 0]
fig, ax = plt.subplots(1)
ax.quiver(*tail,
           vectors[:, 0],
           vectors[:, 1],
           scale=1,
           scale_units='xy',
           angles = 'xy',
           color=['g', 'r', 'k'])

ax.set_xlim((-1, vectors[:,0].max()+1))
ax.set_ylim((-1, vectors[:,1].max()+1))

We have two vectors stored in our vectors array. Those are [2, 0] and [3, 2]. Both in order of [x, y] as you can see from the image. We can perform vector addition between the two by simply adding vectors[0] + vectors[1]. Then we use np.append so we have all three vectors in the same array. Now we define the origin in tail, because we will want the tail of the arrow to be located at [0, 0]. Then we create the figure and axes to plot in with plt.subplots(). The plotting itself can be done with one call to the ax.quiver method. But it is quite the call, with a lot of parameters so let’s go through it.

First, we need to define the origin, so we pass *tail. Why the asterisk? ax.quiver really takes two parameters for the origin, X and Y. The asterisk causes [0, 0] to be unpacked into those two parameters. Next, we pass the x coordinates (vectors[:, 0]) and then the y coordinates (vectors[:, 1]) of our vectors. The next three parameters scale, scale_units and angles are necessary to make the arrow length match the actual numbers. By default, the arrows are scaled, based on the average of all plotted vectors. We get rid of that kind of scaling. Try removing some of those to get a better idea of what I mean. Finally, we pass three colors, one for each arrow.

So what do we need to plot the head to tail aligned vectors as in the title image? We just need to pass the vectors where the origin is the other vector.

ax.quiver(vectors[1::-1,0],
          vectors[1::-1,1],
          vectors[:2,0],
          vectors[:2,1],
          scale=1,
          scale_units='xy',
          angles = 'xy',
          color=['g', 'r'])

This is simple because it is the same quiver method but it is complicated because of the indexing syntax. Now, we no longer unpack *tail. Instead we pass x and y origins separately. In vectors[1::-1,0] the 0 gets the x coordinates. -1 inverts the array. If we would not invert, each vector would be it’s own origin. The 1 skips the first vector, which is the summed vector because we inverted. vectors[1::-1,1] gives us the y coordiantes. Finally we just need to skip the summed vector when we pass x and y magnitudes. The rest is the same.

So that’s it. Unfortunately, ax.quiver only works for 2D vectors. It also isn’t specifically made to present vectors that have a common origin. Its main use case is to plot vector fields. This is why some of the plotting here feels clunky. There is also ax.arrow which is more straightforward but only creates one arrow per method call. I hope this post was helpful for you. Let me know if you have other ways to plot vectors.

Animations with Matplotlib

Anything that can be plotted with Matplotlib can also be animated. This is especially useful when data changes over time. Animations allow us to see the dynamics in our data, which is nearly impossible with most static plots. Here we will learn how to animate with Matplotlib by producing this traveling wave animation.

This is the code to make the animation. It creates the traveling wave, defines two functions that handle the animation and creates the animation with the FuncAnimation class. Let’s take it step by step.

import numpy as np
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt

# Create the traveling wave
def wave(x, t, wavelength, speed):
    return np.sin((2*np.pi)*(x-speed*t)/wavelength)

x = np.arange(0,4,0.01)[np.newaxis,:]
t = np.arange(0,2,0.01)[:,np.newaxis]
wavelength = 1
speed = 1
yt = wave(x, t, wavelength, speed)  # shape is [t,y]

# Create the figure and axes to animate
fig, ax = plt.subplots(1)
# init_func() is called at the beginning of the animation
def init_func():
    ax.clear()

# update_plot() is called between frames
def update_plot(i):
    ax.clear()
    ax.plot(x[0,:], yt[i,:], color='k')

# Create animation
anim = FuncAnimation(fig,
                     update_plot,
                     frames=np.arange(0, len(t[:,0])),
                     init_func=init_func)

# Save animation
anim.save('traveling_wave.mp4',
          dpi=150,
          fps=30,
          writer='ffmpeg')

On the first three lines we import NumPy, Matplotlib and most importantly the FuncAnimation class. It will take the center stage in our code as it will create the animation later on by combining all the parts we need. On lines 5-13 we create the traveling wave. I don’t want to go into too much detail, as it is just a toy example for the animation. The important part is that we get the array yt, which defines the wave at each time point. So yt[0] contains the wave at t0 , yt[1] at t1 and so on. This is important, since we will be iterating over time during the animation. If you want to learn more about the traveling wave, you can change wavelength, speed and play around with the wave() function.

Now that we have our wave, we can start preparing the animation. We create a the figure and the axes we want to use with plt.subplots(1). Then we create a the init_func(). This one will be called whenever the animation starts or repeats. In this particular example it is pretty useless. I include it here because it is a useful feature for more complex animations.

Now we get to update_plot(), the heart of our animation. This function updates our figure between frames. It determines what we see on each frame. It is the most important function and it is shockingly simple. The parameter i is an integer that defines what frame we are at. We use that integer as an index into the first dimension of yt. We plot the wave as it looks at t=i. Importantly, we must clean up our axes with ax.clear(). If we would forget about clearing, our plot would quickly become all black, filled with waves.

Now FuncAnimation is where it all comes together. We pass it fig, update_plot and init_func. We also pass frames, those are the values that i will take on during the animation. Technically, this gets the animation going in your interactive Python console but most of the time we want to save our animation. We do that by calling anim.save(). We pass it the file name as a string, the resolution in dpi, the frames per second and finally the writer class used for generating the animation. Not all writers work for all file formats. I prefer .mp4 with the ffmpeg writer. If there are issues with saving, the most common problem is that the writer we are trying to use is not installed. If you want to find out if the ffmpeg writer is available on your machine, you can type matplotlib.animation.FFMpegWriter().isAvailable(). It returns True if the writer is available and False otherwise. If you are using Anaconda you can install the codec from here.

This wraps up our tutorial. This particular example is very simple, but anything that can be plotted can also be animated. I hope you are now on your way to create your own animations. I will leave you with a more involved animation I created.