File size: 2,452 Bytes
e35584d
dcc7444
e35584d
8bac0a3
 
 
 
ffe6ef0
dcc7444
 
8bac0a3
 
dcc7444
8bac0a3
7874fb0
dcc7444
 
 
 
 
 
 
 
 
 
91da428
dcc7444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91da428
dcc7444
 
 
 
 
 
 
 
 
 
91da428
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from transformers import LogitsProcessorList, StoppingCriteriaList

def custom_generate(
    self,
    input_ids,
    attention_mask=None,
    max_new_tokens=None,
    temperature=1.0,
    do_sample=True,
    pad_token_id=None,
    eos_token_id=None,
    **kwargs
):
    device = input_ids.device
    logits_processor = LogitsProcessorList()
    stopping_criteria = StoppingCriteriaList()
    
    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids)
    
    # Initialize unfinished sentences to manage the loop for early stopping
    unfinished_sents = input_ids.new(input_ids.shape[0]).fill_(1)
    
    cur_len = input_ids.shape[1]
    
    while cur_len < max_new_tokens:
        model_outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=True,
            return_dict=True
        )
        
        next_token_logits = model_outputs.logits[:, -1, :]
        
        # Processing logits to avoid generating undesired tokens
        next_token_logits[:, pad_token_id] = -float('inf')  # Never select pad
        next_token_logits[:, eos_token_id] = -float('inf')  # Avoid generating end token prematurely

        # Apply temperature scaling and softmax to generate probabilities
        if do_sample:
            probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1).squeeze(1)
        else:
            next_token = next_token_logits.argmax(dim=-1)
        
        # Update input_ids and attention_mask
        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)
        new_attention = torch.ones_like(input_ids[:, 0:1])
        attention_mask = torch.cat([attention_mask, new_attention], dim=1)
        
        # Check unfinished sentences
        unfinished_sents.mul_(next_token.ne(eos_token_id).long())
        if unfinished_sents.max() == 0:
            break
        
        cur_len += 1
    
    # Optionally return additional information
    if kwargs.get('return_dict_in_generate', False):
        output = {
            "sequences": input_ids,
            "scores": None,  # Placeholder for when score calculations are implemented
            "attentions": model_outputs.attentions,
            "hidden_states": model_outputs.hidden_states
        }
    else:
        output = input_ids
    
    return output