akoksal commited on
Commit
7c6863d
·
1 Parent(s): 10a2445

Update app

Browse files
Files changed (1) hide show
  1. app.py +24 -6
app.py CHANGED
@@ -2,12 +2,22 @@ import gradio as gr
2
  from transformers import AutoTokenizer, pipeline
3
  import torch
4
 
5
- tokenizer = AutoTokenizer.from_pretrained("notexist/ttt")
6
- tdk = pipeline('text-generation', model='notexist/ttt', tokenizer=tokenizer)
 
 
7
 
8
  def predict(name, sl, topk, topp):
9
  if name == "":
10
- x = tdk(f"<|endoftext|>",
 
 
 
 
 
 
 
 
11
  do_sample=True,
12
  max_length=64,
13
  top_k=topk,
@@ -16,9 +26,17 @@ def predict(name, sl, topk, topp):
16
  repetition_penalty=sl
17
  )[0]["generated_text"]
18
 
19
- return x[len(f"<|endoftext|>"):]
20
  else:
21
- x = tdk(f"<|endoftext|>{name}\n\n",
 
 
 
 
 
 
 
 
22
  do_sample=True,
23
  max_length=64,
24
  top_k=topk,
@@ -27,7 +45,7 @@ def predict(name, sl, topk, topp):
27
  repetition_penalty=sl
28
  )[0]["generated_text"]
29
 
30
- return x[len(f"<|endoftext|>{name}\n\n"):]
31
 
32
 
33
 
 
2
  from transformers import AutoTokenizer, pipeline
3
  import torch
4
 
5
+ tokenizer1 = AutoTokenizer.from_pretrained("notexist/ttt")
6
+ tdk1 = pipeline('text-generation', model='notexist/ttt', tokenizer=tokenizer)
7
+ tokenizer2 = AutoTokenizer.from_pretrained("notexist/ttt")
8
+ tdk2 = pipeline('text-generation', model='notexist/ttt', tokenizer=tokenizer)
9
 
10
  def predict(name, sl, topk, topp):
11
  if name == "":
12
+ x1 = tdk1(f"<|endoftext|>",
13
+ do_sample=True,
14
+ max_length=64,
15
+ top_k=topk,
16
+ top_p=topp,
17
+ num_return_sequences=1,
18
+ repetition_penalty=sl
19
+ )[0]["generated_text"]
20
+ x2 = tdk1(f"<|endoftext|>",
21
  do_sample=True,
22
  max_length=64,
23
  top_k=topk,
 
26
  repetition_penalty=sl
27
  )[0]["generated_text"]
28
 
29
+ return x1[len(f"<|endoftext|>"):]+"\n\n"+x2[len(f"<|endoftext|>"):]
30
  else:
31
+ x1 = tdk1(f"<|endoftext|>{name}\n\n",
32
+ do_sample=True,
33
+ max_length=64,
34
+ top_k=topk,
35
+ top_p=topp,
36
+ num_return_sequences=1,
37
+ repetition_penalty=sl
38
+ )[0]["generated_text"]
39
+ x2 = tdk2(f"<|endoftext|>{name}\n\n",
40
  do_sample=True,
41
  max_length=64,
42
  top_k=topk,
 
45
  repetition_penalty=sl
46
  )[0]["generated_text"]
47
 
48
+ return x1[len(f"<|endoftext|>{name}\n\n"):]+"\n\n"+x2[len(f"<|endoftext|>{name}\n\n"):]
49
 
50
 
51