Browse Source

Fix sampling in BaseStrategy

Zhengxiao Du 3 năm trước cách đây
mục cha
commit
1241f03ec2
1 tập tin đã thay đổi với 3 bổ sung4 xóa
  1. 3 4
      generation/strategies.py

+ 3 - 4
generation/strategies.py

@@ -28,10 +28,9 @@ class BaseStrategy:
         for invalid_slice in self.invalid_slices:
             logits[..., invalid_slice] = -65504
 
-        # logits = top_k_logits(logits, self.topk, self.top_p)
-        # probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
-        # pred = torch.multinomial(probs, num_samples=1)
-        pred = torch.argmax(logits, dim=-1)
+        logits = top_k_logits(logits, self.topk, self.top_p)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
+        pred = torch.multinomial(probs, num_samples=1)
         for i in range(self.batch_size):
             if i >= batch_size:
                 self._is_done[i] = True