from typing import Optional import torch import torch.nn as nn def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: tensor = tensor * mask tensor = tensor.sum(dim=dim) mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean class GPTLMLoss(nn.Module): """ GPT Language Model Loss """ def __init__(self): super().__init__() self.loss = nn.CrossEntropyLoss() def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) class PolicyLoss(nn.Module): """ Policy Loss for PPO """ def __init__(self, clip_eps: float = 0.2) -> None: super().__init__() self.clip_eps = clip_eps def forward(self, log_probs: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor, action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages loss = -torch.min(surr1, surr2) if action_mask is not None: loss = masked_mean(loss, action_mask) loss = loss.mean() return loss class ValueLoss(nn.Module): """ Value Loss for PPO """ def __init__(self, clip_eps: float = 0.4) -> None: super().__init__() self.clip_eps = clip_eps def forward(self, values: torch.Tensor, old_values: torch.Tensor, reward: torch.Tensor, action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) surr1 = (values_clipped - reward)**2 surr2 = (values - reward)**2 loss = torch.max(surr1, surr2) loss = loss.mean() return 0.5 * loss class PPOPtxActorLoss(nn.Module): """ To Do: PPO-ptx Actor Loss """ def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None: super().__init__() self.pretrain_coef = pretrain_coef self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps) self.pretrain_loss_fn = pretrain_loss_fn def forward(self, log_probs: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor, lm_logits: torch.Tensor, lm_input_ids: torch.Tensor, action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask) lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids) return policy_loss + self.pretrain_coef * lm_loss class LogSigLoss(nn.Module): """ Pairwise Loss for Reward Model Details: https://arxiv.org/abs/2203.02155 """ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: probs = torch.sigmoid(chosen_reward - reject_reward) log_probs = torch.log(probs) loss = -log_probs.mean() return loss class LogExpLoss(nn.Module): """ Pairwise Loss for Reward Model Details: https://arxiv.org/abs/2204.05862 """ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() return loss