File size: 2,047 Bytes
5376a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel

# Define your custom language model class
class OBILanguageModel(PreTrainedModel):
    def __init__(self, config):
        super(OBILanguageModel,self).__init__(config)
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size)  # Use length of SentencePiece vocab
        self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size)
        self.transformer = nn.Transformer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            num_encoder_layers=config.num_hidden_layers,
            num_decoder_layers=config.num_hidden_layers,
            dim_feedforward=4 * config.hidden_size,
            dropout=config.hidden_dropout_prob,
            activation='gelu'
        )
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.ln2 = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)  # Use length of SentencePiece vocab
    
    def forward(self, idx, targets=None):
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
        x = tok_emb + pos_emb
        x = self.transformer(x, x)
        x = self.ln1(x)
        x = self.ln2(x)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1))
        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx