|
@@ -7,8 +7,8 @@ from functools import partial
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
from SwissArmyTransformer import mpu
|
|
|
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
|
|
|
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
|
|
|
+from evaluation.model import batch_filling_sequence
|
|
|
+from generation import BeamSearchStrategy, BaseStrategy
|
|
|
from generation import BeamSearchStrategy
|
|
|
from SwissArmyTransformer.generation.utils import timed_name, generate_continually
|
|
|
from initialize import initialize, initialize_model_and_tokenizer
|
|
@@ -31,16 +31,16 @@ def isEnglish(s):
|
|
|
return True
|
|
|
|
|
|
|
|
|
-def get_masks_and_position_ids(seq, mask_position, context_length, gmask=False):
|
|
|
- tokens = seq.unsqueeze(0)
|
|
|
-
|
|
|
- attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
|
|
|
+def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
|
|
|
+ context_length = seq.shape[1]
|
|
|
+ tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
|
|
|
+ attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
|
|
|
attention_mask.tril_()
|
|
|
attention_mask[..., : context_length - 1] = 1
|
|
|
attention_mask.unsqueeze_(1)
|
|
|
attention_mask = (attention_mask < 0.5).bool()
|
|
|
|
|
|
- position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
|
|
|
+ position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
|
|
|
if not gmask:
|
|
|
position_ids[context_length - 1 :] = mask_position
|
|
|
|
|
@@ -99,25 +99,25 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
|
|
|
output_list = []
|
|
|
|
|
|
input_seq = torch.cuda.LongTensor(
|
|
|
- seq + [tokenizer.get_command("sop")] + [-1] * (args.out_seq_length - len(seq) - 1),
|
|
|
+ [seq + [tokenizer.get_command("sop")]],
|
|
|
device=args.device,
|
|
|
)
|
|
|
- output, _ = filling_sequence(
|
|
|
+ output, _ = batch_filling_sequence(
|
|
|
model,
|
|
|
input_seq,
|
|
|
- batch_size=num_output,
|
|
|
+ torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
|
|
|
strategy=strategy,
|
|
|
- log_attention_weights=None,
|
|
|
get_masks_and_position_ids=partial(
|
|
|
get_masks_and_position_ids,
|
|
|
mask_position=mask_position,
|
|
|
- context_length=len(seq) + 1,
|
|
|
+ max_gen_length=args.out_seq_length - input_seq.shape[-1],
|
|
|
gmask=use_gmask,
|
|
|
),
|
|
|
)
|
|
|
if isinstance(output, torch.Tensor): # different strategies
|
|
|
output = list(output)
|
|
|
-
|
|
|
+ else:
|
|
|
+ output = output[0]
|
|
|
output_list.extend(output)
|
|
|
|
|
|
# clip -1s and fill back generated things into seq
|
|
@@ -160,9 +160,10 @@ def main(args):
|
|
|
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
|
|
|
|
|
|
if args.sampling_strategy == "BaseStrategy":
|
|
|
- strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens)
|
|
|
+ strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
|
|
|
elif args.sampling_strategy == "BeamSearchStrategy":
|
|
|
strategy = BeamSearchStrategy(
|
|
|
+ 1,
|
|
|
args.num_beams,
|
|
|
length_penalty=args.length_penalty,
|
|
|
consider_end=True,
|