ai-forever commited on
Commit
8f9f40e
·
1 Parent(s): 3526dc3

add proper generation wrapper

Browse files
Files changed (1) hide show
  1. app.py +22 -1
app.py CHANGED
@@ -1,5 +1,9 @@
1
  import gradio as gr
2
 
 
 
 
 
3
  description = "Multilingual generation with mGPT"
4
  title = "Generate your own example"
5
 
@@ -11,12 +15,29 @@ article = (
11
  "</p>"
12
  )
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  interface = gr.Interface.load("huggingface/sberbank-ai/mGPT",
15
  description=description,
16
  examples=examples,
 
 
 
17
  thumbnail = 'https://habrastorage.org/r/w1560/getpro/habr/upload_files/26a/fa1/3e1/26afa13e1d1a56f54c7b0356761af7b8.png',
18
  theme = "peach",
19
  article = article
20
  )
21
 
22
- interface.launch()
 
1
  import gradio as gr
2
 
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/mGPT")
5
+ model = GPT2LMHeadModel.from_pretrained("sberbank-ai/mGPT")
6
+
7
  description = "Multilingual generation with mGPT"
8
  title = "Generate your own example"
9
 
 
15
  "</p>"
16
  )
17
 
18
+ def generate(prompt: str):
19
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda(device)
20
+ out = model.generate(
21
+ input_ids,
22
+ min_length=100,
23
+ max_length=200,
24
+ top_p=0.8,
25
+ top_k=0
26
+ no_repeat_ngram_size=5
27
+ )
28
+ generated_text = list(map(tokenizer.decode, out))[0]
29
+ return generated_text
30
+
31
+
32
  interface = gr.Interface.load("huggingface/sberbank-ai/mGPT",
33
  description=description,
34
  examples=examples,
35
+ fn=generate,
36
+ inputs="text",
37
+ outputs='text',
38
  thumbnail = 'https://habrastorage.org/r/w1560/getpro/habr/upload_files/26a/fa1/3e1/26afa13e1d1a56f54c7b0356761af7b8.png',
39
  theme = "peach",
40
  article = article
41
  )
42
 
43
+ interface.launch(enable_queue=True)