|
@@ -96,7 +96,7 @@ def init_parser() -> argparse.ArgumentParser:
|
|
|
parser.add_argument(
|
|
|
"--batch_size",
|
|
|
type=int,
|
|
|
- default=10,
|
|
|
+ default=50,
|
|
|
help="Batch size for training",
|
|
|
)
|
|
|
parser.add_argument(
|
|
@@ -129,7 +129,7 @@ def init_parser() -> argparse.ArgumentParser:
|
|
|
parser.add_argument(
|
|
|
"--label_smoothing",
|
|
|
type=float,
|
|
|
- default=0.1,
|
|
|
+ default=0.001,
|
|
|
help=("Label smoothing"),
|
|
|
)
|
|
|
parser.add_argument(
|
|
@@ -171,7 +171,6 @@ def plot_losslog(
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
-@torch.no_grad()
|
|
|
def eval(
|
|
|
head: torch.nn.Module,
|
|
|
frozen_model: UnitYModel,
|
|
@@ -183,18 +182,18 @@ def eval(
|
|
|
losses = []
|
|
|
for batch_idx, (seqs, labels) in enumerate(dataloader.get_dataloader()):
|
|
|
assert seqs.src_tokens is not None
|
|
|
- with torch.autocast(device_type=params.device.type, dtype=params.float_dtype):
|
|
|
- mask = PaddingMask(seqs.src_lengths, seqs.src_tokens.size(1)).to(
|
|
|
- params.device
|
|
|
- )
|
|
|
- vector, _ = frozen_model.encode(
|
|
|
- seqs.src_tokens.to(params.device), padding_mask=mask.to(params.device)
|
|
|
- )
|
|
|
- logits = head(vector)
|
|
|
+ mask = PaddingMask(seqs.src_lengths, seqs.src_tokens.size(1)).to(
|
|
|
+ params.device
|
|
|
+ )
|
|
|
+ with torch.no_grad():
|
|
|
+ with torch.autocast(device_type=params.device.type, dtype=params.float_dtype):
|
|
|
+ vector, _ = frozen_model.encode(
|
|
|
+ seqs.src_tokens.to(params.device), padding_mask=mask.to(params.device)
|
|
|
+ )
|
|
|
+ logits = head(vector)
|
|
|
loss = torch.nn.functional.cross_entropy(
|
|
|
logits,
|
|
|
- labels.to(params.device),
|
|
|
- label_smoothing=0.1,
|
|
|
+ labels.to(params.device)
|
|
|
) / labels.size(0)
|
|
|
losses.append(loss.item())
|
|
|
# TODO: remove
|
|
@@ -248,7 +247,7 @@ def train(
|
|
|
logits = head(vector)
|
|
|
|
|
|
loss = torch.nn.functional.cross_entropy(
|
|
|
- logits, labels, label_smoothing=0.1
|
|
|
+ logits, labels, label_smoothing=label_smoothing,
|
|
|
) / labels.size(0)
|
|
|
if loss.isnan().any().item():
|
|
|
logger.error(seqs)
|
|
@@ -257,18 +256,23 @@ def train(
|
|
|
"Train loss is NaN! Something is wrong in the model!"
|
|
|
)
|
|
|
loss_vals.append(loss.item())
|
|
|
- if update_idx % 100 == 0:
|
|
|
- eval_loss = eval(
|
|
|
- head=head,
|
|
|
- frozen_model=frozen_model,
|
|
|
- dataloader=eval_dataloader,
|
|
|
- params=params,
|
|
|
- )
|
|
|
+ if update_idx % 20 == 0:
|
|
|
+ if 1:
|
|
|
+ eval_loss = eval(
|
|
|
+ head=head,
|
|
|
+ frozen_model=frozen_model,
|
|
|
+ dataloader=eval_dataloader,
|
|
|
+ params=params,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ eval_loss = 0
|
|
|
+ head.train()
|
|
|
+ frozen_model.train()
|
|
|
logger.info(
|
|
|
f" .. epoch={epoch}, "
|
|
|
f"update={update_idx}, "
|
|
|
- f"avg_train_loss={(sum(loss_vals) / len(loss_vals)):.3f}, "
|
|
|
- f"eval_loss={eval_loss:.3f}"
|
|
|
+ f"avg_train_loss={(sum(loss_vals) / len(loss_vals)):.5f}, "
|
|
|
+ f"eval_loss={eval_loss:.5f}"
|
|
|
)
|
|
|
loss_vals = []
|
|
|
|
|
@@ -320,6 +324,7 @@ def main() -> None:
|
|
|
# Put model on selected device
|
|
|
model = model.to(device)
|
|
|
head = head.to(device)
|
|
|
+ logger.info(f"LID head: {head}")
|
|
|
|
|
|
# Create daataloaders
|
|
|
train_dataloader = dataloader.UnitYLanguageIDDataLoader(
|