Eventplots with Color in Matplotlib

Here we will learn how to color individual events in a Matplotlib eventplot. An eventplot most commonly shows the timing of events (x axis) from different sources (y axis). Assigning different colors to the sources can be done easily with the colors parameter of the plt.eventplot function. However, color coding each event requires us to create the eventplot and then manually assign the colors with event_collection.set_colors. Here is the code and the resulting output.

import matplotlib as mpl
import matplotlib.pyplot as plt
import shelve
import os
import numpy as np

dirname = os.path.dirname(__file__)
data_path = os.path.join(dirname, 'evenplot_demo_spikes')

data = shelve.open(data_path)

fig, ax = plt.subplots(1)

event_collection = ax.eventplot(data['spikes'])
phases_flattened = np.hstack(data['phases'])
phase_min = np.quantile(phases_flattened, 0.05)
phase_max = np.quantile(phases_flattened, 0.95)

phases_norm = mpl.colors.Normalize(vmin=phase_min, vmax=phase_max)
normalized_phases = phases_norm(data['phases'])
viridis = mpl.colormaps['viridis']

for idx, col in enumerate(event_collection):
    col.set_colors(viridis(normalized_phases[idx]))

plt.colorbar(mpl.cm.ScalarMappable(norm=phases_norm, cmap=viridis),
             ax=ax, label="Phase", fraction=0.05)

We start out by loading the data with the shelve module and the os module helps us search for the data wherever our script is located (os.path.dirname(file)). Next, we create the figure and axis and plot the event times into the axis. The events will be plotted with the default blue of Matplotlib. We have to make sure to get all the events by assigning to event_collection. We want to color the events according to the data in data['phases']. That data contains a single floating point number for each event and we need to convert those numbers to a RGB color. For that, we first want to normalize our data to a value between zero and one with mpl.colors.Normalize(vmin=phase_min, vmax=phase_max). I use a trick here where I take the 5% quantiles instead of the actual maximum and minimum of the data. This should not be done for a scientific figure, because values above the vmax will have the same color although they have different values, which can be misleading. Once we have our normalized values we create a colormap object with mpl.colormaps['viridis'] and we get can get colors from it by passing it the normalized values. We then loop through our event collection (each entry in event collection is a row in our plot) and we set the color with col.set_colors(viridis(normalized_phases[idx])). Finally, we create a colorbar with plt.colorbar. And that is it. Let me know if there is an easier way to to this that I missed.