File size: 13,633 Bytes
474a5df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import math
from transformers import MarianTokenizer
from datasets import load_dataset
from typing import List
from torch import Tensor
from torch.nn import Transformer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from timeit import default_timer as timer
import urllib.request
import os
from torch.cuda.amp import GradScaler, autocast
import logging
logging.getLogger("datasets").setLevel(logging.ERROR)
print("CUDA是否可用:", torch.cuda.is_available())
print("PyTorch版本:", torch.__version__)
if torch.cuda.is_available():
print("CUDA版本:", torch.version.cuda)
# 设置设备
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("当前使用设备:", DEVICE)
if torch.cuda.is_available():
print(f"GPU信息: {torch.cuda.get_device_name(0)}")
print(f"当前GPU显存使用: {torch.cuda.memory_allocated(0)/1024**2:.2f} MB")
# 初始化tokenizer,MarianMT模型主要是通过其tokenizer(分词器)在发挥作用,而不是使��其预训练的翻译能力
tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-de-en')
# 定义特殊token的索引
PAD_IDX = tokenizer.pad_token_id
BOS_IDX = tokenizer.bos_token_id
EOS_IDX = tokenizer.eos_token_id
UNK_IDX = tokenizer.unk_token_id
# 获取词汇表大小
SRC_VOCAB_SIZE = tokenizer.vocab_size
TGT_VOCAB_SIZE = tokenizer.vocab_size
class PositionalEncoding(nn.Module):
def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
super(PositionalEncoding, self).__init__()
den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(-2)
self.dropout = nn.Dropout(dropout)
self.register_buffer('pos_embedding', pos_embedding)
def forward(self, token_embedding: Tensor):
return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens: Tensor):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
class Seq2SeqTransformer(nn.Module):
def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
emb_size: int, nhead: int, src_vocab_size: int,
tgt_vocab_size: int, dim_feedforward: int = 512, dropout: float = 0.1):
super(Seq2SeqTransformer, self).__init__()
self.transformer = Transformer(d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout)
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
tgt_mask: Tensor, src_padding_mask: Tensor,
tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_mask(src, tgt):
src_seq_len = src.shape[0]
tgt_seq_len = tgt.shape[0]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)
src_padding_mask = (src == PAD_IDX).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
def download_multi30k():
base_url = "https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/"
# 创建数据目录
os.makedirs("multi30k", exist_ok=True)
# 下载训练、验证和测试数据
splits = ['train', 'val', 'test']
languages = ['de', 'en']
for split in splits:
for lang in languages:
filename = f"{split}.{lang}"
url = f"{base_url}{filename}"
path = f"multi30k/{filename}"
if not os.path.exists(path):
print(f"Downloading {filename}...")
urllib.request.urlretrieve(url, path)
def load_data():
# 加载WMT14数据集的德英对
dataset = load_dataset("wmt14", "de-en", cache_dir=".cache")
# 为了便于训练,我们只使用一部分数据
train_size = 29000 # 与Multi30k训练集大小相近
val_size = 1000
test_size = 1000
# 处理数据集
data = {
'train': {
'de': [item['de'] for item in dataset['train']['translation'][:train_size]],
'en': [item['en'] for item in dataset['train']['translation'][:train_size]]
},
'val': {
'de': [item['de'] for item in dataset['validation']['translation'][:val_size]],
'en': [item['en'] for item in dataset['validation']['translation'][:val_size]]
},
'test': {
'de': [item['de'] for item in dataset['test']['translation'][:test_size]],
'en': [item['en'] for item in dataset['test']['translation'][:test_size]]
}
}
return data
# 添加一个自定义Dataset类
class TranslationDataset(Dataset):
def __init__(self, de_texts, en_texts):
self.de_texts = de_texts
self.en_texts = en_texts
def __len__(self):
return len(self.de_texts)
def __getitem__(self, idx):
return {
'de': self.de_texts[idx],
'en': self.en_texts[idx]
}
print("正在加载数据集...")
_cached_data = load_data() # 全局缓存数据
def get_dataloader(split='train', batch_size=32):
# 使用缓存的数据而不是重新加载
data = _cached_data[split]
# 创建Dataset对象
dataset = TranslationDataset(data['de'], data['en'])
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=(split == 'train')
)
# 修改模型参数,减少显存使用
BATCH_SIZE = 32 # 减小批次大小,原来是64
EMB_SIZE = 512 # 保持不变
NHEAD = 8 # 保持不变
FFN_HID_DIM = 512 # 改回512,原来改成了1024
NUM_ENCODER_LAYERS = 3 # 改回3,原来改成了4
NUM_DECODER_LAYERS = 3 # 改回3,原来改成了4
NUM_EPOCHS = 18 # 保持不变
# 实例化模型
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
transformer = transformer.to(DEVICE)
# 初始化参数
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# 创建梯度缩放器
scaler = GradScaler()
def train_epoch(model, optimizer):
try:
model.train()
losses = 0
train_dataloader = get_dataloader('train', BATCH_SIZE)
for batch in train_dataloader:
src_texts = batch['de']
tgt_texts = batch['en']
# 使用自动混合精度
with autocast():
src_tokens = tokenizer(src_texts, padding=True, return_tensors='pt')
tgt_tokens = tokenizer(tgt_texts, padding=True, return_tensors='pt')
src = src_tokens['input_ids'].transpose(0, 1).to(DEVICE)
tgt = tgt_tokens['input_ids'].transpose(0, 1).to(DEVICE)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask)
tgt_out = tgt[1:, :]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
losses += loss.item()
return losses / len(train_dataloader)
except KeyboardInterrupt:
print("\n训练被手动中断!正在保存当前模型状态...")
# 保存检查点
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch, # 保存当前的epoch
'train_loss': train_loss,
'val_loss': val_loss
}
torch.save(checkpoint, 'transformer_translation.pth')
print("模型检查点已保存到 transformer_translation.pth")
raise KeyboardInterrupt
def evaluate(model):
model.eval()
losses = 0
val_dataloader = get_dataloader('val', BATCH_SIZE)
for batch in val_dataloader:
src_texts = batch['de']
tgt_texts = batch['en']
src_tokens = tokenizer(src_texts, padding=True, return_tensors='pt')
tgt_tokens = tokenizer(tgt_texts, padding=True, return_tensors='pt')
src = src_tokens['input_ids'].transpose(0, 1).to(DEVICE)
tgt = tgt_tokens['input_ids'].transpose(0, 1).to(DEVICE)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask)
tgt_out = tgt[1:, :]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
losses += loss.item()
return losses / len(val_dataloader)
def greedy_decode(model, src, src_mask, max_len, start_symbol):
src = src.to(DEVICE)
src_mask = src_mask.to(DEVICE)
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
for i in range(max_len-1):
memory = memory.to(DEVICE)
tgt_mask = (generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(DEVICE)
out = model.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
if next_word == EOS_IDX:
break
return ys
def translate(model: torch.nn.Module, src_sentence: str):
model.eval()
tokens = tokenizer(src_sentence, return_tensors='pt', padding=True)
src = tokens['input_ids'].transpose(0, 1).to(DEVICE)
src_mask = (torch.zeros(src.shape[0], src.shape[0])).type(torch.bool).to(DEVICE)
tgt_tokens = greedy_decode(model, src, src_mask, max_len=src.shape[0] + 5, start_symbol=BOS_IDX).flatten()
return tokenizer.decode(tgt_tokens.tolist(), skip_special_tokens=True)
# 在训练前添加显存清理
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 训练模型
for epoch in range(1, NUM_EPOCHS + 1):
start_time = timer()
train_loss = train_epoch(transformer, optimizer)
end_time = timer()
val_loss = evaluate(transformer)
print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "
f"Epoch time = {(end_time - start_time):.3f}s")
# 保存模型
path = 'transformer_translation.pth'
torch.save(transformer.state_dict(), path)
print("模型保存成功!")
# 加载模型
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
transformer.load_state_dict(torch.load(path))
transformer = transformer.to(DEVICE)
print("模型加载成功!")
# 测试翻译
print(translate(transformer, "Eine Gruppe von Freunden spielt Billiade."))
|