ZhangCheng
commited on
Commit
·
5695f59
1
Parent(s):
894065c
Update README.md
Browse files
README.md
CHANGED
@@ -26,11 +26,12 @@ trained_tokenizer_path = 'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation'
|
|
26 |
|
27 |
class QuestionGeneration:
|
28 |
|
29 |
-
def __init__(self
|
30 |
self.model = T5ForConditionalGeneration.from_pretrained(trained_model_path)
|
31 |
self.tokenizer = T5Tokenizer.from_pretrained(trained_tokenizer_path)
|
32 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
33 |
self.model = self.model.to(self.device)
|
|
|
34 |
|
35 |
def generate(self, answer:str, context:str):
|
36 |
input_text = '<answer> %s <context> %s ' % (answer, context)
|
@@ -40,7 +41,6 @@ class QuestionGeneration:
|
|
40 |
)
|
41 |
input_ids = encoding['input_ids'].to(self.device)
|
42 |
attention_mask = encoding['attention_mask'].to(self.device)
|
43 |
-
self.model.eval()
|
44 |
beam_outputs = self.model.generate(
|
45 |
input_ids = input_ids,
|
46 |
attention_mask = attention_mask
|
|
|
26 |
|
27 |
class QuestionGeneration:
|
28 |
|
29 |
+
def __init__(self):
|
30 |
self.model = T5ForConditionalGeneration.from_pretrained(trained_model_path)
|
31 |
self.tokenizer = T5Tokenizer.from_pretrained(trained_tokenizer_path)
|
32 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
33 |
self.model = self.model.to(self.device)
|
34 |
+
self.model.eval()
|
35 |
|
36 |
def generate(self, answer:str, context:str):
|
37 |
input_text = '<answer> %s <context> %s ' % (answer, context)
|
|
|
41 |
)
|
42 |
input_ids = encoding['input_ids'].to(self.device)
|
43 |
attention_mask = encoding['attention_mask'].to(self.device)
|
|
|
44 |
beam_outputs = self.model.generate(
|
45 |
input_ids = input_ids,
|
46 |
attention_mask = attention_mask
|