Categories

Use the ** operator in Python to clean up matplotlib plotting code

Problem 1:

Plotting in Matplotlib is not without it’s challenges. I spend way too much time styling/tweaking little parameters – colorbars, tick sizes, etc, and I normally end up with bloated matplotlib commands that are full of many, many arguments. This might look something like this:

plt.scatter(x,y, color='red', alpha=.5, marker='d', zorder=1, s=.1)

This is of course fine. But often times you might want to make multiple subplots using the same styling, in which case you might have to reapeat typing (or copy/pasting) this same line several times. And later on, when you inevitably decide to change alpha from .1 to .2 (or whatever), then that requires even more typing. And the conference presentation is in 2 hours!

Enter the ** Operator

This is by no means a hidden feature in python, but I don’t see it used all that often in the beginner python-plotting-for-science literature, which is why I’m making this post! The ** in a function call allows for the key:value pairs in a python dictionary to get passed into a function as key-word arguments. Thus can define a dictionay with the all of the keywords that we want and then use ** to pass those values into the plot function:

cool_style = {"color":'red',
			  "alpha":.5,
			  "marker":"d",
			  "zorder":1,
			  "s":.1}

plt.scatter(x,y, **cool_style)

Problem 2:

The second problem I encounter arises when wrapping together multiple plotting commands into a single, new function. For example, I commonly want to make a scatter plot and add a linear regression plot on top of that. If I want to do this for multiple x~y data combinations, then it makes sense to have a function that wraps everything up as one command. So I would end up with a function that looks something like:

from matplotlib import pyplot as plt
import numpy as np
from scipy import stats


def scatter_with_linregress(ax, x, y):
    # draw a scatterplot
	ax.scatter(x,y)

	# compute linear regression
    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)

    # plot the linear regression line
    ax.plot(x, intercept + slope*x)

This will make the plot that we want. BUT… What about adding different styles for each time I apply “scatter_with_linregress”? I could add an entry for the style parameters that I care about:


def scatter_with_linregress(ax, x, y, ls, linreg_color, linreg_ls, ....):
...

Where I then pass the arguments to the respective plotting command inside the function. This becomes untenable very quickly. There are many style arguments to chose from, and inevitably I will want to change one down the road that I haven’t made as an argument in scatter_with_linregress. The script below offers one solution for getting around this.

def scatter_with_linregress(ax, x, y, **kwargs):


	# Read in the command line arguments...
	# Return a blank dictionary if there are none
    linreg_styl = kwargs.get('linreg_style', {})
    scatter_styl = kwargs.get('scatter_style', {})

    # draw a scatterplot
	ax.scatter(x,y, **scatter_style)

	# compute linear regression
    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)

    # plot the linear regression line
    ax.plot(x, intercept + slope*x, **linreg_style)



# Only run the the lines below if we execute
# (versus import) this script
if __name__ == '__main__':
	# create figure
	fig,ax = plt.subplots()

	# create random data...
	x = np.random.rand(100)
	y = np.random.rand(100)

	# define style
	cool_style = {"color":"red",
	              "alpha":.5,
    	          "ls":"-"}

	other_style = {"color":"black",
	               "alpha":.7,
    	           "maker":"x"}

  	# make the plots
	scatter_with_linregress(ax, x, y, linreg_styl=cool_style, scatter_styl=other_style)
	plt.show()

The first line of the new function might be a odd looking, and there might be a more elegant way do to this. Regardless, it lets us pass in an optional style arguments, which we call “linreg_style” and “scatter_style” into the plotting function. If we don’t pass in either of those dictionaries, the kwargs.get lines will return a blank dictionary, which is still safely passed into the plotting commands.

TL;DR

** placed in front of dictionaries in function calls will interpret each key:value pair as a keyword argument. This can make matplotlib easier to work with.