Skip to content

Commit 3607930

Browse files
committed
enable passing of external Axes to plot on
Adds to the plot_image and plot_tabular functions an optional argument ax: plt.Axes. When given, this is the Axes that will be used to plot on and the internal plt.subplots call is skipped. This is useful for using these functions in custom multi-panel plots (we want to use this in a paper).
1 parent ae3612a commit 3607930

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

dianna/visualization/image.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def plot_image(heatmap,
1616
heatmap_range=(None, None), # (vmin, vmax)
1717
data_cmap=None,
1818
show_plot=True,
19-
output_filename=None):
19+
output_filename=None,
20+
ax=None):
2021
"""Plots a heatmap image.
2122
2223
Optionally, the heatmap (typically a saliency map of an explainer) can be
@@ -38,13 +39,18 @@ def plot_image(heatmap,
3839
show_plot: Shows plot if true (for testing or writing plots to disk
3940
instead).
4041
output_filename: Name of the file to save the plot to (optional).
42+
ax: matplotlib.Axes object to plot on (optional).
4143
4244
Returns:
4345
None
4446
"""
4547
# default cmap depends on shape: grayscale or colour
4648

47-
fig, ax = plt.subplots()
49+
if ax is None:
50+
fig, ax = plt.subplots()
51+
else:
52+
fig = ax.get_figure()
53+
4854
alpha = 1
4955
if original_data is not None:
5056
if len(original_data.shape) == 2 and data_cmap is None:

dianna/visualization/tabular.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def plot_tabular(
1313
num_features: Optional[int] = None,
1414
show_plot: Optional[bool] = True,
1515
output_filename: Optional[str] = None,
16+
ax: Optional[plt.Axes] = None,
1617
) -> plt.Figure:
1718
"""Plot feature importance with segments highlighted.
1819
@@ -26,21 +27,23 @@ def plot_tabular(
2627
plots to disk instead).
2728
output_filename (str, optional): Name of the file to save
2829
the plot to (optional).
30+
ax (matplotlib.Axes, optional): externally created canvas to plot on.
2931
3032
Returns:
3133
plt.Figure
3234
"""
3335
if not num_features:
3436
num_features = len(x)
35-
36-
3737
abs_values = [abs(i) for i in x]
3838
top_values = [x for _, x in sorted(zip(abs_values, x), reverse=True)][:num_features]
3939
top_features = [x for _, x in sorted(zip(abs_values, y), reverse=True)][
4040
:num_features
4141
]
4242

43-
fig, ax = plt.subplots()
43+
if ax is None:
44+
fig, ax = plt.subplots()
45+
else:
46+
fig = ax.get_figure()
4447
colors = ["r" if x >= 0 else "b" for x in top_values]
4548
ax.barh(top_features, top_values, color=colors)
4649
ax.set_xlabel(x_label)

0 commit comments

Comments
 (0)