description

DNA sequence summary model

pretrained by DNA + ENG text

finetuned by English summary data

test the Multilingual transfer ability from ENG to DNA

code example

from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoTokenizer, AutoModelForCausalLM
import math
from transformers import LogitsProcessorList, LogitsProcessor
import torch


# 加载 GPT-2 分词器
tokenizer = AutoTokenizer.from_pretrained("dnagpt/gene_eng_gpt2_summary")
tokenizer.pad_token = tokenizer.eos_token  # 设置填充标记为 EOS 标记

# 6. 加载 GPT-2 模型
model = GPT2LMHeadModel.from_pretrained("dnagpt/gene_eng_gpt2_summary")
model.config.pad_token_id = model.config.eos_token_id

def classify_sequence(sequence):
    # 定义字符集(所有字符都假设为大写)
    dna_chars = set('ACGT')
    protein_chars = set('ACDEFGHIKLMNPQRSTVWY')
    english_chars = set('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 ,.!?:;-"\'()')

    # 去除空格并检查长度
    sequence = sequence.strip()  # 
    
    # 检查是否为DNA序列
    if all(c in dna_chars for c in sequence):
        return "DNA"
    
    # 检查是否为蛋白质序列
    if all(c in protein_chars for c in sequence):
        return "Protein"
    
    # 检查是否为英文文本(允许大小写字母、数字及常见标点符号)
    if all(c in english_chars for c in sequence):
        return "English"
    
    # 如果不符合上述任何条件,则无法明确分类
    return "Unknown"

#获得DNA和英文词表  只要长度2个及以上的词
word_dict = tokenizer.get_vocab()

DNA_token_list = []

for word in word_dict:
    word_type = classify_sequence(word)
    if "DNA"==word_type:
        DNA_token_list.append(word)


class DNAOnlyLogitsProcessor(LogitsProcessor):
    def __init__(self, allowed_tokens, tokenizer):
        self.allowed_token_ids = tokenizer.convert_tokens_to_ids(allowed_tokens)
    
    def __call__(self, input_ids, scores):
        # 创建掩码,将不允许的 token 的分数设为 -inf
        mask = torch.full_like(scores, float("-inf"))
        mask[:, self.allowed_token_ids] = 0
        scores += mask
        return scores

def get_summary_with_constraints(text, DNA_token_list):
    # 确保输入文本的预处理
    text = text.strip() + " TL;DR:"
    
    # 对输入文本进行编码
    encoded_input = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=256,  # 输入文本的最大长度
    )

    # 创建 DNA 限制的 LogitsProcessor
    logits_processor = LogitsProcessorList([
        DNAOnlyLogitsProcessor(DNA_token_list, tokenizer)
    ])
    
    # 使用 max_new_tokens 控制生成长度
    output = model.generate(
        input_ids=encoded_input["input_ids"],
        attention_mask=encoded_input["attention_mask"],
        max_new_tokens=16,       # 控制生成的新增文本长度
        num_beams=5,             # 控制生成文本的多样性
        logits_processor=logits_processor,
        no_repeat_ngram_size=3,  # 避免生成重复内容
        early_stopping=True,     # 提前终止生成
    )
    
    # 对生成的输出进行解码
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # 提取生成的摘要部分
    summary = generated_text[len(text)+len(encoded_input["input_ids"][0])-1:].strip() #字符长度+多出来的空格-1
    
    return summary

# 示例用法
#test_text = "The DNA sequence analysis showed remarkable results."
test_text = "GTTATAACCTGTGAGAGTATGTTGGCGGTTTGTTGCACCTACCTTTCAAACCTCTTGTTCTTCCTGTGATTTATTTGAGGCACTCAAGTGGACAGAGACCATGAGAAATTTGAGTGGAGGCCATGTCGAAGAGTTTGTCTTGGTGGGTTTCCCTACCACTCCTCCCTTCCAGCTGCTCCTCTTTGTCCTTTTCTTTGCAATTTACCTTCTGACATTGTTGGAGAATGCACTCATTGTCTTCACAATATGGCTCACTCCAAGCCTTCATCGCCCCATGTACTTTTTCCTTGGCCATCTTTCTTTCCTGGAGCTTTGGTACATCAACGTCACCATTCCTCAGCTCTTGGCAGCCTTTCTTACCCAGGATAGTAGAGTCTCCTATGTAGGTTGCATGACCCAACTCTACTTCTTTATTGCCTTAGCCTGTACTGAATGTGTGCTGTTGGCAGTTATGGCCTATGACCGC"

print(get_summary_with_constraints(test_text, DNA_token_list))
Downloads last month
16
Safetensors
Model size
162M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the model is not deployed on the HF Inference API.