Skip to content

Commit 4515317

Browse files
Improve TPU compatibility
1 parent eff53ee commit 4515317

6 files changed

Lines changed: 65 additions & 28 deletions

File tree

model.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def forward(
142142
:param torch.tensor target: Target values [batch_size]
143143
:return: Loss [1] if reduction is "mean" else [9]
144144
"""
145-
return self.CrossEntropyLoss(predicted.view(-1, 9), target.view(-1).long())
145+
return self.CrossEntropyLoss(predicted, target)
146146

147147

148148
class CrossEntropyLossImageReorder(torch.nn.Module):
@@ -994,6 +994,7 @@ def __init__(
994994
learning_rate: float = 1e-5,
995995
weight_decay: float = 1e-3,
996996
label_smoothing: float = 0.0,
997+
accelerator: str = None,
997998
):
998999
"""
9991000
INIT
@@ -1052,6 +1053,7 @@ def __init__(
10521053
self.learning_rate = learning_rate
10531054
self.weight_decay = weight_decay
10541055
self.label_smoothing = label_smoothing
1056+
self.accelerator = accelerator
10551057

10561058
if self.encoder_type == "transformer":
10571059
self.model = TEDD1104Transformer(
@@ -1190,21 +1192,31 @@ def training_step(self, batch, batch_idx):
11901192
preds = self.model(x)
11911193
loss = self.criterion(preds, y)
11921194
self.total_batches += 1
1193-
self.running_loss += loss.item()
1194-
self.log("Train/loss", loss, sync_dist=True)
1195-
self.log(
1196-
"Train/running_loss", self.running_loss / self.total_batches, sync_dist=True
1197-
)
1195+
if self.accelerator != "tpu":
1196+
self.running_loss += loss.item()
1197+
self.log("Train/loss", loss, sync_dist=True)
1198+
self.log(
1199+
"Train/running_loss",
1200+
self.running_loss / self.total_batches,
1201+
sync_dist=True,
1202+
)
1203+
else:
1204+
if self.total_batches % 200 == 0:
1205+
self.log("Train/loss", loss, sync_dist=True)
11981206

1199-
return {"preds": preds.detach(), "y": y, "loss": loss}
1207+
return (
1208+
{"preds": preds.detach(), "y": y, "loss": loss}
1209+
if self.accelerator != "tpu"
1210+
else {"loss": loss}
1211+
)
12001212

12011213
def training_step_end(self, outputs):
12021214
"""
12031215
Training step end.
12041216
12051217
:param outputs: outputs of the training step
12061218
"""
1207-
if self.control_mode == "keyboard":
1219+
if self.accelerator != "tpu" and self.control_mode == "keyboard":
12081220
self.train_accuracy(outputs["preds"], outputs["y"])
12091221
self.log(
12101222
"Train/acc_k@1_macro",
@@ -1338,6 +1350,7 @@ def __init__(
13381350
learning_rate: float = 1e-5,
13391351
weight_decay: float = 1e-3,
13401352
encoder_type: str = "transformer",
1353+
accelerator: str = None,
13411354
):
13421355

13431356
"""
@@ -1376,6 +1389,7 @@ def __init__(
13761389
self.learning_rate = learning_rate
13771390
self.weight_decay = weight_decay
13781391
self.encoder_type = encoder_type
1392+
self.accelerator = accelerator
13791393

13801394
self.model = TEDD1104TransformerForImageReordering(
13811395
cnn_model_name=self.cnn_model_name,
@@ -1427,24 +1441,37 @@ def training_step(self, batch, batch_idx):
14271441
preds = self.model(x)
14281442
loss = self.criterion(preds, y)
14291443
self.total_batches += 1
1430-
self.running_loss += loss.item()
1431-
self.log("Train/loss", loss, sync_dist=True)
1432-
self.log(
1433-
"Train/running_loss", self.running_loss / self.total_batches, sync_dist=True
1444+
1445+
if self.accelerator != "tpu":
1446+
self.running_loss += loss.item()
1447+
self.log("Train/loss", loss, sync_dist=True)
1448+
self.log(
1449+
"Train/running_loss",
1450+
self.running_loss / self.total_batches,
1451+
sync_dist=True,
1452+
)
1453+
else:
1454+
if self.total_batches % 200 == 0:
1455+
self.log("Train/loss", loss, sync_dist=True)
1456+
1457+
return (
1458+
{"preds": torch.argmax(preds.detach(), dim=-1), "y": y, "loss": loss}
1459+
if self.accelerator != "tpu"
1460+
else {"loss": loss}
14341461
)
1435-
return {"preds": torch.argmax(preds.detach(), dim=-1), "y": y, "loss": loss}
14361462

14371463
def training_step_end(self, outputs):
14381464
"""
14391465
Training step end.
14401466
14411467
:param outputs: outputs of the training step
14421468
"""
1443-
self.train_accuracy(outputs["preds"], outputs["y"])
1444-
self.log(
1445-
"Train/acc",
1446-
self.train_accuracy,
1447-
)
1469+
if self.accelerator != "tpu":
1470+
self.train_accuracy(outputs["preds"], outputs["y"])
1471+
self.log(
1472+
"Train/acc",
1473+
self.train_accuracy,
1474+
)
14481475

14491476
def validation_step(self, batch, batch_idx):
14501477
"""

train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def train(
8686
)
8787
checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
8888

89+
model.accelerator = accelerator
90+
8991
trainer = pl.Trainer(
9092
devices=devices,
9193
accelerator=accelerator,
92-
precision=precision,
94+
precision=precision if precision == "bf16" else int(precision),
9395
strategy=strategy,
9496
val_check_interval=val_check_interval,
9597
accumulate_grad_batches=accumulation_steps,
@@ -212,6 +214,7 @@ def train_new_model(
212214
weight_decay=weight_decay,
213215
weights=variable_weights,
214216
label_smoothing=label_smoothing,
217+
accelerator=accelerator,
215218
)
216219

217220
else:
@@ -334,10 +337,12 @@ def continue_training(
334337
)
335338
checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
336339

340+
model.accelerator = accelerator
341+
337342
trainer = pl.Trainer(
338343
devices=devices,
339344
accelerator=accelerator,
340-
precision=precision,
345+
precision=precision if precision == "bf16" else int(precision),
341346
strategy=strategy,
342347
val_check_interval=val_check_interval,
343348
accumulate_grad_batches=accumulation_steps,

train_reorder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def train(
8484
trainer = pl.Trainer(
8585
devices=devices,
8686
accelerator=accelerator,
87-
precision=precision,
87+
precision=precision if precision == "bf16" else int(precision),
8888
strategy=strategy,
8989
val_check_interval=val_check_interval,
9090
accumulate_grad_batches=accumulation_steps,
@@ -181,6 +181,7 @@ def train_new_model(
181181
sequence_size=sequence_size,
182182
learning_rate=learning_rate,
183183
weight_decay=weight_decay,
184+
accelerator=accelerator,
184185
)
185186

186187
train(
@@ -290,10 +291,12 @@ def continue_training(
290291
)
291292
checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
292293

294+
model.accelerator = accelerator
295+
293296
trainer = pl.Trainer(
294297
devices=devices,
295298
accelerator=accelerator,
296-
precision=precision,
299+
precision=precision if precision == "bf16" else int(precision),
297300
strategy=strategy,
298301
val_check_interval=val_check_interval,
299302
accumulate_grad_batches=accumulation_steps,

training_scripts/GPU/TEDD_1140_large.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ python3 train.py --train_new \
77
--accumulation_steps 4 \
88
--max_epochs 40 \
99
--cnn_model_name efficientnet_b7 \
10-
--num_layers_encoder 6 \
10+
--num_layers_encoder 4 \
1111
--embedded_size 512 \
1212
--learning_rate 1e-5 \
1313
--mask_prob 0.2 \
1414
--dropout_cnn_out 0.3 \
15-
--dropout_encoder 0.15 \
15+
--dropout_encoder 0.1 \
1616
--dropout_encoder_features 0.3 \
1717
--control_mode keyboard \
1818
--val_check_interval 0.5 \

training_scripts/TPU/TEDD_1140_base.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python3 train.py --train_new \
2121
--val_check_interval 0.5 \
2222
--hide_map_prob 0.4 \
2323
--devices 8 \
24-
--accelerator tpu
24+
--accelerator tpu \
25+
--report_to tensorboard
2526

2627

training_scripts/TPU/TEDD_1140_large.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@ python3 train.py --train_new \
1010
--accumulation_steps 1 \
1111
--max_epochs 40 \
1212
--cnn_model_name efficientnet_b7 \
13-
--num_layers_encoder 6 \
13+
--num_layers_encoder 4 \
1414
--embedded_size 512 \
1515
--learning_rate 1e-5 \
1616
--mask_prob 0.2 \
1717
--dropout_cnn_out 0.3 \
18-
--dropout_encoder 0.15 \
18+
--dropout_encoder 0.1 \
1919
--dropout_encoder_features 0.3 \
2020
--control_mode keyboard \
2121
--val_check_interval 0.5 \
2222
--hide_map_prob 0.4 \
2323
--devices 8 \
24-
--accelerator tpu
24+
--accelerator tpu \
25+
--report_to wandb
2526

2627

0 commit comments

Comments
 (0)