Skip to content

Commit 79f4b87

Browse files
authored
Merge pull request #772 from dianna-ai/750-fix-colab-paths
Fix data file paths for Colab
2 parents 2e7a10c + 4db827f commit 79f4b87

17 files changed

+211
-206
lines changed

dianna/dashboard/_movie_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import numpy as np
3-
from _shared import label_directory
3+
from _shared import data_directory
44
from scipy.special import expit as sigmoid
55
from torchtext.vocab import Vectors
66
from dianna import utils
@@ -13,7 +13,7 @@ class MovieReviewsModelRunner:
1313
def __init__(self, model, word_vectors=None, max_filter_size=5):
1414
"""Initializes the class."""
1515
if word_vectors is None:
16-
word_vectors = label_directory / 'movie_reviews_word_vectors.txt'
16+
word_vectors = data_directory / 'movie_reviews_word_vectors.txt'
1717

1818
self.run_model = utils.get_function(model)
1919
self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))

tutorials/explainers/KernelSHAP/kernelshap_geometric_shapes.ipynb

+9-14
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,8 @@
3636
"source": [
3737
"running_in_colab = 'google.colab' in str(get_ipython())\n",
3838
"if running_in_colab:\n",
39-
" # install dianna\n",
40-
" !python3 -m pip install dianna[notebooks]\n",
41-
" \n",
42-
" # download data used in this demo\n",
43-
" import os \n",
44-
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
45-
" paths_to_download = ['./data/shapes.npz', './models/geometric_shapes_model.onnx']\n",
46-
" for path in paths_to_download:\n",
47-
" !wget {base_url + path} -P {os.path.dirname(path)}"
39+
" # install dianna\n",
40+
" !python3 -m pip install dianna[notebooks]"
4841
]
4942
},
5043
{
@@ -64,14 +57,16 @@
6457
},
6558
"outputs": [],
6659
"source": [
60+
"from pathlib import Path\n",
6761
"import warnings\n",
6862
"warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf\n",
6963
"import numpy as np\n",
7064
"import dianna\n",
7165
"import onnx\n",
7266
"from onnx_tf.backend import prepare\n",
7367
"import matplotlib.pyplot as plt\n",
74-
"from pathlib import Path"
68+
"\n",
69+
"root_dir = Path(dianna.__file__).parent"
7570
]
7671
},
7772
{
@@ -108,7 +103,7 @@
108103
"outputs": [],
109104
"source": [
110105
"# load dataset\n",
111-
"data = np.load(Path('..','..','..','dianna', 'data', 'shapes.npz'))\n",
106+
"data = np.load(Path(root_dir, 'data', 'shapes.npz'))\n",
112107
"# load testing data and the related labels\n",
113108
"X_test = data['X_test'].astype(np.float32).reshape([-1, 1, 64, 64])\n",
114109
"y_test = data['y_test']"
@@ -136,7 +131,7 @@
136131
"outputs": [],
137132
"source": [
138133
"# Load saved onnx model\n",
139-
"onnx_model_path = Path('..','..','..','dianna','models', 'geometric_shapes_model.onnx')\n",
134+
"onnx_model_path = Path(root_dir, 'models', 'geometric_shapes_model.onnx')\n",
140135
"onnx_model = onnx.load(onnx_model_path)\n",
141136
"# get the output node\n",
142137
"output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]"
@@ -366,7 +361,7 @@
366361
"hash": "e7604e8ec5f09e490e10161e37a4725039efd3ab703d81b1b8a1e00d6741866c"
367362
},
368363
"kernelspec": {
369-
"display_name": "Python 3",
364+
"display_name": "Python 3 (ipykernel)",
370365
"language": "python",
371366
"name": "python3"
372367
},
@@ -380,7 +375,7 @@
380375
"name": "python",
381376
"nbconvert_exporter": "python",
382377
"pygments_lexer": "ipython3",
383-
"version": "3.7.3"
378+
"version": "3.9.1"
384379
}
385380
},
386381
"nbformat": 4,

tutorials/explainers/KernelSHAP/kernelshap_mnist.ipynb

+9-14
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,8 @@
3636
"source": [
3737
"running_in_colab = 'google.colab' in str(get_ipython())\n",
3838
"if running_in_colab:\n",
39-
" # install dianna\n",
40-
" !python3 -m pip install dianna[notebooks]\n",
41-
" \n",
42-
" # download data used in this demo\n",
43-
" import os \n",
44-
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
45-
" paths_to_download = ['./data/binary-mnist.npz', './models/mnist_model_tf.onnx']\n",
46-
" for path in paths_to_download:\n",
47-
" !wget {base_url + path} -P {os.path.dirname(path)}"
39+
" # install dianna\n",
40+
" !python3 -m pip install dianna[notebooks]"
4841
]
4942
},
5043
{
@@ -64,14 +57,16 @@
6457
},
6558
"outputs": [],
6659
"source": [
60+
"from pathlib import Path\n",
6761
"import warnings\n",
6862
"warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf\n",
6963
"import numpy as np\n",
7064
"import dianna\n",
7165
"import onnx\n",
7266
"from onnx_tf.backend import prepare\n",
7367
"import matplotlib.pyplot as plt\n",
74-
"from pathlib import Path"
68+
"\n",
69+
"root_dir = Path(dianna.__file__).parent"
7570
]
7671
},
7772
{
@@ -108,7 +103,7 @@
108103
"outputs": [],
109104
"source": [
110105
"# load dataset\n",
111-
"data = np.load(Path('..','..','..','dianna','data', 'binary-mnist.npz'))\n",
106+
"data = np.load(Path(root_dir, 'data', 'binary-mnist.npz'))\n",
112107
"# load testing data and the related labels\n",
113108
"X_test = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1]) / 255\n",
114109
"y_test = data['y_test']"
@@ -136,7 +131,7 @@
136131
"outputs": [],
137132
"source": [
138133
"# Load saved onnx model\n",
139-
"onnx_model_path = Path('..','..','..','dianna','models', 'mnist_model_tf.onnx')\n",
134+
"onnx_model_path = Path(root_dir, 'models', 'mnist_model_tf.onnx')\n",
140135
"onnx_model = onnx.load(onnx_model_path)\n",
141136
"# get the output node\n",
142137
"output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]"
@@ -333,7 +328,7 @@
333328
"hash": "e7604e8ec5f09e490e10161e37a4725039efd3ab703d81b1b8a1e00d6741866c"
334329
},
335330
"kernelspec": {
336-
"display_name": "Python 3",
331+
"display_name": "Python 3 (ipykernel)",
337332
"language": "python",
338333
"name": "python3"
339334
},
@@ -347,7 +342,7 @@
347342
"name": "python",
348343
"nbconvert_exporter": "python",
349344
"pygments_lexer": "ipython3",
350-
"version": "3.7.3"
345+
"version": "3.9.1"
351346
}
352347
},
353348
"nbformat": 4,

tutorials/explainers/KernelSHAP/kernelshap_tabular_penguin.ipynb

+9-13
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,8 @@
2525
"source": [
2626
"running_in_colab = 'google.colab' in str(get_ipython())\n",
2727
"if running_in_colab:\n",
28-
" # install dianna\n",
29-
" !python3 -m pip install dianna[notebooks]\n",
30-
" \n",
31-
" # download data used in this demo\n",
32-
" import os \n",
33-
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
34-
" paths_to_download = ['./models/penguin_model.onnx']\n",
35-
" for path in paths_to_download:\n",
36-
" !wget {base_url + path} -P {os.path.dirname(path)}"
28+
" # install dianna\n",
29+
" !python3 -m pip install dianna[notebooks]"
3730
]
3831
},
3932
{
@@ -49,6 +42,7 @@
4942
"metadata": {},
5043
"outputs": [],
5144
"source": [
45+
"from pathlib import Path\n",
5246
"import dianna\n",
5347
"import numpy as np\n",
5448
"import pandas as pd\n",
@@ -59,7 +53,9 @@
5953
"from numba.core.errors import NumbaDeprecationWarning\n",
6054
"import warnings\n",
6155
"# silence the Numba deprecation warnings in shap\n",
62-
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)"
56+
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)\n",
57+
"\n",
58+
"root_dir = Path(dianna.__file__).parent"
6359
]
6460
},
6561
{
@@ -308,7 +304,7 @@
308304
],
309305
"source": [
310306
"# load onnx model and check the prediction with it\n",
311-
"model_path = '../../../dianna/models/penguin_model.onnx'\n",
307+
"model_path = Path(root_dir, 'models', 'penguin_model.onnx')\n",
312308
"loaded_model = SimpleModelRunner(model_path)\n",
313309
"predictions = loaded_model(data_instance.reshape(1,-1).astype(np.float32))\n",
314310
"species[np.argmax(predictions)]"
@@ -411,7 +407,7 @@
411407
],
412408
"metadata": {
413409
"kernelspec": {
414-
"display_name": "Python 3",
410+
"display_name": "Python 3 (ipykernel)",
415411
"language": "python",
416412
"name": "python3"
417413
},
@@ -425,7 +421,7 @@
425421
"name": "python",
426422
"nbconvert_exporter": "python",
427423
"pygments_lexer": "ipython3",
428-
"version": "3.7.3"
424+
"version": "3.9.1"
429425
}
430426
},
431427
"nbformat": 4,

tutorials/explainers/KernelSHAP/kernelshap_tabular_weather.ipynb

+8-12
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,8 @@
2727
"source": [
2828
"running_in_colab = 'google.colab' in str(get_ipython())\n",
2929
"if running_in_colab:\n",
30-
" # install dianna\n",
31-
" !python3 -m pip install dianna[notebooks]\n",
32-
" \n",
33-
" # download data used in this demo\n",
34-
" import os\n",
35-
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
36-
" paths_to_download = ['./models/sunshine_hours_regression_model.onnx']\n",
37-
" for path in paths_to_download:\n",
38-
" !wget {base_url + path} -P {os.path.dirname(path)}"
30+
" # install dianna\n",
31+
" !python3 -m pip install dianna[notebooks]"
3932
]
4033
},
4134
{
@@ -51,6 +44,7 @@
5144
"metadata": {},
5245
"outputs": [],
5346
"source": [
47+
"from pathlib import Path\n",
5448
"import dianna\n",
5549
"import numpy as np\n",
5650
"import pandas as pd\n",
@@ -60,7 +54,9 @@
6054
"from numba.core.errors import NumbaDeprecationWarning\n",
6155
"import warnings\n",
6256
"# silence the Numba deprecation warnings in shap\n",
63-
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)"
57+
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)\n",
58+
"\n",
59+
"root_dir = Path(dianna.__file__).parent"
6460
]
6561
},
6662
{
@@ -257,7 +253,7 @@
257253
],
258254
"metadata": {
259255
"kernelspec": {
260-
"display_name": "Python 3",
256+
"display_name": "Python 3 (ipykernel)",
261257
"language": "python",
262258
"name": "python3"
263259
},
@@ -271,7 +267,7 @@
271267
"name": "python",
272268
"nbconvert_exporter": "python",
273269
"pygments_lexer": "ipython3",
274-
"version": "3.7.3"
270+
"version": "3.9.1"
275271
}
276272
},
277273
"nbformat": 4,

0 commit comments

Comments
 (0)