|
import math |
|
import re |
|
from collections import defaultdict |
|
from datetime import datetime |
|
from typing import List |
|
|
|
import torch |
|
import torch.optim as optim |
|
from rex import accelerator |
|
from rex.data.data_manager import DataManager |
|
from rex.data.dataset import CachedDataset, StreamReadDataset |
|
from rex.tasks.simple_metric_task import SimpleMetricTask |
|
from rex.utils.batch import decompose_batch_into_instances |
|
from rex.utils.config import ConfigParser |
|
from rex.utils.dict import flatten_dict |
|
from rex.utils.io import load_jsonlines |
|
from rex.utils.registry import register |
|
from torch.utils.tensorboard import SummaryWriter |
|
from transformers.optimization import ( |
|
get_cosine_schedule_with_warmup, |
|
get_linear_schedule_with_warmup, |
|
) |
|
|
|
from .metric import MrcNERMetric, MrcSpanMetric, MultiPartSpanMetric |
|
from .model import ( |
|
MrcGlobalPointerModel, |
|
MrcPointerMatrixModel, |
|
SchemaGuidedInstructBertModel, |
|
) |
|
from .transform import ( |
|
CachedLabelPointerTransform, |
|
CachedPointerMRCTransform, |
|
CachedPointerTaggingTransform, |
|
) |
|
|
|
|
|
@register("task") |
|
class MrcTaggingTask(SimpleMetricTask): |
|
def __init__(self, config, **kwargs) -> None: |
|
super().__init__(config, **kwargs) |
|
|
|
def after_initialization(self): |
|
now_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
|
self.tb_logger: SummaryWriter = SummaryWriter( |
|
log_dir=self.task_path / "tb_summary" / now_string, |
|
comment=self.config.comment, |
|
) |
|
|
|
def after_whole_train(self): |
|
self.tb_logger.close() |
|
|
|
def get_grad_norm(self): |
|
|
|
|
|
|
|
|
|
total_norm = 0.0 |
|
for p in self.model.parameters(): |
|
if p.grad is not None: |
|
param_norm = p.grad.detach().data.norm(2) |
|
total_norm += param_norm.item() ** 2 |
|
total_norm = total_norm ** (1.0 / 2) |
|
return total_norm |
|
|
|
def log_loss( |
|
self, idx: int, loss_item: float, step_or_epoch: str, dataset_name: str |
|
): |
|
self.tb_logger.add_scalar( |
|
f"loss/{dataset_name}/{step_or_epoch}", loss_item, idx |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tb_logger.add_scalar("lr", self.optimizer.param_groups[0]["lr"], idx) |
|
self.tb_logger.add_scalar("grad_norm_total", self.get_grad_norm(), idx) |
|
|
|
def log_metrics( |
|
self, idx: int, metrics: dict, step_or_epoch: str, dataset_name: str |
|
): |
|
metrics = flatten_dict(metrics) |
|
self.tb_logger.add_scalars(f"{dataset_name}/{step_or_epoch}", metrics, idx) |
|
|
|
def init_transform(self): |
|
return CachedPointerTaggingTransform( |
|
self.config.max_seq_len, |
|
self.config.plm_dir, |
|
self.config.ent_type2query_filepath, |
|
mode=self.config.mode, |
|
negative_sample_prob=self.config.negative_sample_prob, |
|
) |
|
|
|
def init_data_manager(self): |
|
return DataManager( |
|
self.config.train_filepath, |
|
self.config.dev_filepath, |
|
self.config.test_filepath, |
|
CachedDataset, |
|
self.transform, |
|
load_jsonlines, |
|
self.config.train_batch_size, |
|
self.config.eval_batch_size, |
|
self.transform.collate_fn, |
|
use_stream_transform=False, |
|
debug_mode=self.config.debug_mode, |
|
dump_cache_dir=self.config.dump_cache_dir, |
|
regenerate_cache=self.config.regenerate_cache, |
|
) |
|
|
|
def init_model(self): |
|
|
|
m = MrcGlobalPointerModel( |
|
self.config.plm_dir, |
|
biaffine_size=self.config.biaffine_size, |
|
dropout=self.config.dropout, |
|
mode=self.config.mode, |
|
) |
|
return m |
|
|
|
def init_metric(self): |
|
return MrcNERMetric() |
|
|
|
def init_optimizer(self): |
|
no_decay = r"(embedding|LayerNorm|\.bias$)" |
|
plm_lr = r"^plm\." |
|
non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])" |
|
|
|
param_groups = [] |
|
for name, param in self.model.named_parameters(): |
|
lr = self.config.learning_rate |
|
weight_decay = self.config.weight_decay |
|
if re.search(non_trainable, name): |
|
param.requires_grad = False |
|
if not re.search(plm_lr, name): |
|
lr = self.config.other_learning_rate |
|
if re.search(no_decay, name): |
|
weight_decay = 0.0 |
|
param_groups.append( |
|
{"params": param, "lr": lr, "weight_decay": weight_decay} |
|
) |
|
return optim.AdamW( |
|
param_groups, |
|
lr=self.config.learning_rate, |
|
betas=(0.9, 0.98), |
|
eps=1e-6, |
|
) |
|
|
|
def init_lr_scheduler(self): |
|
num_training_steps = int( |
|
len(self.data_manager.train_loader) |
|
* self.config.num_epochs |
|
* accelerator.num_processes |
|
) |
|
num_warmup_steps = math.floor( |
|
num_training_steps * self.config.warmup_proportion |
|
) |
|
|
|
return get_cosine_schedule_with_warmup( |
|
self.optimizer, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_training_steps, |
|
) |
|
|
|
def predict_api(self, texts: List[str], **kwargs): |
|
raw_dataset = self.transform.predict_transform(texts) |
|
text_ids = sorted(list({ins["id"] for ins in raw_dataset})) |
|
loader = self.data_manager.prepare_loader(raw_dataset) |
|
|
|
loader = accelerator.prepare_data_loader(loader) |
|
id2ents = defaultdict(set) |
|
for batch in loader: |
|
batch_out = self.model(**batch, is_eval=True) |
|
for _id, _pred in zip(batch["id"], batch_out["pred"]): |
|
id2ents[_id].update(_pred) |
|
results = [id2ents[_id] for _id in text_ids] |
|
|
|
return results |
|
|
|
|
|
@register("task") |
|
class MrcQaTask(MrcTaggingTask): |
|
def init_transform(self): |
|
return CachedPointerMRCTransform( |
|
self.config.max_seq_len, |
|
self.config.plm_dir, |
|
mode=self.config.mode, |
|
) |
|
|
|
def init_model(self): |
|
|
|
m = MrcGlobalPointerModel( |
|
self.config.plm_dir, |
|
biaffine_size=self.config.biaffine_size, |
|
dropout=self.config.dropout, |
|
mode=self.config.mode, |
|
) |
|
return m |
|
|
|
def init_metric(self): |
|
return MrcSpanMetric() |
|
|
|
def predict_api(self, data: list[dict], **kwargs): |
|
""" |
|
Args: |
|
data: a list of dict with query, context, and background strings |
|
""" |
|
raw_dataset = self.transform.predict_transform(data) |
|
loader = self.data_manager.prepare_loader(raw_dataset) |
|
results = [] |
|
for batch in loader: |
|
batch_out = self.model(**batch, is_eval=True) |
|
batch["pred"] = batch_out["pred"] |
|
instances = decompose_batch_into_instances(batch) |
|
for ins in instances: |
|
preds = ins["pred"] |
|
ins_results = [] |
|
for index_list in preds: |
|
ins_result = [] |
|
for i in index_list: |
|
ins_result.append(ins["raw_tokens"][i]) |
|
ins_results.append(("".join(ins_result), tuple(index_list))) |
|
results.append(ins_results) |
|
|
|
return results |
|
|
|
|
|
class StreamReadDatasetWithLen(StreamReadDataset): |
|
def __len__(self): |
|
return 631346 |
|
|
|
|
|
@register("task") |
|
class SchemaGuidedInstructBertTask(MrcTaggingTask): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_transform(self): |
|
self.transform: CachedLabelPointerTransform |
|
return CachedLabelPointerTransform( |
|
self.config.max_seq_len, |
|
self.config.plm_dir, |
|
mode=self.config.mode, |
|
label_span=self.config.label_span, |
|
include_instructions=self.config.get("include_instructions", True), |
|
) |
|
|
|
def init_data_manager(self): |
|
if self.config.get("stream_mode", False): |
|
DatasetClass = StreamReadDatasetWithLen |
|
transform = self.transform.transform |
|
else: |
|
DatasetClass = CachedDataset |
|
transform = self.transform |
|
return DataManager( |
|
self.config.train_filepath, |
|
self.config.dev_filepath, |
|
self.config.test_filepath, |
|
DatasetClass, |
|
transform, |
|
load_jsonlines, |
|
self.config.train_batch_size, |
|
self.config.eval_batch_size, |
|
self.transform.collate_fn, |
|
use_stream_transform=self.config.get("stream_mode", False), |
|
debug_mode=self.config.debug_mode, |
|
dump_cache_dir=self.config.dump_cache_dir, |
|
regenerate_cache=self.config.regenerate_cache, |
|
) |
|
|
|
def init_model(self): |
|
self.model = SchemaGuidedInstructBertModel( |
|
self.config.plm_dir, |
|
vocab_size=len(self.transform.tokenizer), |
|
use_rope=self.config.use_rope, |
|
biaffine_size=self.config.biaffine_size, |
|
dropout=self.config.dropout, |
|
) |
|
|
|
if self.config.get("base_model_path"): |
|
self.load( |
|
self.config.base_model_path, |
|
load_config=False, |
|
load_model=True, |
|
load_optimizer=False, |
|
load_history=False, |
|
) |
|
return self.model |
|
|
|
def init_optimizer(self): |
|
no_decay = r"(embedding|LayerNorm|\.bias$)" |
|
plm_lr = r"^plm\." |
|
|
|
non_trainable = "no_non_trainable" |
|
|
|
param_groups = [] |
|
for name, param in self.model.named_parameters(): |
|
lr = self.config.learning_rate |
|
weight_decay = self.config.weight_decay |
|
if re.search(non_trainable, name): |
|
param.requires_grad = False |
|
if not re.search(plm_lr, name): |
|
lr = self.config.other_learning_rate |
|
if re.search(no_decay, name): |
|
weight_decay = 0.0 |
|
param_groups.append( |
|
{"params": param, "lr": lr, "weight_decay": weight_decay} |
|
) |
|
return optim.AdamW( |
|
param_groups, |
|
lr=self.config.learning_rate, |
|
betas=(0.9, 0.98), |
|
eps=1e-6, |
|
) |
|
|
|
def init_metric(self): |
|
return MultiPartSpanMetric() |
|
|
|
def _convert_span_to_string(self, span, token_ids, tokenizer): |
|
string = "" |
|
if len(span) == 0 or len(span) > 2: |
|
pass |
|
elif len(span) == 1: |
|
string = tokenizer.decode(token_ids[span[0]]) |
|
elif len(span) == 2: |
|
string = tokenizer.decode(token_ids[span[0] : span[1] + 1]) |
|
return (string, self.reset_position(token_ids, span)) |
|
|
|
def reset_position(self, input_ids: list[int], span: list[int]) -> list[int]: |
|
if isinstance(input_ids, torch.Tensor): |
|
input_ids = input_ids.cpu().tolist() |
|
if len(span) < 1: |
|
return span |
|
|
|
tp_token_id, tl_token_id = self.transform.tokenizer.convert_tokens_to_ids( |
|
[self.transform.tp_token, self.transform.tl_token] |
|
) |
|
offset = 0 |
|
if tp_token_id in input_ids: |
|
offset = input_ids.index(tp_token_id) + 1 |
|
elif tl_token_id in input_ids: |
|
offset = input_ids.index(tl_token_id) + 1 |
|
return [i - offset for i in span] |
|
|
|
def predict_api(self, data: list[dict], **kwargs): |
|
""" |
|
Args: |
|
data: a list of dict in UDI: |
|
{ |
|
"id": str, |
|
"instruction": str, |
|
"schema": { |
|
"ent": list, |
|
"rel": list, |
|
"event": dict, |
|
"cls": list, |
|
"discontinuous_ent": list, |
|
"hyper_rel": dict |
|
}, |
|
"text": str, |
|
"bg": str, |
|
"ans": {}, # empty dict |
|
} |
|
""" |
|
raw_dataset = [self.transform.transform(d) for d in data] |
|
loader = self.data_manager.prepare_loader(raw_dataset) |
|
results = [] |
|
for batch in loader: |
|
batch_out = self.model(**batch, is_eval=True) |
|
batch["pred"] = batch_out["pred"] |
|
instances = decompose_batch_into_instances(batch) |
|
for ins in instances: |
|
pred_clses = [] |
|
pred_ents = [] |
|
pred_rels = [] |
|
pred_trigger_to_event = defaultdict( |
|
lambda: {"event_type": "", "arguments": []} |
|
) |
|
pred_events = [] |
|
pred_spans = [] |
|
pred_discon_ents = [] |
|
pred_hyper_rels = [] |
|
raw_schema = ins["raw"]["schema"] |
|
for multi_part_span in ins["pred"]: |
|
span = tuple(multi_part_span) |
|
span_to_label = ins["span_to_label"] |
|
if span[0] in span_to_label: |
|
label = span_to_label[span[0]] |
|
if label["task"] == "cls" and len(span) == 1: |
|
pred_clses.append(label["string"]) |
|
elif label["task"] == "ent" and len(span) == 2: |
|
string = self._convert_span_to_string( |
|
span[1], ins["input_ids"], self.transform.tokenizer |
|
) |
|
pred_ents.append((label["string"], string)) |
|
elif label["task"] == "rel" and len(span) == 3: |
|
head = self._convert_span_to_string( |
|
span[1], ins["input_ids"], self.transform.tokenizer |
|
) |
|
tail = self._convert_span_to_string( |
|
span[2], ins["input_ids"], self.transform.tokenizer |
|
) |
|
pred_rels.append((label["string"], head, tail)) |
|
elif label["task"] == "event": |
|
if label["type"] == "lm" and len(span) == 2: |
|
pred_trigger_to_event[span[1]]["event_type"] = label["string"] |
|
elif label["type"] == "lr" and len(span) == 3: |
|
arg = self._convert_span_to_string( |
|
span[2], ins["input_ids"], self.transform.tokenizer |
|
) |
|
pred_trigger_to_event[span[1]]["arguments"].append( |
|
{"argument": arg, "role": label["string"]} |
|
) |
|
elif label["task"] == "discontinuous_ent" and len(span) > 1: |
|
parts = [ |
|
self._convert_span_to_string( |
|
part, ins["input_ids"], self.transform.tokenizer |
|
) |
|
for part in span[1:] |
|
] |
|
string = " ".join([part[0] for part in parts]) |
|
position = [] |
|
for part in parts: |
|
position.append(part[1]) |
|
pred_discon_ents.append( |
|
(label["string"], string, self.reset_position(position)) |
|
) |
|
elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: |
|
q_label = span_to_label[span[3]] |
|
span_1 = self._convert_span_to_string( |
|
span[1], ins["input_ids"], self.transform.tokenizer |
|
) |
|
span_2 = self._convert_span_to_string( |
|
span[2], ins["input_ids"], self.transform.tokenizer |
|
) |
|
span_4 = self._convert_span_to_string( |
|
span[4], ins["input_ids"], self.transform.tokenizer |
|
) |
|
pred_hyper_rels.append((label["string"], span_1, span_2, q_label["string"], span_4)) |
|
else: |
|
|
|
pred_token_ids = [] |
|
for part in span: |
|
_pred_token_ids = [ins["input_ids"][i] for i in part] |
|
pred_token_ids.extend(_pred_token_ids) |
|
span_string = self.transform.tokenizer.decode(pred_token_ids) |
|
pred_spans.append( |
|
( |
|
span_string, |
|
tuple( |
|
[ |
|
tuple( |
|
self.reset_position( |
|
ins["input_ids"].cpu().tolist(), part |
|
) |
|
) |
|
for part in span |
|
] |
|
), |
|
) |
|
) |
|
for trigger, item in pred_trigger_to_event.items(): |
|
trigger = self._convert_span_to_string( |
|
trigger, ins["input_ids"], self.transform.tokenizer |
|
) |
|
if item["event_type"] not in raw_schema["event"]: |
|
continue |
|
legal_roles = raw_schema["event"][item["event_type"]] |
|
pred_events.append( |
|
{ |
|
"trigger": trigger, |
|
"event_type": item["event_type"], |
|
"arguments": [ |
|
arg |
|
for arg in filter( |
|
lambda arg: arg["role"] in legal_roles, |
|
item["arguments"], |
|
) |
|
], |
|
} |
|
) |
|
results.append( |
|
{ |
|
"id": ins["raw"]["id"], |
|
"results": { |
|
"cls": pred_clses, |
|
"ent": pred_ents, |
|
"rel": pred_rels, |
|
"event": pred_events, |
|
"span": pred_spans, |
|
"discon_ent": pred_discon_ents, |
|
"hyper_rel": pred_hyper_rels, |
|
}, |
|
} |
|
) |
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|