浏览代码

Fix error when empty choice

Sengxian 2 年之前
父节点
当前提交
eaa689e155
共有 1 个文件被更改,包括 8 次插入3 次删除
  1. 8 3
      evaluation/dataset.py

+ 8 - 3
evaluation/dataset.py

@@ -12,6 +12,7 @@ from itertools import accumulate
 from bisect import bisect_right
 
 from SwissArmyTransformer import get_tokenizer
+from SwissArmyTransformer.mpu import get_model_parallel_rank
 
 from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
 from .utils import get_tokenized_input
@@ -102,8 +103,8 @@ class GenerationTaskDataset(EvaluationDataset):
             token_batch.append(token)
             position_id_batch.append(position_id)
             attention_mask_batch.append(attention_mask)
-            context_length_batch.append(sample['context_length'])
-            target_position_id_batch.append(sample['target_position_id'])
+            context_length_batch.append(sample["context_length"])
+            target_position_id_batch.append(sample["target_position_id"])
         return {
             "tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
             "position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
@@ -138,7 +139,7 @@ class GenerationTaskDataset(EvaluationDataset):
         position_id = np.arange(0, context_length, dtype=np.int64)
         target_position_id = np.arange(context_length, context_length + max_gen_length, dtype=np.int64)
         if not use_task_mask:
-            position_id[context_length - 1:] = mask_position
+            position_id[context_length - 1 :] = mask_position
             target_position_id[:] = mask_position
 
         attention_mask = np.tril(np.ones((context_length, context_length), dtype=np.int64))
@@ -256,6 +257,10 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         attention_mask = [np.ones((len(token), len(token)), dtype=np.int64)]
 
         for choice in choices:
+            if len(choice) == 0:
+                if get_model_parallel_rank() == 0:
+                    print("Empty choice found")
+                choice = [0]
             if use_task_mask == False:
                 position_id = np.concatenate(
                     (