# -*- coding: utf-8 -*- import os import re import sys import typing as tp import torch import pysbd from transformers import NllbTokenizer, AutoModelForSeq2SeqLM import unicodedata import time #hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed MODEL_NAME = "AriNubar/nllb-200-distilled-600m-en-xcl" LANGUAGES = { "Գրաբառ Հայոց | Classical Armenian": "xcl_Armn", "Անգլերէն | English": "eng_Latn", } HF_TOKEN = os.environ.get("HF_TOKEN") def get_non_printing_char_replacer(replace_by: str = " "): non_printable_map = { ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) # same as \p{C} in perl # see https://www.unicode.org/reports/tr44/#General_Category_Values if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"} } def replace_non_printing_char(line) -> str: return line.translate(non_printable_map) return replace_non_printing_char # def clean_text(text: str, lang) -> str: # HYW_CHARS_TO_NORMALIZE = { # "«": '"', # "»": '"', # "“": '"', # "”": '"', # "’": "'", # "‘": "'", # "–": "-", # "—": "-", # "ՙ": "'", # "՚": "'", # } # DOUBLE_CHARS_TO_NORMALIZE = { # "Կ՛": "Կ'", # "կ՛": "կ'", # "Չ՛": "Չ'", # "չ՛": "չ'", # "Մ՛": "Մ'", # "մ՛": "մ'", # } # replace_nonprint = get_non_printing_char_replacer() # text = replace_nonprint(text) # # print(text) # text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace(r"[^\x00-\x7F]+", " ").replace(r"\s+", " ") # text = text.strip() # if lang == "xcl_Armn": # text = text.translate(str.maketrans(HYW_CHARS_TO_NORMALIZE)) # for k, v in DOUBLE_CHARS_TO_NORMALIZE.items(): # text = text.replace(k, v) # return text def remove_special_characters(text): # Define a regex pattern for special characters pattern = r'[\u00A0\u200B\u200C\u200D\u200E\u200F\u2028\u2029\xad]' return re.sub(pattern, '', text) def common_clean_methods(text): text = text.strip() text = re.sub(r'\n+', '\n', text) text = re.sub(r' +', ' ', text) text = remove_special_characters(text) text = text.replace("\t", "") text = text.replace("\r", "") text = re.sub(r'^[0-9\*\-]+$', '', text, flags=re.MULTILINE) text = text.strip() return text def wa_clean_methods(text): # 7) Merge the isolate punctuations. text = re.sub(r' (։|:|․|…|՝|՞|`|´|~)(?=\s)', r'\1', text) # 1) Convert all : to ։ (if it is an Armenian text there can be an English name with : in it, we should not convert those) text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])\:', '։', text) # 2) Convert all \. to ․ (similar to the previous one) text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])\.', '․', text) # 4) Convert all ` to ՝ # 5) Convert all ´ to ՛ # 6) Convert all ~ to ՜ # 8) Convert all < to « # 9) Convert all > to » # 6) Convert all ․․․ to … (Alt + 0133) # 7) Convert all ֊ to - text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])`', '՝', text) text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])´', '՛', text) text = re.sub(r'(?<=[ա-ֆԱ-Ֆ])~', '՜', text) text = re.sub(r'<(?=[ա-ֆԱ-Ֆ])', '«', text) text = re.sub(r'>(?=[ա-ֆԱ-Ֆ])', '»', text) text = re.sub(r'․․․', '…', text) text = re.sub(r'\.\.\.', '…', text) text = re.sub(r'֊', '-', text) text = re.sub(r'կ՝', 'կ', text) text = re.sub(r'Կ՝', 'Կ', text) text = re.sub(r'կ՛', 'կ', text) text = re.sub(r'Կ՛', 'Կ', text) text = re.sub(r'մ՝', 'մ', text) text = re.sub(r'Մ՝', 'Մ', text) text = re.sub(r'մ՛', 'մ', text) text = re.sub(r'Մ՛', 'Մ', text) # for « and ( there should not be a space after them text = re.sub(r'« ', '«', text) text = re.sub(r'\( ', '(', text) # for ) and » there should not be a space before them text = re.sub(r' \)', ')', text) text = re.sub(r' »', '»', text) return text def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False): if fix_double_space: text = re.sub(r"\s+", " ", text) text = text.strip() sentences = splitter.segment(text) fillers = [] i = 0 for sent in sentences: start_idx = text.find(sent, i) if ignore_errors and start_idx == -1: start_idx = i + 1 assert start_idx != -1, f"Sent not found after index {i} in text: {text}" fillers.append(text[i:start_idx]) i = start_idx + len(sent) fillers.append(text[i:]) return sentences, fillers def clean_text(text: str, lang) -> str: replace_nonprint = get_non_printing_char_replacer() text = replace_nonprint(text) text = common_clean_methods(text) if lang == "xcl_Armn": text = wa_clean_methods(text) return text def init_tokenizer(tokenizer, new_langs=["xcl_Armn"]): """ Add multiple new language tokens to the tokenizer vocabulary (this should be done each time after its initialization) """ for new_lang in new_langs: old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) tokenizer.lang_code_to_id[new_lang] = old_len-1 tokenizer.id_to_lang_code[old_len-1] = new_lang if new_lang not in tokenizer._additional_special_tokens: tokenizer._additional_special_tokens.append(new_lang) # always move "mask" to the last position tokenizer.fairseq_tokens_to_ids[""] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} # clear the added token encoder; otherwise a new token may end up there by mistake tokenizer.added_tokens_encoder = {} # <- these only work with transformers==4.33.0 tokenizer.added_tokens_decoder = {} return tokenizer # def init_tokenizer(tokenizer, new_lang='xcl_Armn'): # """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """ # old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) # tokenizer.lang_code_to_id[new_lang] = old_len-1 # tokenizer.id_to_lang_code[old_len-1] = new_lang # # always move "mask" to the last position # tokenizer.fairseq_tokens_to_ids[""] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset # tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) # tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} # if new_lang not in tokenizer._additional_special_tokens: # tokenizer._additional_special_tokens.append(new_lang) # # clear the added token encoder; otherwise a new token may end up there by mistake # tokenizer.added_tokens_encoder = {} # tokenizer.added_tokens_decoder = {} class Translator: def __init__(self) -> None: self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, token=HF_TOKEN) if torch.cuda.is_available(): self.model = self.model.cuda() self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) init_tokenizer(self.tokenizer) self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True) self.eng_splitter = pysbd.Segmenter(language="en", clean=True) self.languages = LANGUAGES def translate_single( self, text, src_lang, tgt_lang, max_length="auto", num_beams=4, n_out=None, **kwargs, ): self.tokenizer.src_lang = src_lang encoded = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=256 ) if max_length == "auto": max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) generated_tokens = self.model.generate( **encoded.to(self.model.device), forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], max_length=max_length, num_beams=num_beams, num_return_sequences=n_out or 1, **kwargs, ) out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) if isinstance(text, str) and n_out is None: return out[0] return out def translate(self, text: str, src_lang: str, tgt_lang: str, max_length=256, num_beams=4, by_sentence=True, clean=True, **kwargs): # Split into paragraphs paragraphs = text.split('\n') translated_paragraphs = [] for paragraph in paragraphs: if not paragraph.strip(): translated_paragraphs.append('') continue if by_sentence: if src_lang == "eng_Latn": sents = self.eng_splitter.segment(paragraph) elif src_lang == "xcl_Armn": sents = self.hyw_splitter.segment(paragraph) if clean: sents = [clean_text(sent, src_lang) for sent in sents] if len(sents) > 1: results = self.translate_batch(sents, src_lang, tgt_lang, num_beams=num_beams, max_length=max_length, **kwargs) else: results = [self.translate_single(sents[0], src_lang, tgt_lang, max_length=max_length, num_beams=num_beams, **kwargs)] translated_paragraphs.append(" ".join(results)) else: if clean: paragraph = clean_text(paragraph, src_lang) translated = self.translate_single(paragraph, src_lang, tgt_lang, max_length=max_length, num_beams=num_beams, **kwargs) translated_paragraphs.append(translated) # Reconstruct with original paragraph structure return "\n".join(translated_paragraphs) def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs): self.tokenizer.src_lang = src_lang if torch.cuda.is_available(): inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda") translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) else: inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True) translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]) return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) if __name__ == "__main__": print("Initializing translator...") translator = Translator() print("Translator initialized.") start_time = time.time() print(translator.translate("Hello world!", "eng_Latn", "xcl_Armn")) print("Time elapsed: ", time.time() - start_time) start_time = time.time() print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "xcl_Armn")) print("Time elapsed: ", time.time() - start_time)