xcl-en-demo / translation.py
Ari Nubar Boyacıoğlu
fix typo
375396d
# -*- coding: utf-8 -*-
import os
import re
import sys
import typing as tp
import torch
import pysbd
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
import unicodedata
import time
#hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed
MODEL_NAME = "AriNubar/nllb-200-distilled-600m-en-xcl"
LANGUAGES = {
"Գրաբառ Հայոց | Classical Armenian": "xcl_Armn",
"Անգլերէն | English": "eng_Latn",
}
HF_TOKEN = os.environ.get("HF_TOKEN")
def get_non_printing_char_replacer(replace_by: str = " "):
non_printable_map = {
ord(c): replace_by
for c in (chr(i) for i in range(sys.maxunicode + 1))
# same as \p{C} in perl
# see https://www.unicode.org/reports/tr44/#General_Category_Values
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
}
def replace_non_printing_char(line) -> str:
return line.translate(non_printable_map)
return replace_non_printing_char
# def clean_text(text: str, lang) -> str:
# HYW_CHARS_TO_NORMALIZE = {
# "«": '"',
# "»": '"',
# "“": '"',
# "”": '"',
# "’": "'",
# "‘": "'",
# "–": "-",
# "—": "-",
# "ՙ": "'",
# "՚": "'",
# }
# DOUBLE_CHARS_TO_NORMALIZE = {
# "Կ՛": "Կ'",
# "կ՛": "կ'",
# "Չ՛": "Չ'",
# "չ՛": "չ'",
# "Մ՛": "Մ'",
# "մ՛": "մ'",
# }
# replace_nonprint = get_non_printing_char_replacer()
# text = replace_nonprint(text)
# # print(text)
# text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace(r"[^\x00-\x7F]+", " ").replace(r"\s+", " ")
# text = text.strip()
# if lang == "xcl_Armn":
# text = text.translate(str.maketrans(HYW_CHARS_TO_NORMALIZE))
# for k, v in DOUBLE_CHARS_TO_NORMALIZE.items():
# text = text.replace(k, v)
# return text
def remove_special_characters(text):
# Define a regex pattern for special characters
pattern = r'[\u00A0\u200B\u200C\u200D\u200E\u200F\u2028\u2029\xad]'
return re.sub(pattern, '', text)
def common_clean_methods(text):
text = text.strip()
text = re.sub(r'\n+', '\n', text)
text = re.sub(r' +', ' ', text)
text = remove_special_characters(text)
text = text.replace("\t", "")
text = text.replace("\r", "")
text = re.sub(r'^[0-9\*\-]+$', '', text, flags=re.MULTILINE)
text = text.strip()
return text
def wa_clean_methods(text):
# 7) Merge the isolate punctuations.
text = re.sub(r' (։|:|․|…|՝|՞|`|´|~)(?=\s)', r'\1', text)
# 1) Convert all : to ։ (if it is an Armenian text there can be an English name with : in it, we should not convert those)
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])\:', '։', text)
# 2) Convert all \. to ․ (similar to the previous one)
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])\.', '․', text)
# 4) Convert all ` to ՝
# 5) Convert all ´ to ՛
# 6) Convert all ~ to ՜
# 8) Convert all < to «
# 9) Convert all > to »
# 6) Convert all ․․․ to … (Alt + 0133)
# 7) Convert all ֊ to -
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])`', '՝', text)
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])´', '՛', text)
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])~', '՜', text)
text = re.sub(r'<(?=[ա-ֆԱ-Ֆ])', '«', text)
text = re.sub(r'>(?=[ա-ֆԱ-Ֆ])', '»', text)
text = re.sub(r'․․․', '…', text)
text = re.sub(r'\.\.\.', '…', text)
text = re.sub(r'֊', '-', text)
text = re.sub(r'կ՝', 'կ', text)
text = re.sub(r'Կ՝', 'Կ', text)
text = re.sub(r'կ՛', 'կ', text)
text = re.sub(r'Կ՛', 'Կ', text)
text = re.sub(r'մ՝', 'մ', text)
text = re.sub(r'Մ՝', 'Մ', text)
text = re.sub(r'մ՛', 'մ', text)
text = re.sub(r'Մ՛', 'Մ', text)
# for « and ( there should not be a space after them
text = re.sub(r'« ', '«', text)
text = re.sub(r'\( ', '(', text)
# for ) and » there should not be a space before them
text = re.sub(r' \)', ')', text)
text = re.sub(r' »', '»', text)
return text
def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
if fix_double_space:
text = re.sub(r"\s+", " ", text)
text = text.strip()
sentences = splitter.segment(text)
fillers = []
i = 0
for sent in sentences:
start_idx = text.find(sent, i)
if ignore_errors and start_idx == -1:
start_idx = i + 1
assert start_idx != -1, f"Sent not found after index {i} in text: {text}"
fillers.append(text[i:start_idx])
i = start_idx + len(sent)
fillers.append(text[i:])
return sentences, fillers
def clean_text(text: str, lang) -> str:
replace_nonprint = get_non_printing_char_replacer()
text = replace_nonprint(text)
text = common_clean_methods(text)
if lang == "xcl_Armn":
text = wa_clean_methods(text)
return text
def init_tokenizer(tokenizer, new_langs=["xcl_Armn"]):
""" Add multiple new language tokens to the tokenizer vocabulary (this should be done each time after its initialization) """
for new_lang in new_langs:
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[new_lang] = old_len-1
tokenizer.id_to_lang_code[old_len-1] = new_lang
if new_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(new_lang)
# always move "mask" to the last position
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
# clear the added token encoder; otherwise a new token may end up there by mistake
tokenizer.added_tokens_encoder = {} # <- these only work with transformers==4.33.0
tokenizer.added_tokens_decoder = {}
return tokenizer
# def init_tokenizer(tokenizer, new_lang='xcl_Armn'):
# """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """
# old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
# tokenizer.lang_code_to_id[new_lang] = old_len-1
# tokenizer.id_to_lang_code[old_len-1] = new_lang
# # always move "mask" to the last position
# tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
# tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
# tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
# if new_lang not in tokenizer._additional_special_tokens:
# tokenizer._additional_special_tokens.append(new_lang)
# # clear the added token encoder; otherwise a new token may end up there by mistake
# tokenizer.added_tokens_encoder = {}
# tokenizer.added_tokens_decoder = {}
class Translator:
def __init__(self) -> None:
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, token=HF_TOKEN)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
init_tokenizer(self.tokenizer)
self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True)
self.eng_splitter = pysbd.Segmenter(language="en", clean=True)
self.languages = LANGUAGES
def translate_single(
self,
text,
src_lang,
tgt_lang,
max_length="auto",
num_beams=4,
n_out=None,
**kwargs,
):
self.tokenizer.src_lang = src_lang
encoded = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=256
)
if max_length == "auto":
max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
generated_tokens = self.model.generate(
**encoded.to(self.model.device),
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
max_length=max_length,
num_beams=num_beams,
num_return_sequences=n_out or 1,
**kwargs,
)
out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
if isinstance(text, str) and n_out is None:
return out[0]
return out
def translate(self, text: str,
src_lang: str,
tgt_lang: str,
max_length=256,
num_beams=4,
by_sentence=True,
clean=True,
**kwargs):
# Split into paragraphs
paragraphs = text.split('\n')
translated_paragraphs = []
for paragraph in paragraphs:
if not paragraph.strip():
translated_paragraphs.append('')
continue
if by_sentence:
if src_lang == "eng_Latn":
sents = self.eng_splitter.segment(paragraph)
elif src_lang == "xcl_Armn":
sents = self.hyw_splitter.segment(paragraph)
if clean:
sents = [clean_text(sent, src_lang) for sent in sents]
if len(sents) > 1:
results = self.translate_batch(sents, src_lang, tgt_lang,
num_beams=num_beams,
max_length=max_length, **kwargs)
else:
results = [self.translate_single(sents[0], src_lang, tgt_lang,
max_length=max_length,
num_beams=num_beams, **kwargs)]
translated_paragraphs.append(" ".join(results))
else:
if clean:
paragraph = clean_text(paragraph, src_lang)
translated = self.translate_single(paragraph, src_lang, tgt_lang,
max_length=max_length,
num_beams=num_beams, **kwargs)
translated_paragraphs.append(translated)
# Reconstruct with original paragraph structure
return "\n".join(translated_paragraphs)
def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs):
self.tokenizer.src_lang = src_lang
if torch.cuda.is_available():
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda")
translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
else:
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True)
translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
if __name__ == "__main__":
print("Initializing translator...")
translator = Translator()
print("Translator initialized.")
start_time = time.time()
print(translator.translate("Hello world!", "eng_Latn", "xcl_Armn"))
print("Time elapsed: ", time.time() - start_time)
start_time = time.time()
print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "xcl_Armn"))
print("Time elapsed: ", time.time() - start_time)