Skip to content

Commit 23464c9

Browse files
Add an example
1 parent a0af209 commit 23464c9

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

README.md

+28
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ Using **ForwardHookManager**, you can extract intermediate representations in mo
3434
[This example notebook](https://github.com/yoshitomo-matsubara/torchdistill/tree/main/demo/extract_intermediate_representations.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yoshitomo-matsubara/torchdistill/blob/main/demo/extract_intermediate_representations.ipynb) [![Open In Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/yoshitomo-matsubara/torchdistill/blob/main/demo/extract_intermediate_representations.ipynb)
3535
will give you a better idea of the usage such as knowledge distillation and analysis of intermediate representations.
3636

37+
E.g., extract intermediate representations (feature map) of ResNet-18 for a random input batch
38+
```python
39+
import torch
40+
from torchvision import models
41+
from torchdistill.core.forward_hook import ForwardHookManager
42+
43+
# Define a model and choose torch device
44+
model = models.resnet18(pretrained=False)
45+
device = torch.device('cpu')
46+
47+
# Register forward hooks for modules of your interest
48+
forward_hook_manager = ForwardHookManager(device)
49+
forward_hook_manager.add_hook(model, 'conv1', requires_input=True, requires_output=False)
50+
forward_hook_manager.add_hook(model, 'layer1.0.bn2', requires_input=True, requires_output=True)
51+
forward_hook_manager.add_hook(model, 'fc', requires_input=False, requires_output=True)
52+
53+
# Define a random input batch and run the model
54+
x = torch.rand(32, 3, 224, 224)
55+
y = model(x)
56+
57+
# Extract input and/or output of the modules
58+
io_dict = forward_hook_manager.pop_io_dict()
59+
conv1_input = io_dict['conv1']['input']
60+
layer1_0_bn2_input = io_dict['layer1.0.bn2']['input']
61+
layer1_0_bn2_output = io_dict['layer1.0.bn2']['output']
62+
fc_output = io_dict['fc']['output']
63+
```
64+
3765

3866
## 1 experiment → 1 declarative PyYAML config file
3967
In ***torchdistill***, many components and PyTorch modules are abstracted e.g., models, datasets, optimizers, losses,

0 commit comments

Comments
 (0)