Daimon commited on
Commit
68ea804
·
1 Parent(s): 16f4450

Fixed load_model() method

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -11,8 +11,8 @@ st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layou
11
 
12
  @st.cache
13
  def load_model():
14
- return M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
15
-
16
 
17
  def get_translation(src_code, trg_code, src):
18
 
@@ -23,6 +23,7 @@ def get_translation(src_code, trg_code, src):
23
  #forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
24
  #)
25
  #trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
 
26
  tokenizer.tgt_lang = trg_code
27
  encoded = tokenizer(src, return_tensors="pt")
28
  generated_tokens = model.generate(**encoded)
 
11
 
12
  @st.cache
13
  def load_model():
14
+ model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
15
+ return model
16
 
17
  def get_translation(src_code, trg_code, src):
18
 
 
23
  #forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
24
  #)
25
  #trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
26
+ model = load_model()
27
  tokenizer.tgt_lang = trg_code
28
  encoded = tokenizer(src, return_tensors="pt")
29
  generated_tokens = model.generate(**encoded)