-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathrun.py
116 lines (91 loc) · 4.09 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import print_function, absolute_import
from reid.eug import *
from reid import datasets
from reid import models
import numpy as np
import torch
import argparse
import os
from reid.utils.logging import Logger
import os.path as osp
import sys
from torch.backends import cudnn
from reid.utils.serialization import load_checkpoint
from torch import nn
import time
import pickle
def resume(args):
import re
pattern=re.compile(r'step_(\d+)\.ckpt')
start_step = -1
ckpt_file = ""
# find start step
files = os.listdir(args.logs_dir)
files.sort()
for filename in files:
try:
iter_ = int(pattern.search(filename).groups()[0])
if iter_ > start_step:
start_step = iter_
ckpt_file = osp.join(args.logs_dir, filename)
except:
continue
# if need resume
if start_step >= 0:
print("continued from iter step", start_step)
return start_step, ckpt_file
def main(args):
cudnn.benchmark = True
cudnn.enabled = True
save_path = args.logs_dir
total_step = 100//args.EF + 1
sys.stdout = Logger(osp.join(args.logs_dir, 'log'+ str(args.EF)+ time.strftime(".%m_%d_%H:%M:%S") + '.txt'))
# get all the labeled and unlabeled data for training
dataset_all = datasets.create(args.dataset, osp.join(args.data_dir, args.dataset))
num_all_examples = len(dataset_all.train)
l_data, u_data = get_one_shot_in_cam1(dataset_all, load_path="./examples/oneshot_{}_used_in_paper.pickle".format(dataset_all.name))
resume_step, ckpt_file = -1, ''
if args.resume:
resume_step, ckpt_file = resume(args)
# initial the EUG algorithm
eug = EUG(model_name=args.arch, batch_size=args.batch_size, mode=args.mode, num_classes=dataset_all.num_train_ids,
data_dir=dataset_all.images_dir, l_data=l_data, u_data=u_data, save_path=args.logs_dir, max_frames=args.max_frames)
new_train_data = l_data
for step in range(total_step):
# for resume
if step < resume_step:
continue
nums_to_select = min(int( len(u_data) * (step+1) * args.EF / 100 ), len(u_data))
print("This is running {} with EF={}%, step {}:\t Nums_to_be_select {}, \t Logs-dir {}".format(
args.mode, args.EF, step, nums_to_select, save_path))
# train the model or load ckpt
eug.train(new_train_data, step, epochs=70, step_size=55, init_lr=0.1) if step != resume_step else eug.resume(ckpt_file, step)
# pseudo-label and confidence score
pred_y, pred_score = eug.estimate_label()
# select data
selected_idx = eug.select_top_data(pred_score, nums_to_select)
# add new data
new_train_data = eug.generate_new_train_data(selected_idx, pred_y)
# evluate
eug.evaluate(dataset_all.query, dataset_all.gallery)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Exploit the Unknown Gradually')
parser.add_argument('-d', '--dataset', type=str, default='mars',
choices=datasets.names())
parser.add_argument('-b', '--batch-size', type=int, default=16)
parser.add_argument('-a', '--arch', type=str, default='avg_pool',
choices=models.names())
parser.add_argument('-i', '--iter-step', type=int, default=5)
parser.add_argument('-g', '--gamma', type=float, default=0.3)
parser.add_argument('-l', '--l', type=float)
parser.add_argument('--EF', type=int, default=10)
working_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument('--data_dir', type=str, metavar='PATH',
default=os.path.join(working_dir,'data'))
parser.add_argument('--logs_dir', type=str, metavar='PATH',
default=os.path.join(working_dir,'logs'))
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--continuous', action="store_true")
parser.add_argument('--mode', type=str, choices=["Classification", "Dissimilarity"])
parser.add_argument('--max_frames', type=int, default=900)
main(parser.parse_args())