File size: 2,105 Bytes
cc9c7ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import argparse
from omegaconf import OmegaConf


def binary_intersection(lst1, lst2):
  lst3 = list(set([value for value in lst1 if value in lst2]))
  return lst3


def binary_union(lst1, lst2):
  lst3 = list(set(lst1 + lst2))
  return lst3


def combine(files, type="union"):
  text = {}
  if type == "union":
    fn = binary_union
  else:
    fn = binary_intersection
  for fil in files:
    with open(fil, "r") as f:
      for line in f:
        line_split = line.split("\t")
        if int(line_split[0]) in text:
          text[int(line_split[0])] = fn(
            text[int(line_split[0])], eval(line_split[1])
          )
        else:
          text[int(line_split[0])] = eval(line_split[1])
  return text


def combine_I(files, type="union"):
  text = {}
  if type == "union":
    fn = binary_union
  else:
    fn = binary_intersection
  for fil in files:
    with open(fil, "r") as f:
      for line in f:
        line_split = line.split("\t")
        if int(line_split[0]) in text:
          text[int(line_split[0])] = fn(
            text[int(line_split[0])], eval(line_split[2])
          )
        else:
          text[int(line_split[0])] = eval(line_split[2])
  return text


def write_dict_to_file(text, text_I, path):
  with open(path, "w") as f:
    for id, spans in text.items():
      # if id != len(text) - 1:
      if 1:
        f.write(f"{id}\t{str(spans)}\t{str(text_I[id])}\n")
      # else:
      #     f.write(f"{id}\t{str(spans)}")


if __name__ == "__main__":
  parser = argparse.ArgumentParser(
    prog="combine_preds.py", description="Combine span predictions."
  )
  parser.add_argument(
    "--config",
    type=str,
    action="store",
    help="The configuration for combining predictions.",
  )
  args = parser.parse_args()
  combine_config = OmegaConf.load(args.config)
  text = combine(combine_config.files, combine_config.type)
  text_I = combine_I(combine_config.files, combine_config.type)

  dir = "/".join(combine_config.path.split("/")[:-1])
  if not os.path.exists(dir):
    os.makedirs(dir)
  write_dict_to_file(text, text_I, combine_config.path)