JuanJoseMV commited on
Commit
8f5d925
·
1 Parent(s): 71eacb0
Files changed (2) hide show
  1. NeuralTextGenerator.py +2 -2
  2. app.py +3 -3
NeuralTextGenerator.py CHANGED
@@ -20,7 +20,7 @@ DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
20
 
21
 
22
  class BertTextGenerator:
23
- def __init__(self, model_version, device=DEFAULT_DEVICE, use_apex=APEX_AVAILABLE, use_fast=True,
24
  do_basic_tokenize=True):
25
  """
26
  Wrapper of a BERT model from AutoModelForMaskedLM from huggingfaces.
@@ -47,7 +47,7 @@ class BertTextGenerator:
47
  self.model, optimizer = amp.initialize(self.model, optimizer, opt_level="O2", keep_batchnorm_fp32=True,
48
  loss_scale="dynamic")
49
 
50
- self.tokenizer = AutoTokenizer.from_pretrained(model_version, do_lower_case="uncased" in model_version,
51
  use_fast=use_fast,
52
  do_basic_tokenize=do_basic_tokenize) # added to avoid splitting of unused tokens
53
  self.num_attention_masks = len(self.model.base_model.base_model.encoder.layer)
 
20
 
21
 
22
  class BertTextGenerator:
23
+ def __init__(self, model_version, tokenizer, device=DEFAULT_DEVICE, use_apex=APEX_AVAILABLE, use_fast=True,
24
  do_basic_tokenize=True):
25
  """
26
  Wrapper of a BERT model from AutoModelForMaskedLM from huggingfaces.
 
47
  self.model, optimizer = amp.initialize(self.model, optimizer, opt_level="O2", keep_batchnorm_fp32=True,
48
  loss_scale="dynamic")
49
 
50
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, do_lower_case="uncased" in model_version,
51
  use_fast=use_fast,
52
  do_basic_tokenize=do_basic_tokenize) # added to avoid splitting of unused tokens
53
  self.num_attention_masks = len(self.model.base_model.base_model.encoder.layer)
app.py CHANGED
@@ -2,13 +2,13 @@ import gradio as gr
2
  from NeuralTextGenerator import BertTextGenerator
3
 
4
  model_name = "cardiffnlp/twitter-xlm-roberta-base" #"dbmdz/bert-base-italian-uncased"
5
- en_model = BertTextGenerator(model_name)
6
 
7
  finetunned_BERT_model_name = "JuanJoseMV/BERT_text_gen"
8
- finetunned_BERT_en_model = BertTextGenerator(finetunned_BERT_model_name)
9
 
10
  finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen"
11
- finetunned_RoBERTa_en_model = BertTextGenerator(finetunned_RoBERTa_model_name)
12
 
13
  special_tokens = [
14
  '[POSITIVE-0]',
 
2
  from NeuralTextGenerator import BertTextGenerator
3
 
4
  model_name = "cardiffnlp/twitter-xlm-roberta-base" #"dbmdz/bert-base-italian-uncased"
5
+ en_model = BertTextGenerator(model_name, tokenizer='xlm-roberta')
6
 
7
  finetunned_BERT_model_name = "JuanJoseMV/BERT_text_gen"
8
+ finetunned_BERT_en_model = BertTextGenerator(finetunned_BERT_model_name, tokenizer='bert-base-uncased')
9
 
10
  finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen"
11
+ finetunned_RoBERTa_en_model = BertTextGenerator(finetunned_RoBERTa_model_name, tokenizer='xlm-roberta')
12
 
13
  special_tokens = [
14
  '[POSITIVE-0]',