ZhangCheng commited on
Commit
5695f59
·
1 Parent(s): 894065c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -26,11 +26,12 @@ trained_tokenizer_path = 'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation'
26
 
27
  class QuestionGeneration:
28
 
29
- def __init__(self, model_dir=None):
30
  self.model = T5ForConditionalGeneration.from_pretrained(trained_model_path)
31
  self.tokenizer = T5Tokenizer.from_pretrained(trained_tokenizer_path)
32
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
  self.model = self.model.to(self.device)
 
34
 
35
  def generate(self, answer:str, context:str):
36
  input_text = '<answer> %s <context> %s ' % (answer, context)
@@ -40,7 +41,6 @@ class QuestionGeneration:
40
  )
41
  input_ids = encoding['input_ids'].to(self.device)
42
  attention_mask = encoding['attention_mask'].to(self.device)
43
- self.model.eval()
44
  beam_outputs = self.model.generate(
45
  input_ids = input_ids,
46
  attention_mask = attention_mask
 
26
 
27
  class QuestionGeneration:
28
 
29
+ def __init__(self):
30
  self.model = T5ForConditionalGeneration.from_pretrained(trained_model_path)
31
  self.tokenizer = T5Tokenizer.from_pretrained(trained_tokenizer_path)
32
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
  self.model = self.model.to(self.device)
34
+ self.model.eval()
35
 
36
  def generate(self, answer:str, context:str):
37
  input_text = '<answer> %s <context> %s ' % (answer, context)
 
41
  )
42
  input_ids = encoding['input_ids'].to(self.device)
43
  attention_mask = encoding['attention_mask'].to(self.device)
 
44
  beam_outputs = self.model.generate(
45
  input_ids = input_ids,
46
  attention_mask = attention_mask