@@ -19,14 +19,14 @@ class RISE:
19
19
required_labels = ('batch' , 'channels' )
20
20
21
21
def __init__ (self , n_masks = 1000 , feature_res = 8 , p_keep = 0.5 , # pylint: disable=too-many-arguments
22
- axes_labels = None , preprocess_function = None ):
22
+ axis_labels = None , preprocess_function = None ):
23
23
"""RISE initializer.
24
24
25
25
Args:
26
26
n_masks (int): Number of masks to generate.
27
27
feature_res (int): Resolution of features in masks.
28
28
p_keep (float): Fraction of image to keep in each mask
29
- axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
29
+ axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
30
30
If a list, the name of each axis where the index
31
31
in the list is the axis index
32
32
preprocess_function (callable, optional): Function to preprocess input data with
@@ -37,7 +37,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=t
37
37
self .preprocess_function = preprocess_function
38
38
self .masks = None
39
39
self .predictions = None
40
- self .axes_labels = axes_labels if axes_labels is not None else []
40
+ self .axis_labels = axis_labels if axis_labels is not None else []
41
41
42
42
def explain_text (self , model_or_function , input_text , labels = (0 ,), batch_size = 100 ):
43
43
"""Runs the RISE explainer on text.
@@ -136,7 +136,7 @@ def explain_image(self, model_or_function, input_data, labels=None, batch_size=1
136
136
Explanation heatmap for each class (np.ndarray).
137
137
"""
138
138
# convert data to xarray
139
- input_data = utils .to_xarray (input_data , self .axes_labels , RISE .required_labels )
139
+ input_data = utils .to_xarray (input_data , self .axis_labels , RISE .required_labels )
140
140
# batch axis should always be first
141
141
input_data = utils .move_axis (input_data , 'batch' , 0 )
142
142
input_data , full_preprocess_function = self ._prepare_image_data (input_data )
0 commit comments