--- language: - he pipeline_tag: text-generation --- ### Description Experiments with encoder-decoder model, where encoder is [alephbert-base](https://huggingface.co/onlplab/alephbert-base) and [decoder is pruned mT5-base model](https://huggingface.co/imvladikon/het5-base) Could be useful for generation negative and hard-negative samples for pair-text classification. (To paraphrase is better to use classical approaches rather than this one) ### Usage ```bash git clone https://huggingface.co/imvladikon/alephbert-encoder-t5-decoder ``` ```python import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel from transformers.modeling_outputs import BaseModelOutput from datasets import load_dataset enc_checkpoint = "./alephbert-encoder-t5-decoder/encoder" enc_tokenizer = AutoTokenizer.from_pretrained(enc_checkpoint) encoder = AutoModel.from_pretrained(enc_checkpoint).cuda() dec_checkpoint = "./alephbert-encoder-t5-decoder/decoder" dec_tokenizer = AutoTokenizer.from_pretrained(dec_checkpoint) decoder = AutoModelForSeq2SeqLM.from_pretrained(dec_checkpoint).cuda() def encode(texts): encoded_input = enc_tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt') with torch.no_grad(): model_output = encoder(**encoded_input.to(encoder.device)) embeddings = model_output.pooler_output embeddings = torch.nn.functional.normalize(embeddings) return embeddings def decode(embeddings, max_length=256, repetition_penalty=3.0, **kwargs): out = decoder.generate( encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), max_length=max_length, repetition_penalty=repetition_penalty, ) return [dec_tokenizer.decode(tokens, skip_special_tokens=True) for tokens in out] encoder.eval() text = """ מחר יוסיף להיות מעונן חלקית ובמהלך היום יתחזקו הרוחות בדרום הארץ וייתכן אובך באזור. """.strip() batch = [text] embeddings = encode(batch) decoder.eval() out = decoder.generate(encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), max_length=512, repetition_penalty=3.0) for t, o in zip(batch, out): print(t) print(dec_tokenizer.decode(o, skip_special_tokens=True)) print('-----------') ```