|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from abc import abstractmethod |
|
from torch import tensor as Tensor |
|
from typing import List, Any |
|
|
|
|
|
class BaseVAE(nn.Module): |
|
|
|
def __init__(self) -> None: |
|
super(BaseVAE, self).__init__() |
|
|
|
def encode(self, input: Tensor) -> List[Tensor]: |
|
raise NotImplementedError |
|
|
|
def decode(self, input: Tensor) -> Any: |
|
raise NotImplementedError |
|
|
|
def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor: |
|
raise NotImplementedError |
|
|
|
def generate(self, x: Tensor, **kwargs) -> Tensor: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def forward(self, *inputs: Tensor) -> Tensor: |
|
pass |
|
|
|
@abstractmethod |
|
def loss_function(self, *inputs: Any, **kwargs) -> Tensor: |
|
pass |
|
|
|
|
|
class MLPAE(BaseVAE): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
semantic_latent_dim: int, |
|
truthful_latent_dim: int, |
|
semantic_hidden_dims: List = None, |
|
truthful_hidden_dims: List = None, |
|
decoder_hidden_dims: List = None, |
|
**kwargs |
|
) -> None: |
|
super(MLPAE, self).__init__() |
|
|
|
self.semantic_latent_dim = semantic_latent_dim |
|
|
|
if semantic_hidden_dims is None: |
|
semantic_hidden_dims = [] |
|
|
|
|
|
semantic_encoder_modules = [] |
|
flat_size = in_channels |
|
for h_dim in semantic_hidden_dims: |
|
semantic_encoder_modules.append( |
|
nn.Sequential( |
|
nn.Linear(flat_size, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU() |
|
) |
|
) |
|
flat_size = h_dim |
|
semantic_encoder_modules.append( |
|
nn.Sequential( |
|
nn.Linear(flat_size, semantic_latent_dim), |
|
nn.LayerNorm(semantic_latent_dim), |
|
nn.LeakyReLU(), |
|
) |
|
) |
|
|
|
self.semantic_encoder = nn.Sequential(*semantic_encoder_modules) |
|
|
|
if truthful_hidden_dims is None: |
|
truthful_hidden_dims = [] |
|
|
|
|
|
truthful_encoder_modules = [] |
|
flat_size = in_channels |
|
for h_dim in truthful_hidden_dims: |
|
truthful_encoder_modules.append( |
|
nn.Sequential( |
|
( |
|
nn.Linear(flat_size, h_dim) |
|
if flat_size != h_dim |
|
else nn.Identity() |
|
), |
|
nn.LayerNorm(h_dim), |
|
nn.LeakyReLU(), |
|
) |
|
) |
|
flat_size = h_dim |
|
truthful_encoder_modules.append( |
|
nn.Sequential( |
|
( |
|
nn.Linear(flat_size, truthful_latent_dim) |
|
if flat_size != truthful_latent_dim |
|
else nn.Identity() |
|
), |
|
nn.LayerNorm(truthful_latent_dim), |
|
nn.LeakyReLU(), |
|
) |
|
) |
|
|
|
self.truthful_encoder = nn.Sequential(*truthful_encoder_modules) |
|
|
|
|
|
self.num_heads = 1 |
|
self.cross_attention = nn.MultiheadAttention( |
|
embed_dim=semantic_latent_dim, num_heads=self.num_heads |
|
) |
|
|
|
self.proj = None |
|
if semantic_latent_dim != truthful_latent_dim: |
|
self.proj = nn.Linear(truthful_latent_dim, semantic_latent_dim, bias=False) |
|
|
|
|
|
decoder_modules = [] |
|
if len(decoder_hidden_dims) > 0: |
|
flat_size = semantic_latent_dim |
|
for h_dim in decoder_hidden_dims: |
|
decoder_modules.append( |
|
nn.Sequential( |
|
nn.Linear(flat_size, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU() |
|
) |
|
) |
|
flat_size = h_dim |
|
|
|
flat_size = decoder_hidden_dims[-1] |
|
self.decoder = nn.Sequential(*decoder_modules) |
|
else: |
|
self.decoder_input = None |
|
|
|
self.decoder = None |
|
flat_size = semantic_latent_dim |
|
self.final_layer = nn.Sequential(nn.Linear(flat_size, in_channels)) |
|
|
|
def encode_semantic(self, input: Tensor) -> List[Tensor]: |
|
semantic_latent_rep = self.semantic_encoder(input) |
|
return semantic_latent_rep |
|
|
|
def encode_truthful(self, input: Tensor) -> List[Tensor]: |
|
truthful_latent_rep = self.truthful_encoder(input) |
|
truthful_latent_rep = F.normalize(truthful_latent_rep, p=2, dim=-1) |
|
|
|
return truthful_latent_rep |
|
|
|
def attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: |
|
if self.proj is not None and query.size(-1) != key.size(-1): |
|
key = self.proj(key) |
|
value = self.proj(value) |
|
query = query.unsqueeze(0) |
|
key = key.unsqueeze(0) |
|
value = value.unsqueeze(0) |
|
|
|
output, attention_weights = self.cross_attention(query, key, value) |
|
|
|
return output[0] |
|
|
|
def decode(self, z: Tensor) -> Tensor: |
|
result = z |
|
if self.decoder is not None: |
|
result = self.decoder(result) |
|
result = self.final_layer(result) |
|
return result |
|
|
|
def forward( |
|
self, input: Tensor, truthful_latent_rep=None, **kwargs |
|
) -> List[Tensor]: |
|
semantic_latent_rep = self.encode_semantic(input) |
|
if truthful_latent_rep is None: |
|
truthful_latent_rep = self.encode_truthful(input) |
|
truthful_latent_rep = truthful_latent_rep.reshape( |
|
-1, truthful_latent_rep.size(-1) |
|
) |
|
z = semantic_latent_rep + self.attention( |
|
semantic_latent_rep, |
|
truthful_latent_rep.contiguous(), |
|
truthful_latent_rep.contiguous(), |
|
) |
|
output = self.decode(z) |
|
|
|
return [output, input, semantic_latent_rep, truthful_latent_rep] |
|
|
|
def forward_decoder(self, input, semantic_latent_rep, truthful_latent_rep): |
|
z = semantic_latent_rep + self.attention( |
|
semantic_latent_rep, truthful_latent_rep, truthful_latent_rep |
|
) |
|
output = self.decode(z) |
|
return [output, input, semantic_latent_rep, truthful_latent_rep] |
|
|
|
def get_semantic_latent_rep(self, input: Tensor, **kwargs) -> List[Tensor]: |
|
semantic_latent_rep = self.encode_semantic(input) |
|
return semantic_latent_rep |
|
|
|
def get_truthful_latent_rep(self, input: Tensor, **kwargs) -> List[Tensor]: |
|
truthful_latent_rep = self.encode_truthful(input) |
|
return truthful_latent_rep |
|
|
|
def loss_function(self, *args, **kwargs) -> dict: |
|
recons = args[0] |
|
input = args[1] |
|
recons_loss = F.mse_loss(recons, input) |
|
|
|
loss = recons_loss |
|
return {"loss": loss, "Reconstruction_Loss": recons_loss.detach()} |
|
|
|
|
|
class TruthX: |
|
def __init__(self, model_path, hidden_size, edit_strength=1.0, top_layers=10): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
checkpoint = torch.load(model_path) |
|
args = checkpoint["args"] |
|
|
|
semantic_latent_dim = args.semantic_latent_dim |
|
truthful_latent_dim = args.truthful_latent_dim |
|
semantic_hidden_dims = ( |
|
[int(_) for _ in args.semantic_hidden_dims.split(",")] |
|
if args.semantic_hidden_dims != "" |
|
else [] |
|
) |
|
truthful_hidden_dims = ( |
|
[int(_) for _ in args.truthful_hidden_dims.split(",")] |
|
if args.truthful_hidden_dims != "" |
|
else [] |
|
) |
|
decoder_hidden_dims = ( |
|
[int(_) for _ in args.decoder_hidden_dims.split(",")] |
|
if args.decoder_hidden_dims != "" |
|
else [] |
|
) |
|
|
|
ae_model = MLPAE( |
|
in_channels=hidden_size, |
|
semantic_latent_dim=semantic_latent_dim, |
|
truthful_latent_dim=truthful_latent_dim, |
|
semantic_hidden_dims=semantic_hidden_dims, |
|
truthful_hidden_dims=truthful_hidden_dims, |
|
decoder_hidden_dims=decoder_hidden_dims, |
|
).to(device) |
|
|
|
ae_model.load_state_dict(checkpoint["state_dict"]) |
|
|
|
ae_model.pos_center = ((checkpoint["pos_center"])).to(device) |
|
ae_model.neg_center = ((checkpoint["neg_center"])).to(device) |
|
ae_model.eval() |
|
ae_model.to(device) |
|
self.ae_model = ae_model |
|
|
|
self.rank = checkpoint["rank"] |
|
|
|
self.top_layers = top_layers |
|
self.edit_strength = edit_strength |
|
self.cur_layer_id = 0 |
|
self.prompt_length = None |
|
self.mc = False |
|
|
|
@torch.inference_mode() |
|
def edit(self, X): |
|
layer_id = int(self.cur_layer_id.split(".")[0]) |
|
if self.cur_layer_id.endswith("attn"): |
|
layer_id = 2 * layer_id |
|
else: |
|
layer_id = 2 * layer_id + 1 |
|
|
|
if self.rank[layer_id] > self.top_layers: |
|
return X |
|
|
|
bsz, s_len, d = X.size() |
|
x = ( |
|
X.contiguous() |
|
.view(-1, d) |
|
.type_as(self.ae_model.semantic_encoder[0][0].weight) |
|
) |
|
x_truthful = self.ae_model.get_truthful_latent_rep( |
|
X.type_as(self.ae_model.semantic_encoder[0][0].weight) |
|
) |
|
|
|
pos_center = self.ae_model.pos_center[layer_id].unsqueeze(0) |
|
neg_center = self.ae_model.neg_center[layer_id].unsqueeze(0) |
|
|
|
delta = (pos_center - neg_center).unsqueeze(0) |
|
recon_x_pos = ( |
|
self.ae_model( |
|
x, |
|
truthful_latent_rep=F.normalize( |
|
x_truthful + delta, p=2, dim=-1 |
|
).type_as(x), |
|
)[0] |
|
.contiguous() |
|
.view(bsz, s_len, d) |
|
) |
|
recon_x_neg = ( |
|
self.ae_model( |
|
x, |
|
truthful_latent_rep=F.normalize( |
|
x_truthful - delta, p=2, dim=-1 |
|
).type_as(x), |
|
)[0] |
|
.contiguous() |
|
.view(bsz, s_len, d) |
|
) |
|
Delta = recon_x_pos - recon_x_neg |
|
Delta = Delta.contiguous().to(X.dtype) |
|
Delta = F.normalize(Delta, p=2, dim=-1).type_as(X) * torch.norm( |
|
X, p=2, dim=-1 |
|
).unsqueeze(2) |
|
|
|
mask = torch.ones((bsz, s_len), device=Delta.device) |
|
|
|
if self.mc: |
|
|
|
mask[:, : self.prompt_length + 1] = 0 |
|
|
|
probing = ( |
|
torch.nn.functional.cosine_similarity( |
|
x_truthful, neg_center.unsqueeze(1), dim=-1 |
|
) |
|
- torch.nn.functional.cosine_similarity( |
|
x_truthful, pos_center.unsqueeze(1), dim=-1 |
|
) |
|
).clamp(0, 999) |
|
mask = mask * probing |
|
|
|
else: |
|
|
|
mask[:, :-1] = 0 |
|
mask[:, -1:] = 1 |
|
|
|
new_X = X + (Delta.type_as(X)) * self.edit_strength * mask.unsqueeze(2).type_as(X) |
|
return new_X |
|
|