Simon Clematide
commited on
Commit
·
fc83ec7
1
Parent(s):
feeca2c
Initial commit with models, scripts, and JAR files
Browse files- .gitattributes +4 -0
- lib/__init__.py +0 -0
- lib/mallet2topic_assignment_jsonl.py +346 -0
- lib/mallet_topic_inferencer.py +739 -0
- mallet/lib/mallet-deps.jar +3 -0
- mallet/lib/mallet.jar +3 -0
- models/tm/tm-de-all-v2.0.inferencer +3 -0
- models/tm/tm-de-all-v2.0.pipe +3 -0
- models/tm/tm-de-all-v2.0.vocab.lemmatization.tsv.gz +3 -0
- models/tm/tm-fr-all-v2.0.inferencer +3 -0
- models/tm/tm-fr-all-v2.0.pipe +3 -0
- models/tm/tm-fr-all-v2.0.vocab.lemmatization.tsv.gz +3 -0
- models/tm/tm-lb-all-v2.0.inferencer +3 -0
- models/tm/tm-lb-all-v2.0.pipe +3 -0
- models/tm/tm-lb-all-v2.0.vocab.lemmatization.tsv.gz +3 -0
- requirements.txt +45 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jar filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.inferencer filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.pipe filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
lib/__init__.py
ADDED
File without changes
|
lib/mallet2topic_assignment_jsonl.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
|
4 |
+
Typical output of the script:
|
5 |
+
{"topic_model":"tm-fr-all-v2.0","topic_count":100,"lang":"fr","ci_ref":"actionfem-1936-02-15-a-i0022","topics":[],"min_p":0.02}
|
6 |
+
|
7 |
+
|
8 |
+
{
|
9 |
+
"topic_count": 100,
|
10 |
+
"lang": "de",
|
11 |
+
"topics": [
|
12 |
+
{"t": "tm-de-all-v2.0_tp02_de", "p": 0.027},
|
13 |
+
{"t": "tm-de-all-v2.0_tp11_de", "p": 0.119},
|
14 |
+
{"t": "tm-de-all-v2.0_tp26_de", "p": 0.045}
|
15 |
+
],
|
16 |
+
"min_p": 0.02,
|
17 |
+
"ts": "2024.08.29",
|
18 |
+
"id": "actionfem-1927-12-15-a-i0001",
|
19 |
+
"sys_id": "tm-de-all-v2.0"
|
20 |
+
}
|
21 |
+
"""
|
22 |
+
import datetime
|
23 |
+
import logging
|
24 |
+
import argparse
|
25 |
+
import traceback
|
26 |
+
import math
|
27 |
+
import json
|
28 |
+
import re
|
29 |
+
import collections
|
30 |
+
from typing import Generator, List, Dict, Any, Optional
|
31 |
+
from smart_open import open
|
32 |
+
|
33 |
+
|
34 |
+
CI_ID_REGEX = re.compile(r"^(.+?/)?([^/]+?-\d{4}-\d{2}-\d{2}-\w-i\d{4})[^/]*$")
|
35 |
+
|
36 |
+
|
37 |
+
class Mallet2TopicAssignment:
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
args: Optional[argparse.Namespace] = None,
|
41 |
+
topic_assignment_threshold: Optional[float] = None,
|
42 |
+
lang: Optional[str] = None,
|
43 |
+
topic_model: Optional[str] = None,
|
44 |
+
numeric_topic_ids: Optional[bool] = None,
|
45 |
+
format_type: Optional[str] = None,
|
46 |
+
topic_count: Optional[int] = None,
|
47 |
+
output: Optional[str] = None,
|
48 |
+
) -> None:
|
49 |
+
|
50 |
+
self.eps = args.topic_assignment_threshold
|
51 |
+
self.lang = args.lang
|
52 |
+
self.topic_model = args.topic_model
|
53 |
+
self.numeric_topic_ids = args.numeric_topic_ids
|
54 |
+
self.format_type = args.format_type.lower() # Normalize case
|
55 |
+
self.topic_count = args.topic_count
|
56 |
+
self.output = args.output
|
57 |
+
self.args = args # Ensure we keep the args namespace
|
58 |
+
|
59 |
+
self.validate_options()
|
60 |
+
|
61 |
+
self.precision = math.ceil(abs(math.log10(self.eps))) + 1
|
62 |
+
self.padding_length = math.ceil(math.log10(self.topic_count))
|
63 |
+
self.topic_id_format = (
|
64 |
+
f"{self.topic_model}_tp{{t:0{self.padding_length}d}}_{self.lang}"
|
65 |
+
)
|
66 |
+
self.last_timestamp = (
|
67 |
+
datetime.datetime.now(tz=datetime.timezone.utc)
|
68 |
+
.replace(microsecond=0)
|
69 |
+
.isoformat()
|
70 |
+
+ "Z"
|
71 |
+
)
|
72 |
+
|
73 |
+
def validate_options(self) -> None:
|
74 |
+
if self.eps <= 0 or self.eps >= 1:
|
75 |
+
raise ValueError("topic_assignment_threshold must be between 0 and 1.")
|
76 |
+
if self.format_type == "sparse" and not self.topic_count:
|
77 |
+
raise ValueError(
|
78 |
+
"The --topic_count option is required when using the 'sparse' format."
|
79 |
+
)
|
80 |
+
|
81 |
+
def read_tsv_files(self, filenames: List[str]) -> Generator[List[str], None, None]:
|
82 |
+
|
83 |
+
for filename in filenames:
|
84 |
+
yield from self.read_tsv_file(filename)
|
85 |
+
|
86 |
+
def read_tsv_file(self, filename: str) -> Generator[List[str], None, None]:
|
87 |
+
line_count = 0
|
88 |
+
with open(filename, "r", encoding="utf-8") as file:
|
89 |
+
for line in file:
|
90 |
+
line_count += 1
|
91 |
+
if not line.startswith("#"):
|
92 |
+
yield line.strip().split("\t")
|
93 |
+
if line_count % 1000 == 0:
|
94 |
+
logging.info("Processed lines: %s", line_count)
|
95 |
+
|
96 |
+
def convert_matrix_row(self, row: List[str]) -> Dict[str, Any]:
|
97 |
+
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1])
|
98 |
+
topics = row[2:]
|
99 |
+
topic_count = len(topics)
|
100 |
+
if self.numeric_topic_ids:
|
101 |
+
topics = [
|
102 |
+
{"t": t, "p": round(fp, self.precision)}
|
103 |
+
for t, p in enumerate(topics)
|
104 |
+
if (fp := float(p)) >= self.eps
|
105 |
+
]
|
106 |
+
else:
|
107 |
+
topics = [
|
108 |
+
{
|
109 |
+
"t": self.topic_id_format.format(t=t),
|
110 |
+
"p": round(fp, self.precision),
|
111 |
+
}
|
112 |
+
for t, p in enumerate(topics)
|
113 |
+
if (fp := float(p)) >= self.eps
|
114 |
+
]
|
115 |
+
|
116 |
+
return {
|
117 |
+
"ci_id": ci_id,
|
118 |
+
"model_id": self.topic_model,
|
119 |
+
"lang": self.lang,
|
120 |
+
"topic_count": topic_count,
|
121 |
+
"topics": topics,
|
122 |
+
"min_p": self.eps,
|
123 |
+
"ts": self.last_timestamp,
|
124 |
+
}
|
125 |
+
|
126 |
+
def convert_sparse_row(self, row: List[str]) -> Dict[str, Any]:
|
127 |
+
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1])
|
128 |
+
topic_pairs = row[2:]
|
129 |
+
topics = []
|
130 |
+
for i in range(0, len(topic_pairs), 2):
|
131 |
+
t = int(topic_pairs[i])
|
132 |
+
p = float(topic_pairs[i + 1])
|
133 |
+
if p >= self.eps:
|
134 |
+
if self.numeric_topic_ids:
|
135 |
+
topics.append(
|
136 |
+
{
|
137 |
+
"t": t,
|
138 |
+
"p": round(p, math.ceil(abs(math.log10(self.eps))) + 1),
|
139 |
+
}
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
topics.append(
|
143 |
+
{
|
144 |
+
"t": self.topic_id_format.format(t=t),
|
145 |
+
"p": round(p, math.ceil(abs(math.log10(self.eps))) + 1),
|
146 |
+
}
|
147 |
+
)
|
148 |
+
|
149 |
+
return {
|
150 |
+
"ci_id": ci_id,
|
151 |
+
"model_id": self.topic_model,
|
152 |
+
"lang": self.lang,
|
153 |
+
"topic_count": self.topic_count,
|
154 |
+
"topics": topics,
|
155 |
+
"min_p": self.eps,
|
156 |
+
"ts": self.last_timestamp,
|
157 |
+
}
|
158 |
+
|
159 |
+
def parse_mallet_files(
|
160 |
+
self, filenames: List[str]
|
161 |
+
) -> Generator[Dict[str, Any], None, None]:
|
162 |
+
"""
|
163 |
+
Process the Mallet topic word weights from multiple files and yield topic assignments in JSON format.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
filenames (List[str]): List of paths to the input files.
|
167 |
+
|
168 |
+
Yields:
|
169 |
+
Dict[str, Any]: Parsed topic assignment from each line in the input files.
|
170 |
+
"""
|
171 |
+
ci_id_stats = collections.Counter()
|
172 |
+
if self.format_type == "sparse":
|
173 |
+
convert_row = self.convert_sparse_row
|
174 |
+
elif self.format_type == "matrix":
|
175 |
+
convert_row = self.convert_matrix_row
|
176 |
+
else:
|
177 |
+
raise ValueError(f"Invalid format type: {self.format_type}")
|
178 |
+
|
179 |
+
for row in self.read_tsv_files(filenames):
|
180 |
+
ci_id = re.sub(CI_ID_REGEX, r"\2", row[1])
|
181 |
+
if ci_id in ci_id_stats:
|
182 |
+
ci_id_stats["DUPLICATE_COUNT"] += 1
|
183 |
+
continue
|
184 |
+
ci_id_stats[ci_id] = 1
|
185 |
+
|
186 |
+
yield convert_row(row)
|
187 |
+
|
188 |
+
logging.info("DUPLICATE-COUNT: %d", ci_id_stats["DUPLICATE_COUNT"])
|
189 |
+
|
190 |
+
def run(self) -> Optional[Generator[Dict[str, Any], None, None]]:
|
191 |
+
"""
|
192 |
+
Main method to process the input files based on the command line arguments.
|
193 |
+
Returns a generator if output is set to '<generator>', otherwise writes to a file.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Optional[Generator[Dict[str, Any], None, None]]: A generator for topic assignments
|
197 |
+
if output is set to '<generator>', otherwise None.
|
198 |
+
"""
|
199 |
+
if self.output == "<generator>":
|
200 |
+
# Return a generator if the output is set to '<generator>'
|
201 |
+
return self.parse_mallet_files(self.args.INPUT_FILES)
|
202 |
+
|
203 |
+
try:
|
204 |
+
with open(self.output, "w", encoding="utf-8") as out_file:
|
205 |
+
for topic_assignment in self.parse_mallet_files(self.args.INPUT_FILES):
|
206 |
+
out_file.write(
|
207 |
+
json.dumps(
|
208 |
+
topic_assignment, ensure_ascii=False, separators=(",", ":")
|
209 |
+
)
|
210 |
+
+ "\n"
|
211 |
+
)
|
212 |
+
except Exception as e:
|
213 |
+
logging.error(f"An error occurred: {e}")
|
214 |
+
logging.error("Traceback: %s", traceback.format_exc())
|
215 |
+
exit(1)
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def setup_logging(options: argparse.Namespace) -> None:
|
219 |
+
"""
|
220 |
+
Set up logging configuration based on command line options.
|
221 |
+
"""
|
222 |
+
log_level = logging.DEBUG if options.debug else logging.INFO
|
223 |
+
logging.basicConfig(
|
224 |
+
level=log_level, filename=options.logfile if options.logfile else None
|
225 |
+
)
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def main(
|
229 |
+
args: Optional[List[str]],
|
230 |
+
) -> Optional[Generator[Dict[str, Any], None, None]]:
|
231 |
+
"""
|
232 |
+
Static method serving as the entry point of the script.
|
233 |
+
If the output option is set to '<generator>', it returns a Python generator
|
234 |
+
for topic assignments, otherwise prints results or writes to a file.
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
Optional[Generator[Dict[str, Any], None, None]]: Generator for topic assignments
|
238 |
+
if output is set to '<generator>', otherwise None.
|
239 |
+
"""
|
240 |
+
parser = argparse.ArgumentParser(
|
241 |
+
usage="%(prog)s [OPTIONS] INPUT [INPUT ...]",
|
242 |
+
description=(
|
243 |
+
"Return topic assignments from mallet textual topic modeling output."
|
244 |
+
),
|
245 |
+
epilog="Contact [email protected] for more information.",
|
246 |
+
)
|
247 |
+
|
248 |
+
parser.add_argument("--version", action="version", version="2024.10.23")
|
249 |
+
parser.add_argument(
|
250 |
+
"-l", "--logfile", help="Write log information to FILE", metavar="FILE"
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"-q",
|
254 |
+
"--quiet",
|
255 |
+
action="store_true",
|
256 |
+
help="Do not print status messages to stderr",
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"-d", "--debug", action="store_true", help="Print debug information"
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"-L",
|
263 |
+
"--lang",
|
264 |
+
"--language",
|
265 |
+
default="und",
|
266 |
+
help="ISO 639 language code two-letter or 'und' for undefined",
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"-M",
|
270 |
+
"--topic_model",
|
271 |
+
default="tm000",
|
272 |
+
help="Topic model identifier, e.g., tm001",
|
273 |
+
)
|
274 |
+
parser.add_argument(
|
275 |
+
"-N",
|
276 |
+
"--numeric_topic_ids",
|
277 |
+
action="store_true",
|
278 |
+
help="Use numeric topic IDs in the topic assignment",
|
279 |
+
)
|
280 |
+
parser.add_argument(
|
281 |
+
"-T",
|
282 |
+
"--topic_assignment_threshold",
|
283 |
+
type=float,
|
284 |
+
default=0.02,
|
285 |
+
help="Minimum probability for inclusion in the output",
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"-F",
|
289 |
+
"--format_type",
|
290 |
+
choices=["matrix", "sparse"],
|
291 |
+
default="matrix",
|
292 |
+
help="Format of the input file: 'matrix' or 'sparse'",
|
293 |
+
)
|
294 |
+
parser.add_argument(
|
295 |
+
"-C",
|
296 |
+
"--topic_count",
|
297 |
+
type=int,
|
298 |
+
help="Needed for formatting ",
|
299 |
+
required=True,
|
300 |
+
)
|
301 |
+
parser.add_argument(
|
302 |
+
"-o",
|
303 |
+
"--output",
|
304 |
+
help=(
|
305 |
+
"Path to the output file (%(default)s). If set to '<generator>' it will"
|
306 |
+
" return a generator that can be used to enumerate all results in a"
|
307 |
+
" flexible way. "
|
308 |
+
),
|
309 |
+
default="/dev/stdout",
|
310 |
+
)
|
311 |
+
|
312 |
+
parser.add_argument(
|
313 |
+
"--level",
|
314 |
+
default="INFO",
|
315 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
316 |
+
help="Set the logging level. Default: %(default)s",
|
317 |
+
)
|
318 |
+
parser.add_argument(
|
319 |
+
"INPUT_FILES", nargs="+", help="One or more input files to process."
|
320 |
+
)
|
321 |
+
|
322 |
+
options = parser.parse_args(args=args)
|
323 |
+
|
324 |
+
# Configure logging
|
325 |
+
Mallet2TopicAssignment.setup_logging(options)
|
326 |
+
|
327 |
+
# Validate specific arguments
|
328 |
+
if options.format_type == "sparse" and not options.topic_count:
|
329 |
+
parser.error(
|
330 |
+
"The --topic_count option is required when using the 'sparse' format"
|
331 |
+
)
|
332 |
+
|
333 |
+
# Create the application instance
|
334 |
+
app = Mallet2TopicAssignment(args=options)
|
335 |
+
|
336 |
+
# Check if output is set to '<generator>' and return a generator if so
|
337 |
+
if options.output == "<generator>":
|
338 |
+
return app.run()
|
339 |
+
|
340 |
+
# Otherwise, run normally (output to file or stdout)
|
341 |
+
app.run()
|
342 |
+
return None
|
343 |
+
|
344 |
+
|
345 |
+
if __name__ == "__main__":
|
346 |
+
Mallet2TopicAssignment.main()
|
lib/mallet_topic_inferencer.py
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
DOCUMENTATION: This script performs vectorization and topic inference using Mallet models. It accepts a raw JSONL file,
|
5 |
+
identifies the language of the text, and applies the corresponding Mallet model for topic inference. It also supports
|
6 |
+
other input formats through a flexible InputReader abstraction (e.g., CSV, JSONL).
|
7 |
+
|
8 |
+
The benefit of this script with respect to the Mallet CLI is that it can handle
|
9 |
+
multiple languages in a single run without calling Mallet multiple times.
|
10 |
+
|
11 |
+
Classes:
|
12 |
+
- MalletVectorizer: Handles text-to-Mallet vectorization.
|
13 |
+
- LanguageInferencer: Performs topic inference using a Mallet inferencer and
|
14 |
+
the vectorizer.
|
15 |
+
- InputReader (abstract class): Defines the interface for reading input
|
16 |
+
documents.
|
17 |
+
- JsonlInputReader: Reads input from JSONL files.
|
18 |
+
- CsvInputReader: Reads input from CSV files (Mallet format).
|
19 |
+
- MalletTopicInferencer: Coordinates the process, identifies language, and manages
|
20 |
+
inference.
|
21 |
+
|
22 |
+
USAGE:
|
23 |
+
python mallet_topic_inferencer.py --input input.jsonl --output output.txt
|
24 |
+
--logfile logfile.log --input-format jsonl --level INFO --num_iterations 1000
|
25 |
+
--languages de,en --de_inferencer models/de.inferencer --de_pipe models/de.pipe
|
26 |
+
"""
|
27 |
+
|
28 |
+
import collections
|
29 |
+
import traceback
|
30 |
+
import jpype
|
31 |
+
import jpype.imports
|
32 |
+
|
33 |
+
import spacy
|
34 |
+
|
35 |
+
# from jpype.types import JString
|
36 |
+
import os
|
37 |
+
import logging
|
38 |
+
import argparse
|
39 |
+
import json
|
40 |
+
import csv
|
41 |
+
import tempfile
|
42 |
+
from typing import List, Dict, Generator, Tuple, Optional, Set
|
43 |
+
from abc import ABC, abstractmethod
|
44 |
+
import mallet2topic_assignment_jsonl as m2taj
|
45 |
+
from smart_open import open
|
46 |
+
|
47 |
+
log = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
def save_text_as_csv(text: str) -> str:
|
51 |
+
"""
|
52 |
+
Save the given text as a temporary CSV file with an arbitrary ID and return the file name.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
text (str): The text to be saved in the CSV file.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
str: The name of the temporary CSV file.
|
59 |
+
"""
|
60 |
+
# Create a temporary file with .csv suffix
|
61 |
+
temp_csv_file = tempfile.NamedTemporaryFile(
|
62 |
+
delete=False, mode="w", suffix=".csv", newline="", encoding="utf-8"
|
63 |
+
)
|
64 |
+
|
65 |
+
# Write the text to the CSV file with an arbitrary ID
|
66 |
+
csv_writer = csv.writer(temp_csv_file, delimiter="\t")
|
67 |
+
csv_writer.writerow(["ID", "DUMMYCLASS", "TEXT"]) # Header
|
68 |
+
csv_writer.writerow(["USERINPUT-2024-10-24-a-i0042", "dummy_class", text])
|
69 |
+
|
70 |
+
# Close the file to ensure all data is written
|
71 |
+
temp_csv_file.close()
|
72 |
+
|
73 |
+
return temp_csv_file.name
|
74 |
+
|
75 |
+
|
76 |
+
class Lemmatizer:
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
languages_dict: Dict[str, str],
|
80 |
+
lang_lemmatization_dict: Dict[str, Dict[str, str]],
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Initializes the linguistic lemmatizer with specified languages and lemmatization dictionary.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
languages (List[str]): List of language codes to load processing pipelines for.
|
87 |
+
lemmatization_dict (Dict[str, str]): Dictionary mapping tokens to their lemmas.
|
88 |
+
"""
|
89 |
+
self.languages_dict = languages_dict
|
90 |
+
self.lemmatization_dict = lang_lemmatization_dict
|
91 |
+
self.language_processors = self._load_language_processors(languages_dict)
|
92 |
+
|
93 |
+
def _load_language_processors(
|
94 |
+
self, languages_dict
|
95 |
+
) -> Dict[str, spacy.language.Language]:
|
96 |
+
"""
|
97 |
+
Loads spacy language processors for the specified languages.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
Dict[str, spacy.language.Language]: Dictionary mapping language codes to spacy NLP pipelines.
|
101 |
+
"""
|
102 |
+
|
103 |
+
processors = {}
|
104 |
+
for lang in languages_dict:
|
105 |
+
processors[lang] = spacy.load(
|
106 |
+
languages_dict[lang], disable=["parser", "ner"]
|
107 |
+
)
|
108 |
+
processors[lang].add_pipe("sentencizer")
|
109 |
+
return processors
|
110 |
+
|
111 |
+
def analyze_text(self, text: str, lang: str) -> List[str]:
|
112 |
+
"""
|
113 |
+
Analyzes text, performing tokenization, POS tagging, and lemma mapping.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
text (str): Text to process.
|
117 |
+
lang (str): Language code for the text.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
List[str]: List of tokens that have matching entries in the lemmatization dictionary.
|
121 |
+
"""
|
122 |
+
if lang not in self.language_processors:
|
123 |
+
raise ValueError(f"No processing pipeline for language '{lang}'")
|
124 |
+
|
125 |
+
nlp = self.language_processors[lang]
|
126 |
+
doc = nlp(text)
|
127 |
+
token2lemma = self.lemmatization_dict[lang]
|
128 |
+
matched_tokens = [
|
129 |
+
lemma for tok in doc if (lemma := token2lemma.get(tok.text.lower()))
|
130 |
+
]
|
131 |
+
return matched_tokens
|
132 |
+
|
133 |
+
|
134 |
+
# ==================== Vectorization ====================
|
135 |
+
|
136 |
+
|
137 |
+
class MalletVectorizer:
|
138 |
+
"""
|
139 |
+
Handles the vectorization of multiple documents into a format usable by Mallet using the pipe file from the model.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, language: str, pipe_file: str) -> None:
|
143 |
+
|
144 |
+
# noinspection PyUnresolvedReferences
|
145 |
+
from cc.mallet.classify.tui import Csv2Vectors # type: ignore # Import after JVM is started
|
146 |
+
|
147 |
+
self.vectorizer = Csv2Vectors()
|
148 |
+
|
149 |
+
self.pipe_file = pipe_file
|
150 |
+
self.language = language
|
151 |
+
|
152 |
+
def run_csv2vectors(
|
153 |
+
self,
|
154 |
+
input_file: str,
|
155 |
+
output_file: Optional[str] = None,
|
156 |
+
delete_input_file_after: bool = True,
|
157 |
+
) -> str:
|
158 |
+
"""
|
159 |
+
Run Csv2Vectors to vectorize the input file.
|
160 |
+
|
161 |
+
Simple java-internal command line interface to the Csv2Vectors class in Mallet.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
input_file: Path to the csv input file to be vectorized.
|
165 |
+
output_file: Path where the output .mallet file should be saved.
|
166 |
+
"""
|
167 |
+
|
168 |
+
if not output_file:
|
169 |
+
output_file = input_file + ".mallet"
|
170 |
+
|
171 |
+
# Arguments for Csv2Vectors java main class
|
172 |
+
arguments = [
|
173 |
+
"--input",
|
174 |
+
input_file,
|
175 |
+
"--output",
|
176 |
+
output_file,
|
177 |
+
"--keep-sequence", # Keep sequence for feature extraction
|
178 |
+
"--encoding",
|
179 |
+
"UTF-8",
|
180 |
+
"--use-pipe-from",
|
181 |
+
self.pipe_file,
|
182 |
+
]
|
183 |
+
|
184 |
+
logging.info("Calling mallet Csv2Vector: %s", arguments)
|
185 |
+
self.vectorizer.main(arguments)
|
186 |
+
|
187 |
+
logging.debug("Csv2Vector call finished.")
|
188 |
+
if log.getEffectiveLevel() != logging.DEBUG and delete_input_file_after:
|
189 |
+
os.remove(input_file)
|
190 |
+
logging.info("Cleaning up input file: %s", input_file)
|
191 |
+
return output_file
|
192 |
+
|
193 |
+
|
194 |
+
class LanguageInferencer:
|
195 |
+
"""
|
196 |
+
A class to manage Mallet inferencing for a specific language.
|
197 |
+
Loads the inferencer and pipe file during initialization.
|
198 |
+
"""
|
199 |
+
|
200 |
+
def __init__(self, language: str, inferencer_file: str, pipe_file: str) -> None:
|
201 |
+
|
202 |
+
# noinspection PyUnresolvedReferences
|
203 |
+
from cc.mallet.topics.tui import InferTopics # type: ignore # Import after JVM is started
|
204 |
+
|
205 |
+
self.language = language
|
206 |
+
self.inferencer_file = inferencer_file
|
207 |
+
self.inferencer = InferTopics()
|
208 |
+
self.pipe_file = pipe_file
|
209 |
+
self.vectorizer = MalletVectorizer(language=language, pipe_file=self.pipe_file)
|
210 |
+
|
211 |
+
if not os.path.exists(self.inferencer_file):
|
212 |
+
raise FileNotFoundError(
|
213 |
+
f"Inferencer file not found: {self.inferencer_file}"
|
214 |
+
)
|
215 |
+
|
216 |
+
def run_csv2topics(
|
217 |
+
self, csv_file: str, delete_mallet_file_after: bool = True
|
218 |
+
) -> Dict[str, str]:
|
219 |
+
"""
|
220 |
+
Perform topic inference on a single input file.
|
221 |
+
The input file should be in the format expected by Mallet.
|
222 |
+
Returns a dictionary of document_id -> topic distributions.
|
223 |
+
"""
|
224 |
+
|
225 |
+
# Vectorize the input file and write to a temporary file
|
226 |
+
|
227 |
+
mallet_file = self.vectorizer.run_csv2vectors(csv_file)
|
228 |
+
|
229 |
+
topics_file = mallet_file + ".doctopics"
|
230 |
+
|
231 |
+
arguments = [
|
232 |
+
"--input",
|
233 |
+
mallet_file,
|
234 |
+
"--inferencer",
|
235 |
+
self.inferencer_file,
|
236 |
+
"--output-doc-topics",
|
237 |
+
topics_file,
|
238 |
+
"--random-seed",
|
239 |
+
"42",
|
240 |
+
]
|
241 |
+
|
242 |
+
logging.info("Calling mallet InferTopics: %s", arguments)
|
243 |
+
|
244 |
+
self.inferencer.main(arguments)
|
245 |
+
logging.debug("InferTopics call finished.")
|
246 |
+
|
247 |
+
if log.getEffectiveLevel() != logging.DEBUG and delete_mallet_file_after:
|
248 |
+
os.remove(mallet_file)
|
249 |
+
logging.info("Cleaning up input file: %s", mallet_file)
|
250 |
+
|
251 |
+
return topics_file
|
252 |
+
|
253 |
+
|
254 |
+
# ==================== Input Reader Abstraction ====================
|
255 |
+
|
256 |
+
|
257 |
+
class InputReader(ABC):
|
258 |
+
"""
|
259 |
+
Abstract base class for input readers.
|
260 |
+
Subclasses should implement the `read_documents` method to yield documents.
|
261 |
+
"""
|
262 |
+
|
263 |
+
@abstractmethod
|
264 |
+
def read_documents(self) -> Generator[Tuple[str, str], None, None]:
|
265 |
+
"""
|
266 |
+
Yields a tuple of (document_id, text).
|
267 |
+
Each implementation should handle its specific input format.
|
268 |
+
"""
|
269 |
+
pass
|
270 |
+
|
271 |
+
|
272 |
+
class JsonlInputReader(InputReader):
|
273 |
+
"""
|
274 |
+
Reads input from a JSONL file, where each line contains a JSON object
|
275 |
+
with at least "id" and "text" fields.
|
276 |
+
"""
|
277 |
+
|
278 |
+
def __init__(self, input_file: str) -> None:
|
279 |
+
self.input_file = input_file
|
280 |
+
|
281 |
+
def read_documents(self) -> Generator[Tuple[str, str], None, None]:
|
282 |
+
with open(self.input_file, "r", encoding="utf-8") as f:
|
283 |
+
for line in f:
|
284 |
+
data = json.loads(line)
|
285 |
+
document_id = data.get("id", "unknown_id")
|
286 |
+
text = data.get("text", "")
|
287 |
+
yield document_id, text
|
288 |
+
|
289 |
+
|
290 |
+
class CsvInputReader(InputReader):
|
291 |
+
"""
|
292 |
+
Reads input from a CSV file in Mallet's format (document ID, dummy class, text).
|
293 |
+
Assumes that the CSV has three columns: "id", "dummyclass", and "text".
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(self, input_file: str) -> None:
|
297 |
+
self.input_file = input_file
|
298 |
+
|
299 |
+
def read_documents(self) -> Generator[Tuple[str, str], None, None]:
|
300 |
+
with open(self.input_file, mode="r", encoding="utf-8") as f:
|
301 |
+
csv_reader = csv.reader(f, delimiter="\t")
|
302 |
+
for row in csv_reader:
|
303 |
+
if len(row) < 3:
|
304 |
+
continue
|
305 |
+
document_id, text = row[0], row[2]
|
306 |
+
yield document_id, text.lower()
|
307 |
+
|
308 |
+
|
309 |
+
# ==================== Main Application ====================
|
310 |
+
|
311 |
+
|
312 |
+
class MalletTopicInferencer:
|
313 |
+
"""
|
314 |
+
MalletTopicInferencer class coordinates the process of reading input documents, identifying their language, and performing topic inference using Mallet models.
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
318 |
+
self.args = args
|
319 |
+
self.languages = set(args.languages)
|
320 |
+
self.language_inferencers: Optional[Dict[str, LanguageInferencer]] = None
|
321 |
+
self.language_lemmatizations: Optional[Dict[str, Dict[str, str]]] = None
|
322 |
+
self.input_reader = None
|
323 |
+
self.inference_results: List[Dict[str, str]] = []
|
324 |
+
self.language_dict: Dict[str, str] = {}
|
325 |
+
self.seen_languages: Set[str] = set()
|
326 |
+
self.stats = collections.Counter()
|
327 |
+
|
328 |
+
def initialize(self) -> None:
|
329 |
+
"""Initialize the inferencers after JVM startup."""
|
330 |
+
self.language_inferencers = self.init_language_inferencers(self.args)
|
331 |
+
self.input_reader = self.build_input_reader(self.args)
|
332 |
+
self.language_lemmatizations = self.init_language_lemmatizations(self.args)
|
333 |
+
if self.args.language_file:
|
334 |
+
self.language_dict = self.read_language_file(self.args.language_file)
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def start_jvm() -> None:
|
338 |
+
"""Start the Java Virtual Machine if not already started."""
|
339 |
+
if not jpype.isJVMStarted():
|
340 |
+
current_dir = os.getcwd()
|
341 |
+
source_dir = os.path.dirname(os.path.abspath(__file__))
|
342 |
+
|
343 |
+
# Construct classpath relative to the current directory
|
344 |
+
classpath = [
|
345 |
+
os.path.join(current_dir, "mallet/lib/mallet-deps.jar"),
|
346 |
+
os.path.join(current_dir, "mallet/lib/mallet.jar"),
|
347 |
+
]
|
348 |
+
|
349 |
+
# Check if the files exist in the current directory
|
350 |
+
if not all(os.path.exists(path) for path in classpath):
|
351 |
+
# If not, construct classpath relative to the source directory
|
352 |
+
classpath = [
|
353 |
+
os.path.join(source_dir, "mallet/lib/mallet-deps.jar"),
|
354 |
+
os.path.join(source_dir, "mallet/lib/mallet.jar"),
|
355 |
+
]
|
356 |
+
|
357 |
+
jpype.startJVM(classpath=classpath)
|
358 |
+
log.info(f"JVM started successfully with classpath {classpath}.")
|
359 |
+
else:
|
360 |
+
log.warning("JVM already running.")
|
361 |
+
|
362 |
+
def read_language_file(self, language_file: str) -> Dict[str, str]:
|
363 |
+
"""Read the language file (JSONL) and return a dictionary of document_id -> language."""
|
364 |
+
|
365 |
+
language_dict = {}
|
366 |
+
with open(language_file, "r", encoding="utf-8") as f:
|
367 |
+
for line in f:
|
368 |
+
data = json.loads(line)
|
369 |
+
doc_id = data.get("doc_id")
|
370 |
+
language = data.get("language")
|
371 |
+
if doc_id and language:
|
372 |
+
language_dict[doc_id] = language
|
373 |
+
return language_dict
|
374 |
+
|
375 |
+
@staticmethod
|
376 |
+
def load_lemmatization_file(
|
377 |
+
lemmatization_file_path: str,
|
378 |
+
bidi: bool = False,
|
379 |
+
lowercase: bool = True,
|
380 |
+
ignore_pos: bool = True,
|
381 |
+
) -> Dict[str, str]:
|
382 |
+
"""
|
383 |
+
Load lemmatization data from the file.
|
384 |
+
:param lemmatization_file_path: Path to the lemmatization file.
|
385 |
+
:return: A dictionary mapping tokens to their corresponding lemmas.
|
386 |
+
"""
|
387 |
+
|
388 |
+
token2lemma = {}
|
389 |
+
n = 0
|
390 |
+
with open(lemmatization_file_path, "r", "utf-8") as file:
|
391 |
+
for line in file:
|
392 |
+
token, _, lemma = line.strip().split("\t")
|
393 |
+
if lowercase:
|
394 |
+
token2lemma[token.lower()] = lemma.lower()
|
395 |
+
else:
|
396 |
+
token2lemma[token] = lemma
|
397 |
+
n += 1
|
398 |
+
|
399 |
+
logging.info(
|
400 |
+
"Read %d lemmatization entries from %s", n, lemmatization_file_path
|
401 |
+
)
|
402 |
+
return token2lemma
|
403 |
+
|
404 |
+
def init_language_lemmatizations(
|
405 |
+
self, args: argparse.Namespace
|
406 |
+
) -> Dict[str, Dict[str, str]]:
|
407 |
+
"""Build a mapping of languages to their respective lemmatization dictionaries."""
|
408 |
+
|
409 |
+
language_lemmatizations: Dict[str, Dict[str, str]] = {}
|
410 |
+
for language in args.languages:
|
411 |
+
lemmatization_key = f"{language}_lemmatization"
|
412 |
+
if getattr(args, lemmatization_key, None):
|
413 |
+
lemmatization_file = getattr(args, lemmatization_key)
|
414 |
+
language_lemmatizations[language] = self.load_lemmatization_file(
|
415 |
+
lemmatization_file
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
log.info(
|
419 |
+
f"Lemmatization file for language: {language} not provided by"
|
420 |
+
" arguments. Skipping."
|
421 |
+
)
|
422 |
+
return language_lemmatizations
|
423 |
+
|
424 |
+
def identify_language(self, document_id: str, text: str) -> str:
|
425 |
+
"""Identify the language of the text using the language file or a dummy method."""
|
426 |
+
# Check if the document ID is in the language dictionary
|
427 |
+
if document_id in self.language_dict:
|
428 |
+
return self.language_dict[document_id]
|
429 |
+
# Placeholder: Assume German ("de") for now if not found in the dictionary
|
430 |
+
return "de"
|
431 |
+
|
432 |
+
def init_language_inferencers(
|
433 |
+
self, args: argparse.Namespace
|
434 |
+
) -> Dict[str, LanguageInferencer]:
|
435 |
+
"""Build a mapping of languages to their respective inferencers
|
436 |
+
|
437 |
+
Includes the vectorizer pipe for each language as well.
|
438 |
+
"""
|
439 |
+
|
440 |
+
language_inferencers: Dict[str, LanguageInferencer] = {}
|
441 |
+
for language in args.languages:
|
442 |
+
inferencer_key = f"{language}_inferencer"
|
443 |
+
pipe_key = f"{language}_pipe"
|
444 |
+
if getattr(args, inferencer_key, None) and getattr(args, pipe_key, None):
|
445 |
+
language_inferencers[language] = LanguageInferencer(
|
446 |
+
language=language,
|
447 |
+
inferencer_file=getattr(args, inferencer_key),
|
448 |
+
pipe_file=getattr(args, pipe_key),
|
449 |
+
)
|
450 |
+
else:
|
451 |
+
log.info(
|
452 |
+
f"Inferencer or pipe file for language: {language} not provided by"
|
453 |
+
" arguments. Skipping."
|
454 |
+
)
|
455 |
+
return language_inferencers
|
456 |
+
|
457 |
+
def build_input_reader(self, args: argparse.Namespace) -> InputReader:
|
458 |
+
"""Select the appropriate input reader based on the input format."""
|
459 |
+
if args.input_format == "jsonl":
|
460 |
+
return JsonlInputReader(args.input)
|
461 |
+
elif args.input_format == "csv":
|
462 |
+
return CsvInputReader(args.input)
|
463 |
+
else:
|
464 |
+
raise ValueError(f"Unsupported input format: {args.input_format}")
|
465 |
+
|
466 |
+
def process_input_file(self) -> None:
|
467 |
+
"""Process the input file, identify language, and apply the appropriate Mallet model"""
|
468 |
+
temp_files_by_language = self.write_language_specific_csv_files()
|
469 |
+
|
470 |
+
doctopics_files = self.run_topic_inference(temp_files_by_language)
|
471 |
+
logging.info(doctopics_files)
|
472 |
+
if self.args.output_format == "csv":
|
473 |
+
self.merge_inference_results(doctopics_files)
|
474 |
+
elif self.args.output_format == "jsonl":
|
475 |
+
self.merge_inference_results_jsonl(doctopics_files)
|
476 |
+
|
477 |
+
def merge_inference_results_jsonl(self, doctopics_files_by_language):
|
478 |
+
|
479 |
+
args = ["--output", "<generator>"]
|
480 |
+
m2ta_converters = {}
|
481 |
+
for lang, doctopics_file in doctopics_files_by_language.items():
|
482 |
+
topic_model_id = self.args.__dict__[f"{lang}_model_id"]
|
483 |
+
if "{lang}" in topic_model_id:
|
484 |
+
topic_model_id.format(lang=lang)
|
485 |
+
args += [
|
486 |
+
"--topic_model",
|
487 |
+
topic_model_id,
|
488 |
+
"--topic_count",
|
489 |
+
str(self.args.__dict__[f"{lang}_topic_count"]),
|
490 |
+
"--lang",
|
491 |
+
lang,
|
492 |
+
doctopics_file, # input comes last!
|
493 |
+
]
|
494 |
+
m2ta_converters[lang] = m2taj.Mallet2TopicAssignment.main(args)
|
495 |
+
for lang, m2ta_converter in m2ta_converters.items():
|
496 |
+
with open(self.args.output, "w", encoding="utf-8") as out_f:
|
497 |
+
for row in m2ta_converter:
|
498 |
+
self.stats["content_items"] += 1
|
499 |
+
print(
|
500 |
+
json.dumps(row, ensure_ascii=False, separators=(",", ":")),
|
501 |
+
file=out_f,
|
502 |
+
)
|
503 |
+
|
504 |
+
def merge_inference_results(
|
505 |
+
self, doctopics_files_by_language: Dict[str, str]
|
506 |
+
) -> None:
|
507 |
+
"""Merge the inference results from multiple languages into a single output file."""
|
508 |
+
|
509 |
+
logging.info(
|
510 |
+
"Saving CSV inference results into file %s from multiple languages: %s",
|
511 |
+
self.args.output,
|
512 |
+
doctopics_files_by_language,
|
513 |
+
)
|
514 |
+
with open(self.args.output, "w", encoding="utf-8") as out_f:
|
515 |
+
for language, doctopics_file in doctopics_files_by_language.items():
|
516 |
+
with open(doctopics_file, "r", encoding="utf-8") as f:
|
517 |
+
for line in f:
|
518 |
+
if line.startswith("#"):
|
519 |
+
continue
|
520 |
+
doc_id, topic_dist = line.strip().split("\t", 1)
|
521 |
+
print(
|
522 |
+
doc_id + "__" + language,
|
523 |
+
topic_dist,
|
524 |
+
sep="\t",
|
525 |
+
end="\n",
|
526 |
+
file=out_f,
|
527 |
+
)
|
528 |
+
|
529 |
+
def write_language_specific_csv_files(self) -> Dict[str, str]:
|
530 |
+
"""Read documents and write to language-specific temporary files"""
|
531 |
+
tsv_files_by_language = {}
|
532 |
+
|
533 |
+
for document_id, text in self.input_reader.read_documents():
|
534 |
+
language_code = self.identify_language(document_id, text)
|
535 |
+
self.stats["LANGUAGE: " + language_code] += 1
|
536 |
+
if language_code not in self.languages:
|
537 |
+
continue
|
538 |
+
|
539 |
+
if language_code not in tsv_files_by_language:
|
540 |
+
tsv_files_by_language[language_code] = tempfile.NamedTemporaryFile(
|
541 |
+
delete=False,
|
542 |
+
mode="w",
|
543 |
+
suffix=f".{language_code}.tsv",
|
544 |
+
encoding="utf-8",
|
545 |
+
)
|
546 |
+
logging.info(
|
547 |
+
"Writing documents for language: %s in temp file: %s",
|
548 |
+
language_code,
|
549 |
+
tsv_files_by_language[language_code].name,
|
550 |
+
)
|
551 |
+
|
552 |
+
print(
|
553 |
+
document_id,
|
554 |
+
language_code,
|
555 |
+
text,
|
556 |
+
sep="\t",
|
557 |
+
end="\n",
|
558 |
+
file=tsv_files_by_language[language_code],
|
559 |
+
)
|
560 |
+
|
561 |
+
# Close all temporary files
|
562 |
+
for temp_file in tsv_files_by_language.values():
|
563 |
+
temp_file.close()
|
564 |
+
|
565 |
+
# noinspection PyShadowingNames
|
566 |
+
result = {
|
567 |
+
lang: temp_file.name for lang, temp_file in tsv_files_by_language.items()
|
568 |
+
}
|
569 |
+
return result
|
570 |
+
|
571 |
+
def run_topic_inference(
|
572 |
+
self, language_specific_csv_files: Dict[str, str]
|
573 |
+
) -> Dict[str, str]:
|
574 |
+
"""Run inference for each language"""
|
575 |
+
doctopics_files_by_language = {}
|
576 |
+
for language_code, csv_file in language_specific_csv_files.items():
|
577 |
+
inferencer = self.language_inferencers.get(language_code)
|
578 |
+
if not inferencer:
|
579 |
+
log.error(f"No inferencer found for language: {language_code}")
|
580 |
+
continue
|
581 |
+
|
582 |
+
doctopics_file = inferencer.run_csv2topics(csv_file)
|
583 |
+
doctopics_files_by_language[language_code] = doctopics_file
|
584 |
+
|
585 |
+
# Clean up the temporary vectorized file if logging level is not DEBUG
|
586 |
+
if log.getEffectiveLevel() != logging.DEBUG:
|
587 |
+
logging.info("Cleaning language specific csv file: %s", csv_file)
|
588 |
+
os.remove(csv_file)
|
589 |
+
logging.debug("Resulting doctopic files: %s", doctopics_files_by_language)
|
590 |
+
return doctopics_files_by_language
|
591 |
+
|
592 |
+
def write_results_to_output(self) -> None:
|
593 |
+
"""Write the final merged inference results to the output file."""
|
594 |
+
with open(self.args.output, "w", encoding="utf-8") as out_file:
|
595 |
+
for result in self.inference_results:
|
596 |
+
out_file.write(json.dumps(result) + "\n")
|
597 |
+
log.info(f"All inferences merged and written to {self.args.output}")
|
598 |
+
|
599 |
+
def run(self) -> None:
|
600 |
+
"""Main execution method."""
|
601 |
+
try:
|
602 |
+
self.start_jvm()
|
603 |
+
self.initialize()
|
604 |
+
self.process_input_file()
|
605 |
+
# self.write_results_to_output()
|
606 |
+
except Exception as e:
|
607 |
+
log.error(f"An error occurred: {e}")
|
608 |
+
log.error("Traceback: %s", traceback.format_exc())
|
609 |
+
finally:
|
610 |
+
jpype.shutdownJVM()
|
611 |
+
log.info("JVM shutdown.")
|
612 |
+
for key, value in sorted(self.stats.items()):
|
613 |
+
log.info(f"STATS: {key}: {value}")
|
614 |
+
|
615 |
+
|
616 |
+
if __name__ == "__main__":
|
617 |
+
languages = ["de", "fr", "lb"] # You can add more languages as needed
|
618 |
+
parser = argparse.ArgumentParser(description="Mallet Topic Inference in Python")
|
619 |
+
|
620 |
+
parser.add_argument("--logfile", help="Path to log file", default=None)
|
621 |
+
parser.add_argument(
|
622 |
+
"--level",
|
623 |
+
default="DEBUG",
|
624 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
625 |
+
help="Logging level $(default)s",
|
626 |
+
)
|
627 |
+
parser.add_argument(
|
628 |
+
"--input", help="Path to input file (%(default)s)", required=True
|
629 |
+
)
|
630 |
+
parser.add_argument(
|
631 |
+
"--input-format",
|
632 |
+
choices=["jsonl", "csv"],
|
633 |
+
default="jsonl",
|
634 |
+
help="Format of the input file",
|
635 |
+
)
|
636 |
+
parser.add_argument(
|
637 |
+
"--output-format",
|
638 |
+
choices=["jsonl", "csv"],
|
639 |
+
help=(
|
640 |
+
"Format of the output file: csv: raw Mallet output with docids patched into"
|
641 |
+
" numericID-LANG, jsonl: impresso JSONL format"
|
642 |
+
),
|
643 |
+
)
|
644 |
+
parser.add_argument(
|
645 |
+
"--output",
|
646 |
+
help="Path to final output file. (%(default)s)",
|
647 |
+
default="out.jsonl",
|
648 |
+
)
|
649 |
+
parser.add_argument(
|
650 |
+
"--languages",
|
651 |
+
nargs="+",
|
652 |
+
default=languages,
|
653 |
+
help="List of languages to support (%(default)s)",
|
654 |
+
)
|
655 |
+
|
656 |
+
parser.add_argument(
|
657 |
+
"--language-file",
|
658 |
+
help="Path to JSONL containing document_id to language mappings",
|
659 |
+
required=False,
|
660 |
+
)
|
661 |
+
parser.add_argument("--model_dir", help="Path to model directory", required=True)
|
662 |
+
# Dynamically generate arguments for each language's inferencer and pipe files
|
663 |
+
for lang in languages:
|
664 |
+
parser.add_argument(
|
665 |
+
f"--{lang}_inferencer",
|
666 |
+
help=f"Path to {lang} inferencer file",
|
667 |
+
)
|
668 |
+
parser.add_argument(f"--{lang}_pipe", help=f"Path to {lang} pipe file")
|
669 |
+
parser.add_argument(
|
670 |
+
f"--{lang}_lemmatization", help=f"Path to {lang} lemmatization file"
|
671 |
+
)
|
672 |
+
# Dynamically generate arguments for each language's inferencer and pipe files
|
673 |
+
for lang in languages:
|
674 |
+
parser.add_argument(
|
675 |
+
f"--{lang}_model_id",
|
676 |
+
default=f"tm-{lang}-all-v2.0",
|
677 |
+
help="Model ID can take a {lang} format placeholder (%(default)s)",
|
678 |
+
)
|
679 |
+
for lang in languages:
|
680 |
+
parser.add_argument(
|
681 |
+
f"--{lang}_topic_count",
|
682 |
+
default=100,
|
683 |
+
help="Number of topics of model (%(default)s)",
|
684 |
+
)
|
685 |
+
args = parser.parse_args()
|
686 |
+
|
687 |
+
logging.basicConfig(
|
688 |
+
filename=args.logfile,
|
689 |
+
level=args.level,
|
690 |
+
format="%(asctime)-15s %(filename)s:%(lineno)d %(levelname)s: %(message)s",
|
691 |
+
force=True,
|
692 |
+
)
|
693 |
+
# Automatically construct file paths if not explicitly specified
|
694 |
+
for lang in args.languages:
|
695 |
+
model_id = getattr(args, f"{lang}_model_id")
|
696 |
+
model_dir = args.model_dir
|
697 |
+
|
698 |
+
pipe_path = os.path.join(model_dir, f"{model_id}.pipe")
|
699 |
+
inferencer_path = os.path.join(model_dir, f"{model_id}.inferencer")
|
700 |
+
lemmatization_path = os.path.join(
|
701 |
+
model_dir, f"{model_id}.vocab.lemmatization.tsv.gz"
|
702 |
+
)
|
703 |
+
|
704 |
+
if not getattr(args, f"{lang}_pipe") and os.path.exists(pipe_path):
|
705 |
+
logging.info("Automatically setting pipe path to %s", pipe_path)
|
706 |
+
setattr(args, f"{lang}_pipe", pipe_path)
|
707 |
+
if not getattr(args, f"{lang}_inferencer") and os.path.exists(inferencer_path):
|
708 |
+
logging.info("Automatically setting inferencer path to %s", inferencer_path)
|
709 |
+
setattr(args, f"{lang}_inferencer", inferencer_path)
|
710 |
+
if not getattr(args, f"{lang}_lemmatization") and os.path.exists(
|
711 |
+
lemmatization_path
|
712 |
+
):
|
713 |
+
logging.info(
|
714 |
+
"Automatically setting lemmatization path to %s", lemmatization_path
|
715 |
+
)
|
716 |
+
setattr(args, f"{lang}_lemmatization", lemmatization_path)
|
717 |
+
|
718 |
+
if not args.output_format:
|
719 |
+
if "jsonl" in args.output:
|
720 |
+
args.output_format = "jsonl"
|
721 |
+
else:
|
722 |
+
args.output_format = "csv"
|
723 |
+
logging.warning("Unspecified output format set to %s", args.output_format)
|
724 |
+
for lang in args.languages:
|
725 |
+
if not getattr(args, f"{lang}_inferencer") or not getattr(args, f"{lang}_pipe"):
|
726 |
+
logging.warning(
|
727 |
+
"Inferencer or pipe file not provided for language: %s. Ignoring"
|
728 |
+
" content items for this language.",
|
729 |
+
lang,
|
730 |
+
)
|
731 |
+
args.languages.remove(lang)
|
732 |
+
logging.info(
|
733 |
+
"Performing monolingual topic inference for the following languages: %s",
|
734 |
+
args.languages,
|
735 |
+
)
|
736 |
+
|
737 |
+
logging.info("Arguments: %s", args)
|
738 |
+
app = MalletTopicInferencer(args)
|
739 |
+
app.run()
|
mallet/lib/mallet-deps.jar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cca84272e6a3aee57490de4691702a5d1264d56d2a651b741f034aa09052023
|
3 |
+
size 2644050
|
mallet/lib/mallet.jar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:facdf051fb2775a8b8015d154d23c7fcb1b276172ba5b517fd3877f45332cf30
|
3 |
+
size 2235683
|
models/tm/tm-de-all-v2.0.inferencer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f3961a2b868ea35605a27e19d6d7d74ab2dc692de8d41799e1ccb68e3af0b5b
|
3 |
+
size 23330363
|
models/tm/tm-de-all-v2.0.pipe
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc69ca2007dd2380cdda1c4a4f029f1325363ea89ceacbc14bf82788abd166be
|
3 |
+
size 748181
|
models/tm/tm-de-all-v2.0.vocab.lemmatization.tsv.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:06d163134a44cc3f338b187df9fd6b4936ea4ecf7d1157b29efddefccd6456cc
|
3 |
+
size 1288289
|
models/tm/tm-fr-all-v2.0.inferencer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7acfecbf88e50ea4ca0e872ed159413bcfc8eeea2bcf51c10b6aeb13ec11864
|
3 |
+
size 8718608
|
models/tm/tm-fr-all-v2.0.pipe
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:916b0ff537284e232922016dd6fd6924ade4a405b69776d16adb5b8d61593af8
|
3 |
+
size 249343
|
models/tm/tm-fr-all-v2.0.vocab.lemmatization.tsv.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6476a943c3126eaf249685273af3445105763c6cef22e76eb32b5618724de147
|
3 |
+
size 311945
|
models/tm/tm-lb-all-v2.0.inferencer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bcffa061a4de61cee46ade61ff8cdb9637839a0bb0750a1001c0a9934ec8f53e
|
3 |
+
size 26498144
|
models/tm/tm-lb-all-v2.0.pipe
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b08b13c18f0388100820b6cffc2753c99f4a0be48da9b3e2190926fa473283bf
|
3 |
+
size 8060126
|
models/tm/tm-lb-all-v2.0.vocab.lemmatization.tsv.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b65f8503e341fad42cae0f26c93378be24208ab2453a052d742a34e6e5ee8db
|
3 |
+
size 3893233
|
requirements.txt
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-i https://pypi.org/simple
|
2 |
+
blis==0.7.11
|
3 |
+
boto3==1.35.50; python_version >= '3.8'
|
4 |
+
botocore==1.35.50; python_version >= '3.8'
|
5 |
+
catalogue==2.0.10; python_version >= '3.6'
|
6 |
+
certifi==2024.8.30; python_version >= '3.6'
|
7 |
+
charset-normalizer==3.4.0; python_full_version >= '3.7.0'
|
8 |
+
click==8.1.7; python_version >= '3.7'
|
9 |
+
confection==0.1.5; python_version >= '3.6'
|
10 |
+
cymem==2.0.8
|
11 |
+
https://github.com/explosion/spacy-models/releases/download/de_core_news_md-3.6.0/de_core_news_md-3.6.0.tar.gz
|
12 |
+
https://github.com/explosion/spacy-models/releases/download/fr_core_news_md-3.6.0/fr_core_news_md-3.6.0.tar.gz
|
13 |
+
idna==3.10; python_version >= '3.6'
|
14 |
+
jinja2==3.1.4; python_version >= '3.7'
|
15 |
+
jmespath==1.0.1; python_version >= '3.7'
|
16 |
+
jpype1==1.5.0; python_version >= '3.7'
|
17 |
+
langcodes==3.4.1; python_version >= '3.8'
|
18 |
+
language-data==1.2.0
|
19 |
+
marisa-trie==1.2.1; python_version >= '3.7'
|
20 |
+
markupsafe==3.0.2; python_version >= '3.9'
|
21 |
+
murmurhash==1.0.10; python_version >= '3.6'
|
22 |
+
numpy==1.26.4; python_version >= '3.9'
|
23 |
+
packaging==24.1; python_version >= '3.8'
|
24 |
+
pathlib-abc==0.1.1; python_version >= '3.8'
|
25 |
+
pathy==0.11.0; python_version >= '3.8'
|
26 |
+
preshed==3.0.9; python_version >= '3.6'
|
27 |
+
pydantic==1.10.18; python_version >= '3.7'
|
28 |
+
python-dateutil==2.9.0.post0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
29 |
+
python-dotenv==1.0.1; python_version >= '3.8'
|
30 |
+
requests==2.32.3; python_version >= '3.8'
|
31 |
+
s3transfer==0.10.3; python_version >= '3.8'
|
32 |
+
setuptools==75.2.0; python_version >= '3.8'
|
33 |
+
six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
34 |
+
smart-open==6.4.0; python_version >= '3.6' and python_version < '4.0'
|
35 |
+
spacy==3.6.0; python_version >= '3.6'
|
36 |
+
spacy-legacy==3.0.12; python_version >= '3.6'
|
37 |
+
spacy-loggers==1.0.5; python_version >= '3.6'
|
38 |
+
spacy-lookups-data==1.0.5; python_version >= '3.6'
|
39 |
+
srsly==2.4.8; python_version >= '3.6'
|
40 |
+
thinc==8.1.12; python_version >= '3.6'
|
41 |
+
tqdm==4.66.6; python_version >= '3.7'
|
42 |
+
typer==0.9.4; python_version >= '3.6'
|
43 |
+
typing-extensions==4.12.2; python_version >= '3.8'
|
44 |
+
urllib3==2.2.3; python_version >= '3.10'
|
45 |
+
wasabi==1.1.3; python_version >= '3.6'
|