Mirror / src /utils.py
Spico's picture
update
5953ef9
from collections import defaultdict
import torch
from rex.utils.iteration import windowed_queue_iter
from rex.utils.position import find_all_positions
def find_paths_from_adj_mat(adj_mat: torch.Tensor) -> list[tuple[int]]:
assert adj_mat.shape[0] == adj_mat.shape[1] and len(adj_mat.shape) == 2
paths = []
self_loops = set()
adj_map = defaultdict(set)
rev_adj_map = defaultdict(set)
# current -> next
for c, n in adj_mat.detach().nonzero().tolist():
# self-loop
if c == n:
self_loops.add(c)
else:
adj_map[c].add(n)
# reversed map
rev_adj_map[n].add(c)
for self_loop_node in self_loops:
paths.append((self_loop_node,))
def track(path: tuple[int], c: int):
visited: set[tuple[int]] = set()
stack = [(path, c)]
while stack:
path, c = stack.pop()
if c in adj_map:
for n in adj_map[c]:
if (c, n) in visited:
continue
visited.add((c, n))
stack.append((path + (c,), n))
# else:
if path:
paths.append(path + (c,))
# def track(path: tuple[int], c: int, visited: set[tuple[int]]):
# if c in adj_map:
# for n in adj_map[c]:
# if (c, n) in visited:
# continue
# visited.add((c, n))
# track(path + (c,), n, visited)
# else:
# if path:
# paths.append(path + (c,))
# # # include loops
# # if path not in paths and all(not set(path).issubset(p) for p in paths):
# # paths.append(path)
start_nodes = set(adj_map.keys()) - set(rev_adj_map.keys())
for c in start_nodes:
ns = adj_map[c]
for n in ns:
track((c,), n)
return paths
def encode_nnw_thw_mat(
spans: list[tuple[int]], seq_len: int, nnw_id: int = 0, thw_id: int = 1
) -> torch.Tensor:
mat = torch.zeros(2, seq_len, seq_len)
for span in spans:
if len(span) == 1:
mat[:, span[0], span[0]] = 1
else:
for s, e in windowed_queue_iter(span, 2, 1, drop_last=True):
mat[nnw_id, s, e] = 1
mat[thw_id, span[-1], span[0]] = 1
return mat
def decode_nnw_thw_mat(
batch_mat: torch.LongTensor,
nnw_id: int = 0,
thw_id: int = 1,
offsets: list[int] = None,
) -> list[list[tuple[int]]]:
"""Decode NNW THW matrix into a list of spans
Args:
matrix: (batch_size, 2, seq_len, seq_len)
"""
ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape
assert seq_len1 == seq_len2
assert cls_num == 2
result_batch = []
for ins_id in range(ins_num):
offset = offsets[ins_id] if offsets else 0
ins_span_paths = []
# ins_mat: (2, seq_len, seq_len)
ins_mat = batch_mat[ins_id]
nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...])
end_start_to_paths = defaultdict(set)
for path in nnw_paths:
end_start_to_paths[(path[-1], path[0])].add(path)
thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist()
# reversed match, end -> start
for e, s in thw_pairs:
for path in end_start_to_paths[(e, s)]:
ins_span_paths.append(tuple(i - offset for i in path))
result_batch.append(ins_span_paths)
return result_batch
def decode_pointer_mat(
batch_mat: torch.LongTensor, offsets: list[int] = None
) -> list[list[tuple[int]]]:
batch_paths = []
for i in range(len(batch_mat)):
offset = offsets[i] if offsets else 0
coordinates = (batch_mat[i, 0] == 1).nonzero().tolist()
paths = []
for s, e in coordinates:
path = tuple(range(s - offset, e + 1 - offset))
paths.append(path)
batch_paths.append(paths)
return batch_paths
def encode_nnw_nsw_thw_mat(
spans: list[list[tuple[int]]],
seq_len: int,
nnw_id: int = 0,
nsw_id: int = 1,
thw_id: int = 2,
) -> torch.Tensor:
mat = torch.zeros(3, seq_len, seq_len)
for parts in spans:
span = ()
for p_i, part in enumerate(parts):
if not all(0 <= el <= seq_len - 1 for el in part):
continue
span += part
if p_i < len(parts) - 1 and 0 <= parts[p_i + 1][0] <= seq_len - 1:
# current part to next part
mat[nsw_id, parts[p_i][-1], parts[p_i + 1][0]] = 1
if len(span) == 1:
mat[:, span[0], span[0]] = 1
elif len(span) > 1:
for s, e in windowed_queue_iter(span, 2, 1, drop_last=True):
mat[nnw_id, s, e] = 1
if span:
mat[thw_id, span[-1], span[0]] = 1
return mat
def split_tuple_by_positions(nums, positions) -> list:
"""
Examples:
>>> nums = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
>>> positions = [2, 5, 7]
>>> split_tuple_by_positions(nums, positions)
((1, 2), (3, 4, 5), (6, 7), (8, 9, 10))
"""
# Check if the given positions are valid
if not all(p < len(nums) for p in positions):
raise ValueError("Invalid positions")
# Add 0 and len(nums) to the list of positions
positions = [0] + sorted(positions) + [len(nums)]
# Split the tuple into multiple tuples based on the positions
result = []
for i in range(1, len(positions)):
start = positions[i - 1]
end = positions[i]
result.append(nums[start:end])
return result
def decode_nnw_nsw_thw_mat(
batch_mat: torch.LongTensor,
nnw_id: int = 0,
nsw_id: int = 1,
thw_id: int = 2,
offsets: list[int] = None,
) -> list[list[tuple[int]]]:
"""Decode NNW NSW THW matrix into a list of spans
One span has multiple parts
Args:
batch_mat: (batch_size, 3, seq_len, seq_len)
"""
ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape
assert seq_len1 == seq_len2
assert cls_num == 3
result_batch = []
for ins_id in range(ins_num):
offset = offsets[ins_id] if offsets else 0
ins_span_paths = set()
# ins_mat: (2, seq_len, seq_len)
ins_mat = batch_mat[ins_id]
nsw_connections = {
(part1e, part2s)
for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist()
}
nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...])
end_start_to_paths = defaultdict(set)
for path in nnw_paths:
end_start_to_paths[(path[-1], path[0])].add(path)
thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist()
# reversed match, end -> start
for e, s in thw_pairs:
for path in nnw_paths:
if s in path:
sub_path = path[path.index(s) :]
if e in sub_path:
sub_path = sub_path[: sub_path.index(e) + 1]
chain = tuple(i - offset for i in sub_path)
parts = []
all_sep_positions = set()
# cut path into multiple spans if there are skip links
if len(chain) > 1:
for sep in nsw_connections:
sep = tuple(i - offset for i in sep)
positions = find_all_positions(list(chain), list(sep))
if positions:
# +1: (5, 6, 269) with (6, 269) as sep, found position is 1,
# while we want to split after 6, which needs +1
positions = {p[0] + 1 for p in positions}
all_sep_positions.update(positions)
parts = split_tuple_by_positions(chain, all_sep_positions)
if not parts:
parts = [chain]
ins_span_paths.add(tuple(parts))
result_batch.append(list(ins_span_paths))
return result_batch
# def encode_nnw_nsw_thw_mat(
# spans: list[list[tuple[int]]],
# seq_len: int,
# nnw_id: int = 0,
# nsw_id: int = 1,
# thw_id: int = 2,
# ) -> torch.Tensor:
# mat = torch.zeros(3, seq_len, seq_len)
# for span in spans:
# for p_i, part in enumerate(span):
# if len(part) == 1:
# mat[:, part[0], part[0]] = 1
# else:
# for s, e in windowed_queue_iter(part, 2, 1, drop_last=True):
# mat[nnw_id, s, e] = 1
# if p_i < len(span) - 1:
# # current part to next part
# mat[nsw_id, span[p_i][-1], span[p_i + 1][0]] = 1
# mat[thw_id, span[-1][-1], span[0][0]] = 1
# return mat
# def decode_nnw_nsw_thw_mat(
# batch_mat: torch.LongTensor,
# nnw_id: int = 0,
# nsw_id: int = 1,
# thw_id: int = 2,
# offsets: list[int] = None,
# ) -> list[list[tuple[int]]]:
# """Decode NNW NSW THW matrix into a list of spans
# One span has multiple parts
# Args:
# batch_mat: (batch_size, 3, seq_len, seq_len)
# """
# ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape
# assert seq_len1 == seq_len2
# assert cls_num == 2
# result_batch = []
# for ins_id in range(ins_num):
# offset = offsets[ins_id] if offsets else 0
# ins_span_paths = []
# # ins_mat: (3, seq_len, seq_len)
# ins_mat = batch_mat[ins_id]
# nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...])
# path_index = {"s": defaultdict(set), "e": defaultdict(set)}
# for path in nnw_paths:
# s = path[0]
# e = path[-1]
# path_index["s"][s].add(path)
# path_index["e"][e].add(path)
# nsw_connections = {(part1e, part2s) for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist()}
# thw_connections = {(span_e, span_s) for span_e, span_s in ins_mat[thw_id, ...].detach().nonzero().tolist()}
# for e, s in thw_connections:
# path_span_combinations = []
# for part1_e, part2_s in nsw_connections:
# part1s = path_index["e"][part1_e]
# part2s = path_index["s"][part2_s]
# # for part1 in part1s:
# # for part2 in part2s:
# # if ()
# end_start_to_paths = defaultdict(set)
# for path in nnw_paths:
# end_start_to_paths[(path[-1], path[0])].add(path)
# # reversed match, end -> start
# for e, s in thw_pairs:
# for path in end_start_to_paths[(e, s)]:
# ins_span_paths.append(tuple(i - offset for i in path))
# result_batch.append(ins_span_paths)
# return result_batch