jjreif commited on
Commit
fa6f171
·
1 Parent(s): 6d4c8c6

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -0
handler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import NougatProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+
5
+
6
+ # check for GPU
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ # load the model
13
+ self.processor = NougatProcessor.from_pretrained(path)
14
+ self.model = VisionEncoderDecoderModel.from_pretrained(path)
15
+ # move model to device
16
+ self.model.to(device)
17
+ # self.decoder_input_ids = self.processor.tokenizer(
18
+ # "<s_cord-v2>", add_special_tokens=False, return_tensors="pt"
19
+ # ).input_ids
20
+
21
+ def __call__(self, data):
22
+
23
+ inputs = data.pop("inputs", data)
24
+
25
+
26
+ # preprocess the input
27
+ pixel_values = self.processor(inputs, return_tensors="pt").pixel_values
28
+ print(type(pixel_values))
29
+ # forward pass
30
+ outputs = self.model.generate(
31
+ pixel_values.to(device),
32
+ min_length = 1,
33
+ # decoder_input_ids=self.decoder_input_ids.to(device),
34
+ max_length=3584,
35
+ # early_stopping=True,
36
+ # pad_token_id=self.processor.tokenizer.pad_token_id,
37
+ # eos_token_id=self.processor.tokenizer.eos_token_id,
38
+ # use_cache=True,
39
+ # num_beams=1,
40
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
41
+ # return_dict_in_generate=True,
42
+ )
43
+ # process output
44
+ prediction = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
45
+ prediction = self.processor.post_process_generation(prediction, fix_markdown=False)
46
+
47
+ return prediction