2
0

utils.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import torch
  2. import torch.distributed as dist
  3. from SwissArmyTransformer import mpu, get_tokenizer
  4. def print_rank_0(*args, **kwargs):
  5. if torch.distributed.get_rank() == 0:
  6. print(*args, **kwargs)
  7. def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, collate_fn=None):
  8. # Sampler.
  9. world_size = mpu.get_data_parallel_world_size()
  10. rank = mpu.get_data_parallel_rank()
  11. sampler = torch.utils.data.distributed.DistributedSampler(
  12. dataset, num_replicas=world_size, rank=rank, shuffle=False
  13. )
  14. # Data loader. Note that batch size is the per GPU batch size.
  15. data_loader = torch.utils.data.DataLoader(
  16. dataset,
  17. batch_size=micro_batch_size,
  18. sampler=sampler,
  19. shuffle=False,
  20. num_workers=num_workers,
  21. drop_last=drop_last,
  22. pin_memory=True,
  23. collate_fn=collate_fn,
  24. )
  25. return data_loader
  26. def gather_result(prediction, total_length, micro_batch_size):
  27. """
  28. @param prediction: Local predictions with order defined by distributed sampler
  29. @param total_length: Total sample num
  30. @return: [sample_0, sample_1, ..., sample_{total_length-1}]
  31. """
  32. torch.cuda.empty_cache()
  33. world_size = mpu.get_data_parallel_world_size()
  34. prediction_gathered = [None for _ in range(world_size)]
  35. dist.all_gather_object(prediction_gathered, prediction, group=mpu.get_data_parallel_group())
  36. prediction = []
  37. for i in range(len(prediction_gathered[0])):
  38. for j in range(micro_batch_size):
  39. for k in range(world_size):
  40. if j < len(prediction_gathered[k][i]):
  41. prediction.append(prediction_gathered[k][i][j])
  42. prediction = prediction[:total_length]
  43. return prediction
  44. def get_tokenized_input(item, key):
  45. if key in item:
  46. return item[key]
  47. tokenizer = get_tokenizer()
  48. pretokenized_key = key + "_pretokenized"
  49. assert pretokenized_key in item
  50. if isinstance(item[pretokenized_key], list):
  51. result = []
  52. for raw in item[pretokenized_key]:
  53. result.append(tokenizer.tokenize(raw))
  54. return result
  55. else:
  56. return tokenizer.tokenize(item[pretokenized_key])