@@ -142,7 +142,7 @@ def forward(
142142 :param torch.tensor target: Target values [batch_size]
143143 :return: Loss [1] if reduction is "mean" else [9]
144144 """
145- return self .CrossEntropyLoss (predicted . view ( - 1 , 9 ), target . view ( - 1 ). long () )
145+ return self .CrossEntropyLoss (predicted , target )
146146
147147
148148class CrossEntropyLossImageReorder (torch .nn .Module ):
@@ -994,6 +994,7 @@ def __init__(
994994 learning_rate : float = 1e-5 ,
995995 weight_decay : float = 1e-3 ,
996996 label_smoothing : float = 0.0 ,
997+ accelerator : str = None ,
997998 ):
998999 """
9991000 INIT
@@ -1052,6 +1053,7 @@ def __init__(
10521053 self .learning_rate = learning_rate
10531054 self .weight_decay = weight_decay
10541055 self .label_smoothing = label_smoothing
1056+ self .accelerator = accelerator
10551057
10561058 if self .encoder_type == "transformer" :
10571059 self .model = TEDD1104Transformer (
@@ -1190,21 +1192,31 @@ def training_step(self, batch, batch_idx):
11901192 preds = self .model (x )
11911193 loss = self .criterion (preds , y )
11921194 self .total_batches += 1
1193- self .running_loss += loss .item ()
1194- self .log ("Train/loss" , loss , sync_dist = True )
1195- self .log (
1196- "Train/running_loss" , self .running_loss / self .total_batches , sync_dist = True
1197- )
1195+ if self .accelerator != "tpu" :
1196+ self .running_loss += loss .item ()
1197+ self .log ("Train/loss" , loss , sync_dist = True )
1198+ self .log (
1199+ "Train/running_loss" ,
1200+ self .running_loss / self .total_batches ,
1201+ sync_dist = True ,
1202+ )
1203+ else :
1204+ if self .total_batches % 200 == 0 :
1205+ self .log ("Train/loss" , loss , sync_dist = True )
11981206
1199- return {"preds" : preds .detach (), "y" : y , "loss" : loss }
1207+ return (
1208+ {"preds" : preds .detach (), "y" : y , "loss" : loss }
1209+ if self .accelerator != "tpu"
1210+ else {"loss" : loss }
1211+ )
12001212
12011213 def training_step_end (self , outputs ):
12021214 """
12031215 Training step end.
12041216
12051217 :param outputs: outputs of the training step
12061218 """
1207- if self .control_mode == "keyboard" :
1219+ if self .accelerator != "tpu" and self . control_mode == "keyboard" :
12081220 self .train_accuracy (outputs ["preds" ], outputs ["y" ])
12091221 self .log (
12101222 "Train/acc_k@1_macro" ,
@@ -1338,6 +1350,7 @@ def __init__(
13381350 learning_rate : float = 1e-5 ,
13391351 weight_decay : float = 1e-3 ,
13401352 encoder_type : str = "transformer" ,
1353+ accelerator : str = None ,
13411354 ):
13421355
13431356 """
@@ -1376,6 +1389,7 @@ def __init__(
13761389 self .learning_rate = learning_rate
13771390 self .weight_decay = weight_decay
13781391 self .encoder_type = encoder_type
1392+ self .accelerator = accelerator
13791393
13801394 self .model = TEDD1104TransformerForImageReordering (
13811395 cnn_model_name = self .cnn_model_name ,
@@ -1427,24 +1441,37 @@ def training_step(self, batch, batch_idx):
14271441 preds = self .model (x )
14281442 loss = self .criterion (preds , y )
14291443 self .total_batches += 1
1430- self .running_loss += loss .item ()
1431- self .log ("Train/loss" , loss , sync_dist = True )
1432- self .log (
1433- "Train/running_loss" , self .running_loss / self .total_batches , sync_dist = True
1444+
1445+ if self .accelerator != "tpu" :
1446+ self .running_loss += loss .item ()
1447+ self .log ("Train/loss" , loss , sync_dist = True )
1448+ self .log (
1449+ "Train/running_loss" ,
1450+ self .running_loss / self .total_batches ,
1451+ sync_dist = True ,
1452+ )
1453+ else :
1454+ if self .total_batches % 200 == 0 :
1455+ self .log ("Train/loss" , loss , sync_dist = True )
1456+
1457+ return (
1458+ {"preds" : torch .argmax (preds .detach (), dim = - 1 ), "y" : y , "loss" : loss }
1459+ if self .accelerator != "tpu"
1460+ else {"loss" : loss }
14341461 )
1435- return {"preds" : torch .argmax (preds .detach (), dim = - 1 ), "y" : y , "loss" : loss }
14361462
14371463 def training_step_end (self , outputs ):
14381464 """
14391465 Training step end.
14401466
14411467 :param outputs: outputs of the training step
14421468 """
1443- self .train_accuracy (outputs ["preds" ], outputs ["y" ])
1444- self .log (
1445- "Train/acc" ,
1446- self .train_accuracy ,
1447- )
1469+ if self .accelerator != "tpu" :
1470+ self .train_accuracy (outputs ["preds" ], outputs ["y" ])
1471+ self .log (
1472+ "Train/acc" ,
1473+ self .train_accuracy ,
1474+ )
14481475
14491476 def validation_step (self , batch , batch_idx ):
14501477 """
0 commit comments