description
DNA sequence summary model
pretrained by DNA + ENG text
finetuned by English summary data
test the Multilingual transfer ability from ENG to DNA
code example
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoTokenizer, AutoModelForCausalLM
import math
from transformers import LogitsProcessorList, LogitsProcessor
import torch
# 加载 GPT-2 分词器
tokenizer = AutoTokenizer.from_pretrained("dnagpt/gene_eng_gpt2_summary")
tokenizer.pad_token = tokenizer.eos_token # 设置填充标记为 EOS 标记
# 6. 加载 GPT-2 模型
model = GPT2LMHeadModel.from_pretrained("dnagpt/gene_eng_gpt2_summary")
model.config.pad_token_id = model.config.eos_token_id
def classify_sequence(sequence):
# 定义字符集(所有字符都假设为大写)
dna_chars = set('ACGT')
protein_chars = set('ACDEFGHIKLMNPQRSTVWY')
english_chars = set('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 ,.!?:;-"\'()')
# 去除空格并检查长度
sequence = sequence.strip() #
# 检查是否为DNA序列
if all(c in dna_chars for c in sequence):
return "DNA"
# 检查是否为蛋白质序列
if all(c in protein_chars for c in sequence):
return "Protein"
# 检查是否为英文文本(允许大小写字母、数字及常见标点符号)
if all(c in english_chars for c in sequence):
return "English"
# 如果不符合上述任何条件,则无法明确分类
return "Unknown"
#获得DNA和英文词表 只要长度2个及以上的词
word_dict = tokenizer.get_vocab()
DNA_token_list = []
for word in word_dict:
word_type = classify_sequence(word)
if "DNA"==word_type:
DNA_token_list.append(word)
class DNAOnlyLogitsProcessor(LogitsProcessor):
def __init__(self, allowed_tokens, tokenizer):
self.allowed_token_ids = tokenizer.convert_tokens_to_ids(allowed_tokens)
def __call__(self, input_ids, scores):
# 创建掩码,将不允许的 token 的分数设为 -inf
mask = torch.full_like(scores, float("-inf"))
mask[:, self.allowed_token_ids] = 0
scores += mask
return scores
def get_summary_with_constraints(text, DNA_token_list):
# 确保输入文本的预处理
text = text.strip() + " TL;DR:"
# 对输入文本进行编码
encoded_input = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=256, # 输入文本的最大长度
)
# 创建 DNA 限制的 LogitsProcessor
logits_processor = LogitsProcessorList([
DNAOnlyLogitsProcessor(DNA_token_list, tokenizer)
])
# 使用 max_new_tokens 控制生成长度
output = model.generate(
input_ids=encoded_input["input_ids"],
attention_mask=encoded_input["attention_mask"],
max_new_tokens=16, # 控制生成的新增文本长度
num_beams=5, # 控制生成文本的多样性
logits_processor=logits_processor,
no_repeat_ngram_size=3, # 避免生成重复内容
early_stopping=True, # 提前终止生成
)
# 对生成的输出进行解码
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# 提取生成的摘要部分
summary = generated_text[len(text)+len(encoded_input["input_ids"][0])-1:].strip() #字符长度+多出来的空格-1
return summary
# 示例用法
#test_text = "The DNA sequence analysis showed remarkable results."
test_text = "GTTATAACCTGTGAGAGTATGTTGGCGGTTTGTTGCACCTACCTTTCAAACCTCTTGTTCTTCCTGTGATTTATTTGAGGCACTCAAGTGGACAGAGACCATGAGAAATTTGAGTGGAGGCCATGTCGAAGAGTTTGTCTTGGTGGGTTTCCCTACCACTCCTCCCTTCCAGCTGCTCCTCTTTGTCCTTTTCTTTGCAATTTACCTTCTGACATTGTTGGAGAATGCACTCATTGTCTTCACAATATGGCTCACTCCAAGCCTTCATCGCCCCATGTACTTTTTCCTTGGCCATCTTTCTTTCCTGGAGCTTTGGTACATCAACGTCACCATTCCTCAGCTCTTGGCAGCCTTTCTTACCCAGGATAGTAGAGTCTCCTATGTAGGTTGCATGACCCAACTCTACTTCTTTATTGCCTTAGCCTGTACTGAATGTGTGCTGTTGGCAGTTATGGCCTATGACCGC"
print(get_summary_with_constraints(test_text, DNA_token_list))
- Downloads last month
- 16
Inference Providers
NEW
This model is not currently available via any of the supported third-party Inference Providers, and
the model is not deployed on the HF Inference API.