|
@@ -40,9 +40,8 @@ def batch_filling_sequence(
|
|
|
# Now, we want to generate seq[counter + 1],
|
|
|
# token[:, index: counter+1] needs forwarding.
|
|
|
# forward
|
|
|
- if num_beams > 1:
|
|
|
- tokens = tokens.reshape(batch_size * num_beams, -1)
|
|
|
- mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
|
|
|
+ tokens = tokens.reshape(batch_size * num_beams, -1)
|
|
|
+ mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
|
|
|
logits, *output_per_layers = model(
|
|
|
tokens[:, index:],
|
|
|
position_ids[..., index: counter+1],
|
|
@@ -58,11 +57,12 @@ def batch_filling_sequence(
|
|
|
logits = logits[:, -1]
|
|
|
counter += 1
|
|
|
index = counter
|
|
|
+ # if torch.distributed.get_rank() == 0:
|
|
|
+ # print(f"counter: {counter}: logits: {logits.float().abs().mean()}")
|
|
|
# sampling
|
|
|
- if num_beams > 1:
|
|
|
- logits = logits.reshape(batch_size, num_beams, -1)
|
|
|
- tokens = tokens.reshape(batch_size, num_beams, -1)
|
|
|
- mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
|
|
|
+ logits = logits.reshape(batch_size, num_beams, -1)
|
|
|
+ tokens = tokens.reshape(batch_size, num_beams, -1)
|
|
|
+ mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
|
|
|
tokens, mems = strategy.forward(logits, tokens, mems)
|
|
|
if len(tokens.shape) == 3 and num_beams == 1:
|
|
|
num_beams = tokens.shape[1]
|