anonymous-good commited on
Commit
72eba74
·
1 Parent(s): 9a1e185
Files changed (2) hide show
  1. build_model.py +80 -0
  2. models/build_model.py +80 -0
build_model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Any, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.layers import trunc_normal_
7
+ from copy import deepcopy
8
+ import os
9
+ import torch.backends.cudnn as cudnn
10
+
11
+ import models.vision_transformer as vits
12
+
13
+ class vit(nn.Module):
14
+
15
+ def __init__(self, model_size="base", freeze_transformer=True, pretrained_weights=None):
16
+ super(ibotvit, self).__init__()
17
+ self.model_size = model_size
18
+ self.freeze_transformer = freeze_transformer
19
+ self.pretrained_weights = pretrained_weights
20
+
21
+ # Loading a model with registers
22
+ n_register_tokens = 4
23
+
24
+ if model_size == "vit_small":
25
+ self.embedding_size = 384
26
+
27
+ elif model_size == "vit_base":
28
+ self.embedding_size = 768
29
+
30
+ elif model_size == "vit_large":
31
+ self.embedding_size = 1024
32
+
33
+ elif model_size == "giant":
34
+ self.embedding_size = 1536
35
+
36
+ # Load state_dict
37
+ model = vits.__dict__[model_size](patch_size=16)
38
+ self.transformer = deepcopy(model)
39
+
40
+ # Freeze transformer if specified
41
+ if self.freeze_transformer:
42
+ for param in self.transformer.parameters():
43
+ param.requires_grad = False
44
+
45
+
46
+ if self.pretrained_weights and os.path.isfile(self.pretrained_weights):
47
+ state_dict = torch.load(self.pretrained_weights, map_location="cpu")
48
+ if 'teacher' in state_dict:
49
+ state_dict = state_dict['teacher']
50
+ elif 'model' in state_dict:
51
+ state_dict = state_dict['model']
52
+
53
+ # remove `backbone.` prefix induced by multicrop wrapper
54
+ state_dict = {
55
+ (k[len("teacher."):] if k.startswith("teacher.") else k): v
56
+ for k, v in state_dict.items()
57
+ }
58
+ state_dict = {
59
+ (k[len("backbone."):] if k.startswith("backbone.") else k): v
60
+ for k, v in state_dict.items()
61
+ }
62
+ msg = self.transformer.load_state_dict(state_dict, strict=False)
63
+ print(model_size, msg)
64
+
65
+
66
+ def forward(self, x):
67
+ x = self.transformer(x)
68
+
69
+ return x
70
+
71
+
72
+
73
+ def build_model(args):
74
+
75
+ net = vit("vit_base", freeze_transformer=True, pretrained_weights=args.pretrained_weights)
76
+ net.cuda()
77
+
78
+
79
+
80
+ return net
models/build_model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Any, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.layers import trunc_normal_
7
+ from copy import deepcopy
8
+ import os
9
+ import torch.backends.cudnn as cudnn
10
+
11
+ import models.vision_transformer as vits
12
+
13
+ class vit(nn.Module):
14
+
15
+ def __init__(self, model_size="base", freeze_transformer=True, pretrained_weights=None):
16
+ super(ibotvit, self).__init__()
17
+ self.model_size = model_size
18
+ self.freeze_transformer = freeze_transformer
19
+ self.pretrained_weights = pretrained_weights
20
+
21
+ # Loading a model with registers
22
+ n_register_tokens = 4
23
+
24
+ if model_size == "vit_small":
25
+ self.embedding_size = 384
26
+
27
+ elif model_size == "vit_base":
28
+ self.embedding_size = 768
29
+
30
+ elif model_size == "vit_large":
31
+ self.embedding_size = 1024
32
+
33
+ elif model_size == "giant":
34
+ self.embedding_size = 1536
35
+
36
+ # Load state_dict
37
+ model = vits.__dict__[model_size](patch_size=16)
38
+ self.transformer = deepcopy(model)
39
+
40
+ # Freeze transformer if specified
41
+ if self.freeze_transformer:
42
+ for param in self.transformer.parameters():
43
+ param.requires_grad = False
44
+
45
+
46
+ if self.pretrained_weights and os.path.isfile(self.pretrained_weights):
47
+ state_dict = torch.load(self.pretrained_weights, map_location="cpu")
48
+ if 'teacher' in state_dict:
49
+ state_dict = state_dict['teacher']
50
+ elif 'model' in state_dict:
51
+ state_dict = state_dict['model']
52
+
53
+ # remove `backbone.` prefix induced by multicrop wrapper
54
+ state_dict = {
55
+ (k[len("teacher."):] if k.startswith("teacher.") else k): v
56
+ for k, v in state_dict.items()
57
+ }
58
+ state_dict = {
59
+ (k[len("backbone."):] if k.startswith("backbone.") else k): v
60
+ for k, v in state_dict.items()
61
+ }
62
+ msg = self.transformer.load_state_dict(state_dict, strict=False)
63
+ print(model_size, msg)
64
+
65
+
66
+ def forward(self, x):
67
+ x = self.transformer(x)
68
+
69
+ return x
70
+
71
+
72
+
73
+ def build_model(args):
74
+
75
+ net = vit("vit_base", freeze_transformer=True, pretrained_weights=args.pretrained_weights)
76
+ net.cuda()
77
+
78
+
79
+
80
+ return net