ZhangCheng commited on
Commit
7061d13
·
1 Parent(s): ed42078

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -1
README.md CHANGED
@@ -1 +1,61 @@
1
- #T5 Base Fine-Tuned on SQuAD for Question Generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - squad
5
+ widget:
6
+ - text: "<answer> T5 <context> Cheng fine-tuned T5 on SQuAD for question generation."
7
+ example_title: "Example 1"
8
+ - text: "<answer> SQuAD <context> Cheng fine-tuned T5 on SQuAD dataset for question generation."
9
+ example_title: "Example 2"
10
+ - text: "<answer> deep learning <context> Deep learning is part of a broader family of machine learning methods based on artificial neural networks with representation learning."
11
+ example_title: "Example 3"
12
+ ---
13
+
14
+ # T5-Base Fine-Tuned on SQuAD for Question Generation
15
+
16
+ ### Model in Action:
17
+
18
+ ```python
19
+ import torch
20
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
21
+
22
+ trained_model_path = 'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation'
23
+ trained_tokenizer_path = 'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation'
24
+
25
+ class QuestionGeneration:
26
+
27
+ def __init__(self, model_dir=None):
28
+ self.model = T5ForConditionalGeneration.from_pretrained(trained_model_path)
29
+ self.tokenizer = T5Tokenizer.from_pretrained(trained_tokenizer_path)
30
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+ self.model = self.model.to(self.device)
32
+
33
+ def generate(self, answer:str, context:str):
34
+ input_text = '<answer> %s <context> %s ' % (answer, context)
35
+ encoding = self.tokenizer.encode_plus(
36
+ input_text,
37
+ return_tensors='pt'
38
+ )
39
+ input_ids = encoding['input_ids'].to(self.device)
40
+ attention_mask = encoding['attention_mask'].to(self.device)
41
+ self.model.eval()
42
+ beam_outputs = self.model.generate(
43
+ input_ids = input_ids,
44
+ attention_mask = attention_mask
45
+ )
46
+ question = self.tokenizer.decode(
47
+ beam_outputs[0],
48
+ skip_special_tokens = True,
49
+ clean_up_tokenization_spaces = True
50
+ )
51
+ return {'question': question, 'answer': answer}
52
+
53
+ if __name__ == "__main__":
54
+ context = 'ZhangCheng fine-tuned T5 on SQuAD dataset for question generation.'
55
+ answer = 'ZhangCheng'
56
+ QG = QuestionGeneration()
57
+ qa = QG.generate(answer, context)
58
+ print(qa['question'])
59
+ # Output:
60
+ # Who fine-tuned T5 on SQuAD dataset for question generation?
61
+ ```