|
@@ -80,20 +80,21 @@ class ModelForEvaluation(torch.nn.Module):
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
self.model = model
|
|
self.model = model
|
|
|
|
+ self.device = next(self.model.parameters()).device
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
- def process_data(batch):
|
|
|
|
|
|
+ def process_data(batch, device):
|
|
return (
|
|
return (
|
|
- batch["tokens"].to(device=torch.cuda.current_device()).long(),
|
|
|
|
- batch["position_ids"].to(device=torch.cuda.current_device()).long(),
|
|
|
|
- batch["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1),
|
|
|
|
|
|
+ batch["tokens"].to(device=device).long(),
|
|
|
|
+ batch["position_ids"].to(device=device).long(),
|
|
|
|
+ batch["attention_mask"].to(device=device).bool().unsqueeze(1),
|
|
)
|
|
)
|
|
|
|
|
|
def cond_log_prob(self, batch) -> List[List[float]]:
|
|
def cond_log_prob(self, batch) -> List[List[float]]:
|
|
"""
|
|
"""
|
|
@return: Conditional log probability of each option
|
|
@return: Conditional log probability of each option
|
|
"""
|
|
"""
|
|
- tokens, position_ids, attention_mask = self.process_data(batch)
|
|
|
|
|
|
+ tokens, position_ids, attention_mask = self.process_data(batch, self.device)
|
|
choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
|
|
choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
|
|
is_single_token = batch["is_single_token"]
|
|
is_single_token = batch["is_single_token"]
|
|
|
|
|
|
@@ -123,7 +124,7 @@ class ModelForEvaluation(torch.nn.Module):
|
|
@return: A list of text model generated, sorted by score in descending order
|
|
@return: A list of text model generated, sorted by score in descending order
|
|
"""
|
|
"""
|
|
|
|
|
|
- seqs = sample["tokens"].to(device=torch.cuda.current_device()).long()
|
|
|
|
|
|
+ seqs = sample["tokens"].to(device=self.device).long()
|
|
context_lengths = sample["context_length"].long()
|
|
context_lengths = sample["context_length"].long()
|
|
|
|
|
|
def get_masks_and_position_ids(seq):
|
|
def get_masks_and_position_ids(seq):
|
|
@@ -131,8 +132,8 @@ class ModelForEvaluation(torch.nn.Module):
|
|
max_gen_length = sample['target_position_ids'].shape[-1]
|
|
max_gen_length = sample['target_position_ids'].shape[-1]
|
|
tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
|
|
tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
|
|
position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
|
|
position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
|
|
- position_ids = position_ids.to(device=torch.cuda.current_device()).long()
|
|
|
|
- attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device())
|
|
|
|
|
|
+ position_ids = position_ids.to(device=self.device).long()
|
|
|
|
+ attention_mask = sample["attention_mask"].to(device=self.device)
|
|
context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
|
|
context_mask = attention_mask[torch.arange(batch_size), context_lengths - 1].unsqueeze(1).repeat(1,
|
|
max_gen_length,
|
|
max_gen_length,
|
|
1)
|
|
1)
|
|
@@ -178,10 +179,10 @@ class ModelForEvaluation(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def calculate_loss(self, batch) -> List[float]:
|
|
def calculate_loss(self, batch) -> List[float]:
|
|
- tokens, position_ids, attention_mask = self.process_data(batch)
|
|
|
|
|
|
+ tokens, position_ids, attention_mask = self.process_data(batch, self.device)
|
|
targets, loss_masks = (
|
|
targets, loss_masks = (
|
|
- batch["targets"].to(device=torch.cuda.current_device()).long(),
|
|
|
|
- batch["loss_masks"].to(device=torch.cuda.current_device()).long(),
|
|
|
|
|
|
+ batch["targets"].to(device=self.device).long(),
|
|
|
|
+ batch["loss_masks"].to(device=self.device).long(),
|
|
)
|
|
)
|
|
|
|
|
|
original_parallel_output = self.model.transformer.parallel_output
|
|
original_parallel_output = self.model.transformer.parallel_output
|