|
from collections import defaultdict |
|
|
|
import torch |
|
from rex.utils.iteration import windowed_queue_iter |
|
from rex.utils.position import find_all_positions |
|
|
|
|
|
def find_paths_from_adj_mat(adj_mat: torch.Tensor) -> list[tuple[int]]: |
|
assert adj_mat.shape[0] == adj_mat.shape[1] and len(adj_mat.shape) == 2 |
|
|
|
paths = [] |
|
self_loops = set() |
|
adj_map = defaultdict(set) |
|
rev_adj_map = defaultdict(set) |
|
|
|
for c, n in adj_mat.detach().nonzero().tolist(): |
|
|
|
if c == n: |
|
self_loops.add(c) |
|
else: |
|
adj_map[c].add(n) |
|
|
|
rev_adj_map[n].add(c) |
|
for self_loop_node in self_loops: |
|
paths.append((self_loop_node,)) |
|
|
|
def track(path: tuple[int], c: int): |
|
visited: set[tuple[int]] = set() |
|
stack = [(path, c)] |
|
while stack: |
|
path, c = stack.pop() |
|
if c in adj_map: |
|
for n in adj_map[c]: |
|
if (c, n) in visited: |
|
continue |
|
visited.add((c, n)) |
|
stack.append((path + (c,), n)) |
|
|
|
if path: |
|
paths.append(path + (c,)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_nodes = set(adj_map.keys()) - set(rev_adj_map.keys()) |
|
for c in start_nodes: |
|
ns = adj_map[c] |
|
for n in ns: |
|
track((c,), n) |
|
|
|
return paths |
|
|
|
|
|
def encode_nnw_thw_mat( |
|
spans: list[tuple[int]], seq_len: int, nnw_id: int = 0, thw_id: int = 1 |
|
) -> torch.Tensor: |
|
mat = torch.zeros(2, seq_len, seq_len) |
|
for span in spans: |
|
if len(span) == 1: |
|
mat[:, span[0], span[0]] = 1 |
|
else: |
|
for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): |
|
mat[nnw_id, s, e] = 1 |
|
mat[thw_id, span[-1], span[0]] = 1 |
|
return mat |
|
|
|
|
|
def decode_nnw_thw_mat( |
|
batch_mat: torch.LongTensor, |
|
nnw_id: int = 0, |
|
thw_id: int = 1, |
|
offsets: list[int] = None, |
|
) -> list[list[tuple[int]]]: |
|
"""Decode NNW THW matrix into a list of spans |
|
|
|
Args: |
|
matrix: (batch_size, 2, seq_len, seq_len) |
|
""" |
|
ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape |
|
assert seq_len1 == seq_len2 |
|
assert cls_num == 2 |
|
|
|
result_batch = [] |
|
for ins_id in range(ins_num): |
|
offset = offsets[ins_id] if offsets else 0 |
|
ins_span_paths = [] |
|
|
|
ins_mat = batch_mat[ins_id] |
|
nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...]) |
|
end_start_to_paths = defaultdict(set) |
|
for path in nnw_paths: |
|
end_start_to_paths[(path[-1], path[0])].add(path) |
|
thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist() |
|
|
|
for e, s in thw_pairs: |
|
for path in end_start_to_paths[(e, s)]: |
|
ins_span_paths.append(tuple(i - offset for i in path)) |
|
result_batch.append(ins_span_paths) |
|
|
|
return result_batch |
|
|
|
|
|
def decode_pointer_mat( |
|
batch_mat: torch.LongTensor, offsets: list[int] = None |
|
) -> list[list[tuple[int]]]: |
|
batch_paths = [] |
|
for i in range(len(batch_mat)): |
|
offset = offsets[i] if offsets else 0 |
|
coordinates = (batch_mat[i, 0] == 1).nonzero().tolist() |
|
paths = [] |
|
for s, e in coordinates: |
|
path = tuple(range(s - offset, e + 1 - offset)) |
|
paths.append(path) |
|
batch_paths.append(paths) |
|
return batch_paths |
|
|
|
|
|
def encode_nnw_nsw_thw_mat( |
|
spans: list[list[tuple[int]]], |
|
seq_len: int, |
|
nnw_id: int = 0, |
|
nsw_id: int = 1, |
|
thw_id: int = 2, |
|
) -> torch.Tensor: |
|
mat = torch.zeros(3, seq_len, seq_len) |
|
for parts in spans: |
|
span = () |
|
for p_i, part in enumerate(parts): |
|
if not all(0 <= el <= seq_len - 1 for el in part): |
|
continue |
|
span += part |
|
if p_i < len(parts) - 1 and 0 <= parts[p_i + 1][0] <= seq_len - 1: |
|
|
|
mat[nsw_id, parts[p_i][-1], parts[p_i + 1][0]] = 1 |
|
if len(span) == 1: |
|
mat[:, span[0], span[0]] = 1 |
|
elif len(span) > 1: |
|
for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): |
|
mat[nnw_id, s, e] = 1 |
|
if span: |
|
mat[thw_id, span[-1], span[0]] = 1 |
|
return mat |
|
|
|
|
|
def split_tuple_by_positions(nums, positions) -> list: |
|
""" |
|
Examples: |
|
>>> nums = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) |
|
>>> positions = [2, 5, 7] |
|
>>> split_tuple_by_positions(nums, positions) |
|
((1, 2), (3, 4, 5), (6, 7), (8, 9, 10)) |
|
""" |
|
|
|
if not all(p < len(nums) for p in positions): |
|
raise ValueError("Invalid positions") |
|
|
|
|
|
positions = [0] + sorted(positions) + [len(nums)] |
|
|
|
|
|
result = [] |
|
for i in range(1, len(positions)): |
|
start = positions[i - 1] |
|
end = positions[i] |
|
result.append(nums[start:end]) |
|
|
|
return result |
|
|
|
|
|
def decode_nnw_nsw_thw_mat( |
|
batch_mat: torch.LongTensor, |
|
nnw_id: int = 0, |
|
nsw_id: int = 1, |
|
thw_id: int = 2, |
|
offsets: list[int] = None, |
|
) -> list[list[tuple[int]]]: |
|
"""Decode NNW NSW THW matrix into a list of spans |
|
One span has multiple parts |
|
|
|
Args: |
|
batch_mat: (batch_size, 3, seq_len, seq_len) |
|
""" |
|
ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape |
|
assert seq_len1 == seq_len2 |
|
assert cls_num == 3 |
|
|
|
result_batch = [] |
|
for ins_id in range(ins_num): |
|
offset = offsets[ins_id] if offsets else 0 |
|
ins_span_paths = set() |
|
|
|
ins_mat = batch_mat[ins_id] |
|
nsw_connections = { |
|
(part1e, part2s) |
|
for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist() |
|
} |
|
nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...]) |
|
end_start_to_paths = defaultdict(set) |
|
for path in nnw_paths: |
|
end_start_to_paths[(path[-1], path[0])].add(path) |
|
thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist() |
|
|
|
for e, s in thw_pairs: |
|
for path in nnw_paths: |
|
if s in path: |
|
sub_path = path[path.index(s) :] |
|
if e in sub_path: |
|
sub_path = sub_path[: sub_path.index(e) + 1] |
|
chain = tuple(i - offset for i in sub_path) |
|
parts = [] |
|
all_sep_positions = set() |
|
|
|
if len(chain) > 1: |
|
for sep in nsw_connections: |
|
sep = tuple(i - offset for i in sep) |
|
positions = find_all_positions(list(chain), list(sep)) |
|
if positions: |
|
|
|
|
|
positions = {p[0] + 1 for p in positions} |
|
all_sep_positions.update(positions) |
|
parts = split_tuple_by_positions(chain, all_sep_positions) |
|
if not parts: |
|
parts = [chain] |
|
ins_span_paths.add(tuple(parts)) |
|
result_batch.append(list(ins_span_paths)) |
|
|
|
return result_batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|