|
import torch |
|
import torch.nn as nn |
|
from rex.utils.iteration import windowed_queue_iter |
|
from transformers import AutoModel, BertModel |
|
|
|
from src.utils import decode_nnw_nsw_thw_mat, decode_nnw_thw_mat, decode_pointer_mat |
|
|
|
|
|
class Biaffine(nn.Module): |
|
"""Biaffine transformation |
|
|
|
References: |
|
- https://github.com/yzhangcs/parser/blob/main/supar/modules/affine.py |
|
- https://github.com/ljynlp/W2NER |
|
""" |
|
|
|
def __init__(self, n_in, n_out=2, bias_x=True, bias_y=True): |
|
super().__init__() |
|
|
|
self.n_in = n_in |
|
self.n_out = n_out |
|
self.bias_x = bias_x |
|
self.bias_y = bias_y |
|
weight = torch.zeros(n_out, n_in + int(bias_x), n_in + int(bias_y)) |
|
nn.init.xavier_normal_(weight) |
|
self.weight = nn.Parameter(weight, requires_grad=True) |
|
|
|
def extra_repr(self): |
|
s = f"n_in={self.n_in}, n_out={self.n_out}" |
|
if self.bias_x: |
|
s += f", bias_x={self.bias_x}" |
|
if self.bias_y: |
|
s += f", bias_y={self.bias_y}" |
|
|
|
return s |
|
|
|
def forward(self, x, y): |
|
if self.bias_x: |
|
x = torch.cat((x, torch.ones_like(x[..., :1])), -1) |
|
if self.bias_y: |
|
y = torch.cat((y, torch.ones_like(y[..., :1])), -1) |
|
|
|
s = torch.einsum("bxi,oij,byj->boxy", x, self.weight, y) |
|
|
|
|
|
return s |
|
|
|
|
|
class LinearWithAct(nn.Module): |
|
def __init__(self, n_in, n_out, dropout=0) -> None: |
|
super().__init__() |
|
|
|
self.linear = nn.Linear(n_in, n_out) |
|
self.act_fn = nn.GELU() |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
x = self.linear(x) |
|
x = self.act_fn(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class PointerMatrix(nn.Module): |
|
"""Pointer Matrix Prediction |
|
|
|
References: |
|
- https://github.com/ljynlp/W2NER |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size, |
|
biaffine_size, |
|
cls_num=2, |
|
dropout=0, |
|
biaffine_bias=False, |
|
use_rope=False, |
|
): |
|
super().__init__() |
|
self.linear_h = LinearWithAct( |
|
n_in=hidden_size, n_out=biaffine_size, dropout=dropout |
|
) |
|
self.linear_t = LinearWithAct( |
|
n_in=hidden_size, n_out=biaffine_size, dropout=dropout |
|
) |
|
self.biaffine = Biaffine( |
|
n_in=biaffine_size, |
|
n_out=cls_num, |
|
bias_x=biaffine_bias, |
|
bias_y=biaffine_bias, |
|
) |
|
self.use_rope = use_rope |
|
|
|
def sinusoidal_position_embedding(self, qw, kw): |
|
batch_size, seq_len, output_dim = qw.shape |
|
position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1) |
|
|
|
indices = torch.arange(0, output_dim // 2, dtype=torch.float) |
|
indices = torch.pow(10000, -2 * indices / output_dim) |
|
pos_emb = position_ids * indices |
|
pos_emb = torch.stack([torch.sin(pos_emb), torch.cos(pos_emb)], dim=-1) |
|
pos_emb = pos_emb.repeat((batch_size, *([1] * len(pos_emb.shape)))) |
|
pos_emb = torch.reshape(pos_emb, (batch_size, seq_len, output_dim)) |
|
pos_emb = pos_emb.to(qw) |
|
|
|
|
|
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) |
|
|
|
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) |
|
qw2 = torch.cat([-qw[..., 1::2], qw[..., ::2]], -1) |
|
qw = qw * cos_pos + qw2 * sin_pos |
|
kw2 = torch.cat([-kw[..., 1::2], kw[..., ::2]], -1) |
|
kw = kw * cos_pos + kw2 * sin_pos |
|
return qw, kw |
|
|
|
def forward(self, x): |
|
h = self.linear_h(x) |
|
t = self.linear_t(x) |
|
if self.use_rope: |
|
h, t = self.sinusoidal_position_embedding(h, t) |
|
o = self.biaffine(h, t) |
|
return o |
|
|
|
|
|
def multilabel_categorical_crossentropy(y_pred, y_true, bit_mask=None): |
|
""" |
|
https://kexue.fm/archives/7359 |
|
https://github.com/gaohongkui/GlobalPointer_pytorch/blob/main/common/utils.py |
|
""" |
|
y_pred = (1 - 2 * y_true) * y_pred |
|
y_pred_neg = y_pred - y_true * 1e12 |
|
y_pred_pos = y_pred - (1 - y_true) * 1e12 |
|
zeros = torch.zeros_like(y_pred[..., :1]) |
|
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) |
|
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) |
|
neg_loss = torch.logsumexp(y_pred_neg, dim=-1) |
|
pos_loss = torch.logsumexp(y_pred_pos, dim=-1) |
|
|
|
if bit_mask is None: |
|
return neg_loss + pos_loss |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
class MrcPointerMatrixModel(nn.Module): |
|
def __init__( |
|
self, |
|
plm_dir: str, |
|
cls_num: int = 2, |
|
biaffine_size: int = 384, |
|
none_type_id: int = 0, |
|
text_mask_id: int = 4, |
|
dropout: float = 0.3, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.cls_num = cls_num |
|
|
|
self.none_type_id = none_type_id |
|
|
|
|
|
self.text_mask_id = text_mask_id |
|
|
|
self.plm = BertModel.from_pretrained(plm_dir) |
|
hidden_size = self.plm.config.hidden_size |
|
|
|
self.nnw_mat = PointerMatrix( |
|
hidden_size, biaffine_size, cls_num=2, dropout=dropout |
|
) |
|
self.thw_mat = PointerMatrix( |
|
hidden_size, biaffine_size, cls_num=2, dropout=dropout |
|
) |
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
def input_encoding(self, input_ids, mask): |
|
attention_mask = mask.gt(0).float() |
|
plm_outputs = self.plm( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
) |
|
return plm_outputs.last_hidden_state |
|
|
|
def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
|
bs, seq_len = mask.shape |
|
mask_mat = ( |
|
mask.eq(self.text_mask_id).unsqueeze(-1).expand((bs, seq_len, seq_len)) |
|
) |
|
|
|
bit_mask = ( |
|
torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).long() |
|
) |
|
return bit_mask |
|
|
|
def forward(self, input_ids, mask, labels=None, is_eval=False, **kwargs): |
|
hidden = self.input_encoding(input_ids, mask) |
|
nnw_hidden = self.nnw_mat(hidden) |
|
thw_hidden = self.thw_mat(hidden) |
|
|
|
|
|
|
|
bs, _, seq_len, seq_len = nnw_hidden.shape |
|
|
|
bit_mask = self.build_bit_mask(mask) |
|
|
|
results = {"logits": {"nnw": nnw_hidden, "thw": thw_hidden}} |
|
if labels is not None: |
|
|
|
nnw_loss = self.criterion( |
|
nnw_hidden.permute(0, 2, 3, 1).reshape(-1, 2), |
|
labels[:, 0, :, :].reshape(-1), |
|
) |
|
thw_loss = self.criterion( |
|
thw_hidden.permute(0, 2, 3, 1).reshape(-1, 2), |
|
labels[:, 1, :, :].reshape(-1), |
|
) |
|
loss = nnw_loss + thw_loss |
|
results["loss"] = loss |
|
|
|
if is_eval: |
|
batch_positions = self.decode(nnw_hidden, thw_hidden, bit_mask, **kwargs) |
|
results["pred"] = batch_positions |
|
return results |
|
|
|
def decode( |
|
self, |
|
nnw_hidden: torch.Tensor, |
|
thw_hidden: torch.Tensor, |
|
bit_mask: torch.Tensor, |
|
**kwargs, |
|
): |
|
|
|
nnw_pred = nnw_hidden.argmax(1) |
|
thw_pred = thw_hidden.argmax(1) |
|
|
|
pred = torch.stack([nnw_pred, thw_pred], dim=1) |
|
pred = pred * bit_mask |
|
|
|
batch_preds = decode_nnw_thw_mat(pred, offsets=kwargs.get("offset")) |
|
|
|
return batch_preds |
|
|
|
|
|
class MrcGlobalPointerModel(nn.Module): |
|
def __init__( |
|
self, |
|
plm_dir: str, |
|
use_rope: bool = True, |
|
cls_num: int = 2, |
|
biaffine_size: int = 384, |
|
none_type_id: int = 0, |
|
text_mask_id: int = 4, |
|
dropout: float = 0.3, |
|
mode: str = "w2", |
|
): |
|
super().__init__() |
|
|
|
|
|
self.cls_num = cls_num |
|
|
|
self.none_type_id = none_type_id |
|
|
|
|
|
self.text_mask_id = text_mask_id |
|
self.use_rope = use_rope |
|
|
|
|
|
self.mode = mode |
|
assert self.mode in ["w2", "cons"] |
|
|
|
self.plm = BertModel.from_pretrained(plm_dir) |
|
self.hidden_size = self.plm.config.hidden_size |
|
self.biaffine_size = biaffine_size |
|
self.pointer = PointerMatrix( |
|
self.hidden_size, |
|
biaffine_size, |
|
cls_num=2 if self.mode == "w2" else 1, |
|
dropout=dropout, |
|
biaffine_bias=True, |
|
use_rope=use_rope, |
|
) |
|
|
|
def input_encoding(self, input_ids, mask): |
|
attention_mask = mask.gt(0).float() |
|
plm_outputs = self.plm( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
) |
|
return plm_outputs.last_hidden_state |
|
|
|
def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
|
bs, seq_len = mask.shape |
|
mask_mat = ( |
|
mask.eq(self.text_mask_id).unsqueeze(-1).expand((bs, seq_len, seq_len)) |
|
) |
|
|
|
bit_mask = ( |
|
torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).float() |
|
) |
|
if self.mode == "cons": |
|
bit_mask = bit_mask.triu() |
|
|
|
return bit_mask |
|
|
|
def forward( |
|
self, input_ids, mask, labels=None, is_eval=False, top_p=0.5, top_k=-1, **kwargs |
|
): |
|
bit_mask = self.build_bit_mask(mask) |
|
hidden = self.input_encoding(input_ids, mask) |
|
|
|
logits = self.pointer(hidden) |
|
logits = logits * bit_mask - (1.0 - bit_mask) * 1e12 |
|
logits = logits / (self.biaffine_size**0.5) |
|
|
|
bs, cls_num, seq_len, seq_len = logits.shape |
|
assert labels.shape == (bs, cls_num, seq_len, seq_len) |
|
|
|
results = {"logits": logits} |
|
if labels is not None: |
|
loss = multilabel_categorical_crossentropy( |
|
logits.reshape(bs * cls_num, -1), labels.reshape(bs * cls_num, -1) |
|
) |
|
loss = loss.mean() |
|
results["loss"] = loss |
|
|
|
if is_eval: |
|
batch_positions = self.decode(logits, top_p=top_p, top_k=top_k, **kwargs) |
|
results["pred"] = batch_positions |
|
return results |
|
|
|
def calc_path_prob(self, probs, paths): |
|
""" |
|
Args: |
|
probs: (2, seq_len, seq_len) | (1, seq_len, seq_len) |
|
paths: a list of paths in tuple |
|
|
|
Returns: |
|
[(path: tuple, prob: float), ...] |
|
""" |
|
assert self.mode in ["w2", "cons"] |
|
paths_with_prob = [] |
|
for path in paths: |
|
path_prob = 1.0 |
|
if self.mode == "w2": |
|
for se in windowed_queue_iter(path, 2, 1, drop_last=True): |
|
path_prob *= probs[0, se[0], se[-1]] |
|
path_prob *= probs[1, path[-1], path[0]] |
|
elif self.mode == "cons": |
|
path_prob = probs[0, path[0], path[-1]] |
|
paths_with_prob.append((path, path_prob)) |
|
return paths_with_prob |
|
|
|
def decode( |
|
self, |
|
logits: torch.Tensor, |
|
top_p: float = 0.5, |
|
top_k: int = -1, |
|
**kwargs, |
|
): |
|
|
|
assert self.mode in ["w2", "cons"] |
|
|
|
probs = logits.sigmoid() |
|
pred = (probs > top_p).long() |
|
if self.mode == "w2": |
|
preds = decode_nnw_thw_mat(pred, offsets=kwargs.get("offset")) |
|
elif self.mode == "cons": |
|
pred = pred.triu() |
|
preds = decode_pointer_mat(pred, offsets=kwargs.get("offset")) |
|
|
|
if top_k == -1: |
|
batch_preds = preds |
|
else: |
|
batch_preds = [] |
|
for i, paths in enumerate(preds): |
|
paths_with_prob = self.calc_path_prob(probs[i], paths) |
|
paths_with_prob.sort(key=lambda pp: pp[1], reverse=True) |
|
batch_preds.append([pp[0] for pp in paths_with_prob[:top_k]]) |
|
|
|
return batch_preds |
|
|
|
|
|
class SchemaGuidedInstructBertModel(nn.Module): |
|
def __init__( |
|
self, |
|
plm_dir: str, |
|
vocab_size: int = None, |
|
use_rope: bool = True, |
|
biaffine_size: int = 512, |
|
label_mask_id: int = 4, |
|
text_mask_id: int = 7, |
|
dropout: float = 0.3, |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
self.label_mask_id = label_mask_id |
|
self.text_mask_id = text_mask_id |
|
self.use_rope = use_rope |
|
|
|
self.plm = AutoModel.from_pretrained(plm_dir) |
|
if vocab_size: |
|
self.plm.resize_token_embeddings(vocab_size) |
|
self.hidden_size = self.plm.config.hidden_size |
|
self.biaffine_size = biaffine_size |
|
self.pointer = PointerMatrix( |
|
self.hidden_size, |
|
biaffine_size, |
|
cls_num=3, |
|
dropout=dropout, |
|
biaffine_bias=True, |
|
use_rope=use_rope, |
|
) |
|
|
|
def input_encoding(self, input_ids, mask): |
|
attention_mask = mask.gt(0).float() |
|
plm_outputs = self.plm( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
) |
|
return plm_outputs.last_hidden_state |
|
|
|
def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
|
bs, seq_len = mask.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
bit_mask = ( |
|
mask.gt(0).unsqueeze(1).unsqueeze(1).expand(bs, 1, seq_len, seq_len).float() |
|
) |
|
|
|
return bit_mask |
|
|
|
def forward( |
|
self, input_ids, mask, labels=None, is_eval=False, top_p=0.5, top_k=-1, **kwargs |
|
): |
|
bit_mask = self.build_bit_mask(mask) |
|
hidden = self.input_encoding(input_ids, mask) |
|
|
|
logits = self.pointer(hidden) |
|
logits = logits * bit_mask - (1.0 - bit_mask) * 1e12 |
|
logits = logits / (self.biaffine_size**0.5) |
|
|
|
bs, cls_num, seq_len, seq_len = logits.shape |
|
assert labels.shape == (bs, cls_num, seq_len, seq_len) |
|
|
|
results = {"logits": logits} |
|
if labels is not None: |
|
loss = multilabel_categorical_crossentropy( |
|
logits.reshape(bs * cls_num, -1), labels.reshape(bs * cls_num, -1) |
|
) |
|
loss = loss.mean() |
|
results["loss"] = loss |
|
|
|
if is_eval: |
|
batch_positions = self.decode(logits, top_p=top_p, top_k=top_k, **kwargs) |
|
results["pred"] = batch_positions |
|
return results |
|
|
|
def calc_path_prob(self, probs, paths): |
|
""" |
|
Args: |
|
probs: (2, seq_len, seq_len) | (1, seq_len, seq_len) |
|
paths: a list of paths in tuple |
|
|
|
Returns: |
|
[(path: tuple, prob: float), ...] |
|
""" |
|
paths_with_prob = [] |
|
for path in paths: |
|
path_prob = 1.0 |
|
for se in windowed_queue_iter(path, 2, 1, drop_last=True): |
|
path_prob *= probs[0, se[0], se[-1]] |
|
path_prob *= probs[1, path[-1], path[0]] |
|
paths_with_prob.append((path, path_prob)) |
|
return paths_with_prob |
|
|
|
def decode( |
|
self, |
|
logits: torch.Tensor, |
|
top_p: float = 0.5, |
|
top_k: int = -1, |
|
|
|
legal_num_parts: tuple = None, |
|
labels: torch.Tensor = None, |
|
**kwargs, |
|
): |
|
|
|
if labels is None: |
|
|
|
probs = logits.sigmoid() |
|
pred = (probs > top_p).long() |
|
else: |
|
pred = labels |
|
preds = decode_nnw_nsw_thw_mat(pred, offsets=kwargs.get("offset")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if top_k == -1: |
|
batch_preds = preds |
|
else: |
|
batch_preds = [] |
|
for i, paths in enumerate(preds): |
|
paths_with_prob = self.calc_path_prob(probs[i], paths) |
|
paths_with_prob.sort(key=lambda pp: pp[1], reverse=True) |
|
batch_preds.append([pp[0] for pp in paths_with_prob[:top_k]]) |
|
|
|
if legal_num_parts is not None: |
|
legal_preds = [] |
|
for ins_paths in batch_preds: |
|
legal_paths = [] |
|
for path in ins_paths: |
|
if len(path) in legal_num_parts: |
|
legal_paths.append(path) |
|
legal_preds.append(legal_paths) |
|
else: |
|
legal_preds = batch_preds |
|
|
|
return legal_preds |
|
|