-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfast_demo.py
More file actions
42 lines (34 loc) · 1.63 KB
/
Copy pathfast_demo.py
File metadata and controls
42 lines (34 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import random
from data_loader.data_loaders import OmniglotDataLoaderCreator
from model import accuracy, accuracy_oneshot, top_k_acc
from model.model import CnnKoch2015
from trainer import *
if __name__ == '__main__':
random.seed(72)
np.random.seed(72)
torch.manual_seed(72)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# export CUDA_VISIBLE_DEVICES=1
print(torch.cuda.device_count())
device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
device_ids = list(range(torch.cuda.device_count()))
print(device_ids)
print("Using device", device)
omniglot_dataloader_creator = OmniglotDataLoaderCreator("./data/", 300, 100, 40, 40, False, 8)
train_loader = omniglot_dataloader_creator.load_train(128, True)
val_loader = omniglot_dataloader_creator.load_validation(False, 128)
val_oneshot_loader = omniglot_dataloader_creator.load_validation(True, 3)
test_oneshot_loader = omniglot_dataloader_creator.load_test(3)
koch2015 = CnnKoch2015()
learning_rate = 0.00006
epochs = 500
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
optimizer = torch.optim.Adam(koch2015.parameters(), lr=learning_rate)
metric_ftns = [accuracy]
metric_ftns_oneshot = [accuracy_oneshot, top_k_acc]
trainer = OmniglotTrainer(koch2015, criterion, metric_ftns, metric_ftns_oneshot, optimizer, device, [], epochs,
"./saved", "max val accuracy_oneshot", train_loader, val_loader, val_oneshot_loader,
test_oneshot_loader)
koch2015.summary(device.type)
trainer.train()