Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import os | |
import re | |
import sys | |
import typing as tp | |
import torch | |
import pysbd | |
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM | |
import unicodedata | |
import time | |
#hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed | |
MODEL_NAME = "AriNubar/nllb-200-distilled-600m-en-xcl" | |
LANGUAGES = { | |
"Գրաբառ Հայոց | Classical Armenian": "xcl_Armn", | |
"Անգլերէն | English": "eng_Latn", | |
} | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
def get_non_printing_char_replacer(replace_by: str = " "): | |
non_printable_map = { | |
ord(c): replace_by | |
for c in (chr(i) for i in range(sys.maxunicode + 1)) | |
# same as \p{C} in perl | |
# see https://www.unicode.org/reports/tr44/#General_Category_Values | |
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"} | |
} | |
def replace_non_printing_char(line) -> str: | |
return line.translate(non_printable_map) | |
return replace_non_printing_char | |
# def clean_text(text: str, lang) -> str: | |
# HYW_CHARS_TO_NORMALIZE = { | |
# "«": '"', | |
# "»": '"', | |
# "“": '"', | |
# "”": '"', | |
# "’": "'", | |
# "‘": "'", | |
# "–": "-", | |
# "—": "-", | |
# "ՙ": "'", | |
# "՚": "'", | |
# } | |
# DOUBLE_CHARS_TO_NORMALIZE = { | |
# "Կ՛": "Կ'", | |
# "կ՛": "կ'", | |
# "Չ՛": "Չ'", | |
# "չ՛": "չ'", | |
# "Մ՛": "Մ'", | |
# "մ՛": "մ'", | |
# } | |
# replace_nonprint = get_non_printing_char_replacer() | |
# text = replace_nonprint(text) | |
# # print(text) | |
# text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace(r"[^\x00-\x7F]+", " ").replace(r"\s+", " ") | |
# text = text.strip() | |
# if lang == "xcl_Armn": | |
# text = text.translate(str.maketrans(HYW_CHARS_TO_NORMALIZE)) | |
# for k, v in DOUBLE_CHARS_TO_NORMALIZE.items(): | |
# text = text.replace(k, v) | |
# return text | |
def remove_special_characters(text): | |
# Define a regex pattern for special characters | |
pattern = r'[\u00A0\u200B\u200C\u200D\u200E\u200F\u2028\u2029\xad]' | |
return re.sub(pattern, '', text) | |
def common_clean_methods(text): | |
text = text.strip() | |
text = re.sub(r'\n+', '\n', text) | |
text = re.sub(r' +', ' ', text) | |
text = remove_special_characters(text) | |
text = text.replace("\t", "") | |
text = text.replace("\r", "") | |
text = re.sub(r'^[0-9\*\-]+$', '', text, flags=re.MULTILINE) | |
text = text.strip() | |
return text | |
def wa_clean_methods(text): | |
# 7) Merge the isolate punctuations. | |
text = re.sub(r' (։|:|․|…|՝|՞|`|´|~)(?=\s)', r'\1', text) | |
# 1) Convert all : to ։ (if it is an Armenian text there can be an English name with : in it, we should not convert those) | |
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])\:', '։', text) | |
# 2) Convert all \. to ․ (similar to the previous one) | |
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])\.', '․', text) | |
# 4) Convert all ` to ՝ | |
# 5) Convert all ´ to ՛ | |
# 6) Convert all ~ to ՜ | |
# 8) Convert all < to « | |
# 9) Convert all > to » | |
# 6) Convert all ․․․ to … (Alt + 0133) | |
# 7) Convert all ֊ to - | |
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])`', '՝', text) | |
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])´', '՛', text) | |
text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])~', '՜', text) | |
text = re.sub(r'<(?=[ա-ֆԱ-Ֆ])', '«', text) | |
text = re.sub(r'>(?=[ա-ֆԱ-Ֆ])', '»', text) | |
text = re.sub(r'․․․', '…', text) | |
text = re.sub(r'\.\.\.', '…', text) | |
text = re.sub(r'֊', '-', text) | |
text = re.sub(r'կ՝', 'կ', text) | |
text = re.sub(r'Կ՝', 'Կ', text) | |
text = re.sub(r'կ՛', 'կ', text) | |
text = re.sub(r'Կ՛', 'Կ', text) | |
text = re.sub(r'մ՝', 'մ', text) | |
text = re.sub(r'Մ՝', 'Մ', text) | |
text = re.sub(r'մ՛', 'մ', text) | |
text = re.sub(r'Մ՛', 'Մ', text) | |
# for « and ( there should not be a space after them | |
text = re.sub(r'« ', '«', text) | |
text = re.sub(r'\( ', '(', text) | |
# for ) and » there should not be a space before them | |
text = re.sub(r' \)', ')', text) | |
text = re.sub(r' »', '»', text) | |
return text | |
def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False): | |
if fix_double_space: | |
text = re.sub(r"\s+", " ", text) | |
text = text.strip() | |
sentences = splitter.segment(text) | |
fillers = [] | |
i = 0 | |
for sent in sentences: | |
start_idx = text.find(sent, i) | |
if ignore_errors and start_idx == -1: | |
start_idx = i + 1 | |
assert start_idx != -1, f"Sent not found after index {i} in text: {text}" | |
fillers.append(text[i:start_idx]) | |
i = start_idx + len(sent) | |
fillers.append(text[i:]) | |
return sentences, fillers | |
def clean_text(text: str, lang) -> str: | |
replace_nonprint = get_non_printing_char_replacer() | |
text = replace_nonprint(text) | |
text = common_clean_methods(text) | |
if lang == "xcl_Armn": | |
text = wa_clean_methods(text) | |
return text | |
def init_tokenizer(tokenizer, new_langs=["xcl_Armn"]): | |
""" Add multiple new language tokens to the tokenizer vocabulary (this should be done each time after its initialization) """ | |
for new_lang in new_langs: | |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) | |
tokenizer.lang_code_to_id[new_lang] = old_len-1 | |
tokenizer.id_to_lang_code[old_len-1] = new_lang | |
if new_lang not in tokenizer._additional_special_tokens: | |
tokenizer._additional_special_tokens.append(new_lang) | |
# always move "mask" to the last position | |
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset | |
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) | |
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} | |
# clear the added token encoder; otherwise a new token may end up there by mistake | |
tokenizer.added_tokens_encoder = {} # <- these only work with transformers==4.33.0 | |
tokenizer.added_tokens_decoder = {} | |
return tokenizer | |
# def init_tokenizer(tokenizer, new_lang='xcl_Armn'): | |
# """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """ | |
# old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) | |
# tokenizer.lang_code_to_id[new_lang] = old_len-1 | |
# tokenizer.id_to_lang_code[old_len-1] = new_lang | |
# # always move "mask" to the last position | |
# tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset | |
# tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) | |
# tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} | |
# if new_lang not in tokenizer._additional_special_tokens: | |
# tokenizer._additional_special_tokens.append(new_lang) | |
# # clear the added token encoder; otherwise a new token may end up there by mistake | |
# tokenizer.added_tokens_encoder = {} | |
# tokenizer.added_tokens_decoder = {} | |
class Translator: | |
def __init__(self) -> None: | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
init_tokenizer(self.tokenizer) | |
self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True) | |
self.eng_splitter = pysbd.Segmenter(language="en", clean=True) | |
self.languages = LANGUAGES | |
def translate_single( | |
self, | |
text, | |
src_lang, | |
tgt_lang, | |
max_length="auto", | |
num_beams=4, | |
n_out=None, | |
**kwargs, | |
): | |
self.tokenizer.src_lang = src_lang | |
encoded = self.tokenizer( | |
text, return_tensors="pt", truncation=True, max_length=256 | |
) | |
if max_length == "auto": | |
max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) | |
generated_tokens = self.model.generate( | |
**encoded.to(self.model.device), | |
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], | |
max_length=max_length, | |
num_beams=num_beams, | |
num_return_sequences=n_out or 1, | |
**kwargs, | |
) | |
out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
if isinstance(text, str) and n_out is None: | |
return out[0] | |
return out | |
def translate(self, text: str, | |
src_lang: str, | |
tgt_lang: str, | |
max_length=256, | |
num_beams=4, | |
by_sentence=True, | |
clean=True, | |
**kwargs): | |
# Split into paragraphs | |
paragraphs = text.split('\n') | |
translated_paragraphs = [] | |
for paragraph in paragraphs: | |
if not paragraph.strip(): | |
translated_paragraphs.append('') | |
continue | |
if by_sentence: | |
if src_lang == "eng_Latn": | |
sents = self.eng_splitter.segment(paragraph) | |
elif src_lang == "xcl_Armn": | |
sents = self.hyw_splitter.segment(paragraph) | |
if clean: | |
sents = [clean_text(sent, src_lang) for sent in sents] | |
if len(sents) > 1: | |
results = self.translate_batch(sents, src_lang, tgt_lang, | |
num_beams=num_beams, | |
max_length=max_length, **kwargs) | |
else: | |
results = [self.translate_single(sents[0], src_lang, tgt_lang, | |
max_length=max_length, | |
num_beams=num_beams, **kwargs)] | |
translated_paragraphs.append(" ".join(results)) | |
else: | |
if clean: | |
paragraph = clean_text(paragraph, src_lang) | |
translated = self.translate_single(paragraph, src_lang, tgt_lang, | |
max_length=max_length, | |
num_beams=num_beams, **kwargs) | |
translated_paragraphs.append(translated) | |
# Reconstruct with original paragraph structure | |
return "\n".join(translated_paragraphs) | |
def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs): | |
self.tokenizer.src_lang = src_lang | |
if torch.cuda.is_available(): | |
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda") | |
translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) | |
else: | |
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True) | |
translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) | |
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) | |
if __name__ == "__main__": | |
print("Initializing translator...") | |
translator = Translator() | |
print("Translator initialized.") | |
start_time = time.time() | |
print(translator.translate("Hello world!", "eng_Latn", "xcl_Armn")) | |
print("Time elapsed: ", time.time() - start_time) | |
start_time = time.time() | |
print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "xcl_Armn")) | |
print("Time elapsed: ", time.time() - start_time) |