|
@@ -4,7 +4,8 @@ import torch.nn.functional as F
|
|
|
from SwissArmyTransformer.generation.sampling_strategies.base_strategy import top_k_logits
|
|
|
|
|
|
class BaseStrategy:
|
|
|
- def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
|
|
|
+ def __init__(self, batch_size, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None,
|
|
|
+ deterministic=False):
|
|
|
self.batch_size = batch_size
|
|
|
self.invalid_slices = invalid_slices
|
|
|
self.temperature = temperature
|
|
@@ -14,6 +15,8 @@ class BaseStrategy:
|
|
|
if end_tokens is None:
|
|
|
end_tokens = []
|
|
|
self.end_tokens = end_tokens
|
|
|
+ self.deterministic = deterministic
|
|
|
+ print(self.deterministic)
|
|
|
self._is_done = np.zeros(self.batch_size, dtype=np.bool)
|
|
|
|
|
|
@property
|
|
@@ -30,8 +33,11 @@ class BaseStrategy:
|
|
|
logits[..., invalid_slice] = -65504
|
|
|
|
|
|
logits = top_k_logits(logits, self.topk, self.top_p)
|
|
|
- probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
|
|
|
- pred = torch.multinomial(probs, num_samples=1)
|
|
|
+ if self.deterministic:
|
|
|
+ pred = logits.max(dim=-1)[1]
|
|
|
+ else:
|
|
|
+ probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
|
|
|
+ pred = torch.multinomial(probs, num_samples=1)
|
|
|
for i in range(self.batch_size):
|
|
|
if i >= batch_size:
|
|
|
self._is_done[i] = True
|