ソースを参照

Fix generate.py

Zhengxiao Du 3 年 前
コミット
e7d58c7d9d
1 ファイル変更15 行追加14 行削除
  1. 15 14
      generate.py

+ 15 - 14
generate.py

@@ -7,8 +7,8 @@ from functools import partial
 from typing import List, Tuple
 
 from SwissArmyTransformer import mpu
-from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
-from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
+from evaluation.model import batch_filling_sequence
+from generation import BeamSearchStrategy, BaseStrategy
 from generation import BeamSearchStrategy
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
 from initialize import initialize, initialize_model_and_tokenizer
@@ -31,16 +31,16 @@ def isEnglish(s):
         return True
 
 
-def get_masks_and_position_ids(seq, mask_position, context_length, gmask=False):
-    tokens = seq.unsqueeze(0)
-
-    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
+    context_length = seq.shape[1]
+    tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
+    attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
     attention_mask.tril_()
     attention_mask[..., : context_length - 1] = 1
     attention_mask.unsqueeze_(1)
     attention_mask = (attention_mask < 0.5).bool()
 
-    position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
+    position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
     if not gmask:
         position_ids[context_length - 1 :] = mask_position
 
@@ -99,25 +99,25 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
         output_list = []
 
         input_seq = torch.cuda.LongTensor(
-            seq + [tokenizer.get_command("sop")] + [-1] * (args.out_seq_length - len(seq) - 1),
+            [seq + [tokenizer.get_command("sop")]],
             device=args.device,
         )
-        output, _ = filling_sequence(
+        output, _ = batch_filling_sequence(
             model,
             input_seq,
-            batch_size=num_output,
+            torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
             strategy=strategy,
-            log_attention_weights=None,
             get_masks_and_position_ids=partial(
                 get_masks_and_position_ids,
                 mask_position=mask_position,
-                context_length=len(seq) + 1,
+                max_gen_length=args.out_seq_length - input_seq.shape[-1],
                 gmask=use_gmask,
             ),
         )
         if isinstance(output, torch.Tensor):  # different strategies
             output = list(output)
-
+        else:
+            output = output[0]
         output_list.extend(output)
 
         # clip -1s and fill back generated things into seq
@@ -160,9 +160,10 @@ def main(args):
     end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
 
     if args.sampling_strategy == "BaseStrategy":
-        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens)
+        strategy = BaseStrategy(batch_size=1, temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
     elif args.sampling_strategy == "BeamSearchStrategy":
         strategy = BeamSearchStrategy(
+            1,
             args.num_beams,
             length_penalty=args.length_penalty,
             consider_end=True,