Ruslan Mavlyutov 1 an în urmă
părinte
comite
d21645c90b

+ 1 - 6
src/seamless_communication/cli/m4t/classification_head/model.py

@@ -7,15 +7,11 @@ class ClassificationHead(nn.Module):
         super(ClassificationHead, self).__init__()
         self.num_languages = n_classes
         self.num_layers = n_layers
-
-        self.attn = nn.MultiheadAttention(embed_dim, num_heads=n_heads)
         self.layers = nn.ModuleList(
             [
                 nn.Sequential(
                     nn.Linear(embed_dim, embed_dim),
-                    nn.BatchNorm1d(embed_dim),  # normalize batch
-                    nn.ReLU(),  # activation function
-                    nn.Dropout(0.1),  # prevent overfitting
+                    nn.ReLU(),
                 )
                 for _ in range(n_layers)
             ]
@@ -24,7 +20,6 @@ class ClassificationHead(nn.Module):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         # (Batch, Seq, Embed)
-        x, _ = self.attn(x, x, x)
         x = x[:, 0]
         for layer in self.layers:
             x = layer(x)

+ 28 - 23
src/seamless_communication/cli/m4t/classification_head/train.py

@@ -96,7 +96,7 @@ def init_parser() -> argparse.ArgumentParser:
     parser.add_argument(
         "--batch_size",
         type=int,
-        default=10,
+        default=50,
         help="Batch size for training",
     )
     parser.add_argument(
@@ -129,7 +129,7 @@ def init_parser() -> argparse.ArgumentParser:
     parser.add_argument(
         "--label_smoothing",
         type=float,
-        default=0.1,
+        default=0.001,
         help=("Label smoothing"),
     )
     parser.add_argument(
@@ -171,7 +171,6 @@ def plot_losslog(
         plt.show()
 
 
-@torch.no_grad()
 def eval(
     head: torch.nn.Module,
     frozen_model: UnitYModel,
@@ -183,18 +182,18 @@ def eval(
     losses = []
     for batch_idx, (seqs, labels) in enumerate(dataloader.get_dataloader()):
         assert seqs.src_tokens is not None
-        with torch.autocast(device_type=params.device.type, dtype=params.float_dtype):
-            mask = PaddingMask(seqs.src_lengths, seqs.src_tokens.size(1)).to(
-                params.device
-            )
-            vector, _ = frozen_model.encode(
-                seqs.src_tokens.to(params.device), padding_mask=mask.to(params.device)
-            )
-            logits = head(vector)
+        mask = PaddingMask(seqs.src_lengths, seqs.src_tokens.size(1)).to(
+            params.device
+        )
+        with torch.no_grad():
+            with torch.autocast(device_type=params.device.type, dtype=params.float_dtype):
+                vector, _ = frozen_model.encode(
+                    seqs.src_tokens.to(params.device), padding_mask=mask.to(params.device)
+                )
+                logits = head(vector)
         loss = torch.nn.functional.cross_entropy(
             logits,
-            labels.to(params.device),
-            label_smoothing=0.1,
+            labels.to(params.device)
         ) / labels.size(0)
         losses.append(loss.item())
         # TODO: remove
@@ -248,7 +247,7 @@ def train(
                     logits = head(vector)
 
                 loss = torch.nn.functional.cross_entropy(
-                    logits, labels, label_smoothing=0.1
+                    logits, labels, label_smoothing=label_smoothing,
                 ) / labels.size(0)
                 if loss.isnan().any().item():
                     logger.error(seqs)
@@ -257,18 +256,23 @@ def train(
                         "Train loss is NaN! Something is wrong in the model!"
                     )
                 loss_vals.append(loss.item())
-                if update_idx % 100 == 0:
-                    eval_loss = eval(
-                        head=head,
-                        frozen_model=frozen_model,
-                        dataloader=eval_dataloader,
-                        params=params,
-                    )
+                if update_idx % 20 == 0:
+                    if 1:
+                        eval_loss = eval(
+                            head=head,
+                            frozen_model=frozen_model,
+                            dataloader=eval_dataloader,
+                            params=params,
+                        )
+                    else:
+                        eval_loss = 0
+                    head.train()
+                    frozen_model.train()
                     logger.info(
                         f" .. epoch={epoch}, "
                         f"update={update_idx}, "
-                        f"avg_train_loss={(sum(loss_vals) / len(loss_vals)):.3f}, "
-                        f"eval_loss={eval_loss:.3f}"
+                        f"avg_train_loss={(sum(loss_vals) / len(loss_vals)):.5f}, "
+                        f"eval_loss={eval_loss:.5f}"
                     )
                     loss_vals = []
 
@@ -320,6 +324,7 @@ def main() -> None:
     # Put model on selected device
     model = model.to(device)
     head = head.to(device)
+    logger.info(f"LID head: {head}")
 
     # Create daataloaders
     train_dataloader = dataloader.UnitYLanguageIDDataLoader(