|
@@ -28,10 +28,9 @@ class BaseStrategy:
|
|
for invalid_slice in self.invalid_slices:
|
|
for invalid_slice in self.invalid_slices:
|
|
logits[..., invalid_slice] = -65504
|
|
logits[..., invalid_slice] = -65504
|
|
|
|
|
|
- # logits = top_k_logits(logits, self.topk, self.top_p)
|
|
|
|
- # probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
|
|
|
|
- # pred = torch.multinomial(probs, num_samples=1)
|
|
|
|
- pred = torch.argmax(logits, dim=-1)
|
|
|
|
|
|
+ logits = top_k_logits(logits, self.topk, self.top_p)
|
|
|
|
+ probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
|
|
|
|
+ pred = torch.multinomial(probs, num_samples=1)
|
|
for i in range(self.batch_size):
|
|
for i in range(self.batch_size):
|
|
if i >= batch_size:
|
|
if i >= batch_size:
|
|
self._is_done[i] = True
|
|
self._is_done[i] = True
|