Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support dinov2 backbone #1522

Merged
merged 19 commits into from
May 5, 2023
7 changes: 6 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
version: 2

# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.8"

formats:
- epub

python:
version: 3.8
install:
- requirements: requirements/docs.txt
- requirements: requirements/readthedocs.txt
58 changes: 58 additions & 0 deletions configs/dinov2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# DINOv2

> [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)

<!-- [ALGORITHM] -->

## Abstract

The recent breakthroughs in natural language processing for model pretraining on large quantities of data have opened the way for similar foundation models in computer vision. These models could greatly simplify the use of images in any system by producing allpurpose visual features, i.e., features that work across image distributions and tasks without finetuning. This work shows that existing pretraining methods, especially self-supervised methods, can produce such features if trained on enough curated data from diverse sources. We revisit existing approaches and combine different techniques to scale our pretraining in terms of data and model size. Most of the technical contributions aim at accelerating and stabilizing the training at scale. In terms of data, we propose an automatic pipeline to build a dedicated, diverse, and curated image dataset instead of uncurated data, as typically done in the self-supervised literature. In terms of models, we train a ViT model (Dosovitskiy et al., 2020) with 1B parameters and distill it into a series of smaller models that surpass the best available all-purpose features, OpenCLIP (Ilharco et al., 2021) on most of the benchmarks at image and pixel levels.

<div align=center>
<img src="https://user-images.githubusercontent.com/36138628/234560516-b495795c-c75c-444c-a712-bb61a3de444e.png" width="70%"/>
</div>

## How to use it?

<!-- [TABS-BEGIN] -->

**Use the model**

```python
import torch
from mmpretrain import get_model

model = get_model('vit-small-p14_dinov2-pre_3rdparty', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```

<!-- [TABS-END] -->

## Models and results

### Pretrained models

| Model | Params (M) | Flops (G) | Config | Download |
| :------------------------------------ | :--------: | :-------: | :--------------------------------------------: | :------------------------------------------------------------------------------------------------: |
| `vit-small-p14_dinov2-pre_3rdparty`\* | 22.06 | 46.76 | [config](vit-small-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-small-p14_dinov2-pre_3rdparty_20230426-5641ca5a.pth) |
| `vit-base-p14_dinov2-pre_3rdparty`\* | 86.58 | 152.00 | [config](vit-base-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-base-p14_dinov2-pre_3rdparty_20230426-ba246503.pth) |
| `vit-large-p14_dinov2-pre_3rdparty`\* | 304.00 | 507.00 | [config](vit-large-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-large-p14_dinov2-pre_3rdparty_20230426-f3302d9e.pth) |
| `vit-giant-p14_dinov2-pre_3rdparty`\* | 1136.00 | 1784.00 | [config](vit-giant-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-giant-p14_dinov2-pre_3rdparty_20230426-2934a630.pth) |

*Models with * are converted from the [official repo](https://github.com/facebookresearch/dinov2). The config files of these models are only for inference. We haven't reprodcue the training results.*

## Citation

```bibtex
@misc{oquab2023dinov2,
title={DINOv2: Learning Robust Visual Features without Supervision},
author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
journal={arXiv:2304.07193},
year={2023}
}
```
73 changes: 73 additions & 0 deletions configs/dinov2/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
Collections:
- Name: DINOv2
Metadata:
Architecture:
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
Paper:
Title: 'DINOv2: Learning Robust Visual Features without Supervision'
URL: https://arxiv.org/abs/2304.07193
README: configs/dinov2/README.md
Code:
URL: null
Version: null

Models:
- Name: vit-small-p14_dinov2-pre_3rdparty
Metadata:
FLOPs: 46762000000
Parameters: 22056000
Training Data:
- LVD-142M
In Collection: DINOv2
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-small-p14_dinov2-pre_3rdparty_20230426-5641ca5a.pth
Config: configs/dinov2/vit-small-p14_dinov2-pre_headless.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth
Code: https://github.com/facebookresearch/dinov2

- Name: vit-base-p14_dinov2-pre_3rdparty
Metadata:
FLOPs: 152000000000
Parameters: 86580000
Training Data:
- LVD-142M
In Collection: DINOv2
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-base-p14_dinov2-pre_3rdparty_20230426-ba246503.pth
Config: configs/dinov2/vit-base-p14_dinov2-pre_headless.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth
Code: https://github.com/facebookresearch/dinov2

- Name: vit-large-p14_dinov2-pre_3rdparty
Metadata:
FLOPs: 507000000000
Parameters: 304000000
Training Data:
- LVD-142M
In Collection: DINOv2
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-large-p14_dinov2-pre_3rdparty_20230426-f3302d9e.pth
Config: configs/dinov2/vit-large-p14_dinov2-pre_headless.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth
Code: https://github.com/facebookresearch/dinov2

- Name: vit-giant-p14_dinov2-pre_3rdparty
Metadata:
FLOPs: 1784000000000
Parameters: 1136000000
Training Data:
- LVD-142M
In Collection: DINOv2
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-giant-p14_dinov2-pre_3rdparty_20230426-2934a630.pth
Config: configs/dinov2/vit-giant-p14_dinov2-pre_headless.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth
Code: https://github.com/facebookresearch/dinov2
20 changes: 20 additions & 0 deletions configs/dinov2/vit-base-p14_dinov2-pre_headless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=518,
patch_size=14,
layer_scale_init_value=1e-5,
),
neck=None,
head=None)

data_preprocessor = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
21 changes: 21 additions & 0 deletions configs/dinov2/vit-giant-p14_dinov2-pre_headless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='dinov2-giant',
img_size=518,
patch_size=14,
layer_scale_init_value=1e-5,
layer_cfgs=dict(ffn_type='swiglu_fused'),
),
neck=None,
head=None)

data_preprocessor = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
20 changes: 20 additions & 0 deletions configs/dinov2/vit-large-p14_dinov2-pre_headless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='large',
img_size=518,
patch_size=14,
layer_scale_init_value=1e-5,
),
neck=None,
head=None)

data_preprocessor = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
20 changes: 20 additions & 0 deletions configs/dinov2/vit-small-p14_dinov2-pre_headless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='dinov2-small',
img_size=518,
patch_size=14,
layer_scale_init_value=1e-5,
),
neck=None,
head=None)

data_preprocessor = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
48 changes: 37 additions & 11 deletions mmpretrain/models/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.registry import MODELS
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
to_2tuple)
from ..utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer,
resize_pos_embed, to_2tuple)
from .base_backbone import BaseBackbone


Expand All @@ -21,6 +21,8 @@ class TransformerEncoderLayer(BaseModule):
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
layer_scale_init_value (float or torch.Tensor): Init value of layer
scale. Defaults to 0.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Expand All @@ -29,6 +31,7 @@ class TransformerEncoderLayer(BaseModule):
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Expand All @@ -41,11 +44,13 @@ def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
layer_scale_init_value=0.,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
ffn_type='origin',
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
Expand All @@ -61,17 +66,27 @@ def __init__(self,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias)
qkv_bias=qkv_bias,
layer_scale_init_value=layer_scale_init_value)

self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)

self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
if ffn_type == 'origin':
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
layer_scale_init_value=layer_scale_init_value)
elif ffn_type == 'swiglu_fused':
self.ffn = SwiGLUFFNFused(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
layer_scale_init_value=layer_scale_init_value)
else:
raise NotImplementedError

@property
def norm1(self):
Expand Down Expand Up @@ -147,6 +162,8 @@ class tokens with shape (B, L, C).
-1 means not freezing any parameters. Defaults to -1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
layer_scale_init_value (float or torch.Tensor): Init value of layer
scale. Defaults to 0.
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
Expand Down Expand Up @@ -203,7 +220,7 @@ class tokens with shape (B, L, C).
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small'], {
['deit-s', 'deit-small', 'dinov2-s', 'dinov2-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
Expand All @@ -216,6 +233,13 @@ class tokens with shape (B, L, C).
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
**dict.fromkeys(
['dinov2-g', 'dinov2-giant'], {
'embed_dims': 1536,
'num_layers': 40,
'num_heads': 24,
'feedforward_channels': 6144
}),
}
num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
Expand All @@ -235,6 +259,7 @@ def __init__(self,
with_cls_token=True,
frozen_stages=-1,
interpolate_mode='bicubic',
layer_scale_init_value=0.,
patch_cfg=dict(),
layer_cfgs=dict(),
pre_norm=False,
Expand Down Expand Up @@ -322,6 +347,7 @@ def __init__(self,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
layer_scale_init_value=layer_scale_init_value,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=qkv_bias,
Expand Down
3 changes: 3 additions & 0 deletions mmpretrain/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
build_2d_sincos_position_embedding)
from .res_layer_extra_norm import ResLayerExtraNorm
from .se_layer import SELayer
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .vector_quantizer import NormEMAVectorQuantizer

__all__ = [
Expand Down Expand Up @@ -69,4 +70,6 @@
'VideoDataPreprocessor',
'CosineEMA',
'ResLayerExtraNorm',
'SwiGLUFFN',
'SwiGLUFFNFused',
]
Loading