--- library_name: transformers tags: [] --- # description DNA sequence summary model pretrained by DNA + ENG text finetuned by English summary data test the Multilingual transfer ability from ENG to DNA # code example ```python from datasets import load_dataset from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling from transformers import AutoTokenizer, DataCollatorWithPadding from transformers import AutoTokenizer, AutoModelForCausalLM import math from transformers import LogitsProcessorList, LogitsProcessor import torch # 加载 GPT-2 分词器 tokenizer = AutoTokenizer.from_pretrained("dnagpt/gene_eng_gpt2_summary") tokenizer.pad_token = tokenizer.eos_token # 设置填充标记为 EOS 标记 # 6. 加载 GPT-2 模型 model = GPT2LMHeadModel.from_pretrained("dnagpt/gene_eng_gpt2_summary") model.config.pad_token_id = model.config.eos_token_id def classify_sequence(sequence): # 定义字符集(所有字符都假设为大写) dna_chars = set('ACGT') protein_chars = set('ACDEFGHIKLMNPQRSTVWY') english_chars = set('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 ,.!?:;-"\'()') # 去除空格并检查长度 sequence = sequence.strip() # # 检查是否为DNA序列 if all(c in dna_chars for c in sequence): return "DNA" # 检查是否为蛋白质序列 if all(c in protein_chars for c in sequence): return "Protein" # 检查是否为英文文本(允许大小写字母、数字及常见标点符号) if all(c in english_chars for c in sequence): return "English" # 如果不符合上述任何条件,则无法明确分类 return "Unknown" #获得DNA和英文词表 只要长度2个及以上的词 word_dict = tokenizer.get_vocab() DNA_token_list = [] for word in word_dict: word_type = classify_sequence(word) if "DNA"==word_type: DNA_token_list.append(word) class DNAOnlyLogitsProcessor(LogitsProcessor): def __init__(self, allowed_tokens, tokenizer): self.allowed_token_ids = tokenizer.convert_tokens_to_ids(allowed_tokens) def __call__(self, input_ids, scores): # 创建掩码,将不允许的 token 的分数设为 -inf mask = torch.full_like(scores, float("-inf")) mask[:, self.allowed_token_ids] = 0 scores += mask return scores def get_summary_with_constraints(text, DNA_token_list): # 确保输入文本的预处理 text = text.strip() + " TL;DR:" # 对输入文本进行编码 encoded_input = tokenizer( text, return_tensors="pt", truncation=True, max_length=256, # 输入文本的最大长度 ) # 创建 DNA 限制的 LogitsProcessor logits_processor = LogitsProcessorList([ DNAOnlyLogitsProcessor(DNA_token_list, tokenizer) ]) # 使用 max_new_tokens 控制生成长度 output = model.generate( input_ids=encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"], max_new_tokens=16, # 控制生成的新增文本长度 num_beams=5, # 控制生成文本的多样性 logits_processor=logits_processor, no_repeat_ngram_size=3, # 避免生成重复内容 early_stopping=True, # 提前终止生成 ) # 对生成的输出进行解码 generated_text = tokenizer.decode(output[0], skip_special_tokens=True) # 提取生成的摘要部分 summary = generated_text[len(text)+len(encoded_input["input_ids"][0])-1:].strip() #字符长度+多出来的空格-1 return summary # 示例用法 #test_text = "The DNA sequence analysis showed remarkable results." test_text = "GTTATAACCTGTGAGAGTATGTTGGCGGTTTGTTGCACCTACCTTTCAAACCTCTTGTTCTTCCTGTGATTTATTTGAGGCACTCAAGTGGACAGAGACCATGAGAAATTTGAGTGGAGGCCATGTCGAAGAGTTTGTCTTGGTGGGTTTCCCTACCACTCCTCCCTTCCAGCTGCTCCTCTTTGTCCTTTTCTTTGCAATTTACCTTCTGACATTGTTGGAGAATGCACTCATTGTCTTCACAATATGGCTCACTCCAAGCCTTCATCGCCCCATGTACTTTTTCCTTGGCCATCTTTCTTTCCTGGAGCTTTGGTACATCAACGTCACCATTCCTCAGCTCTTGGCAGCCTTTCTTACCCAGGATAGTAGAGTCTCCTATGTAGGTTGCATGACCCAACTCTACTTCTTTATTGCCTTAGCCTGTACTGAATGTGTGCTGTTGGCAGTTATGGCCTATGACCGC" print(get_summary_with_constraints(test_text, DNA_token_list)) ```