Spaces:
Running
on
L40S
Running
on
L40S
File size: 1,409 Bytes
ca25718 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from abc import ABC, abstractmethod
import torch
class BaseRewardLoss(ABC):
"""
Base class for reward functions implementing a differentiable reward function for optimization.
"""
def __init__(self, name: str, weighting: float):
self.name = name
self.weighting = weighting
@staticmethod
def freeze_parameters(params: torch.nn.ParameterList):
for param in params:
param.requires_grad = False
@abstractmethod
def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
pass
@abstractmethod
def get_text_features(self, prompt: str) -> torch.Tensor:
pass
@abstractmethod
def compute_loss(
self, image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
pass
def process_features(self, features: torch.Tensor) -> torch.Tensor:
features_normed = features / features.norm(dim=-1, keepdim=True)
return features_normed
def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor:
image_features = self.get_image_features(image)
text_features = self.get_text_features(prompt)
image_features_normed = self.process_features(image_features)
text_features_normed = self.process_features(text_features)
loss = self.compute_loss(image_features_normed, text_features_normed)
return loss
|