riotu-lab commited on
Commit
2a5a75a
·
verified ·
1 Parent(s): 11d504a

Upload app

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+
4
+ # Define the function to handle text generation
5
+ def generate_text(model_name, text, num_beams, max_length, top_p, temperature, repetition_penalty, no_repeat_ngram_size):
6
+ # Load tokenizer and model
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+
10
+ # Initialize pipeline
11
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
+
13
+ # Generate text with the specified parameters
14
+ generated_text = pipe(text,
15
+ pad_token_id=tokenizer.eos_token_id,
16
+ num_beams=num_beams,
17
+ max_length=max_length,
18
+ top_p=top_p,
19
+ temperature=temperature,
20
+ repetition_penalty=repetition_penalty,
21
+ no_repeat_ngram_size=no_repeat_ngram_size)[0]['generated_text']
22
+
23
+ return generated_text
24
+
25
+ # Define model options
26
+ model_options = [
27
+ "riotu-lab/ArabianGPT-01B",
28
+ "riotu-lab/ArabianGPT-03B",
29
+ "riotu-lab/ArabianGPT-08B"
30
+ ]
31
+
32
+ # Define Gradio interface components
33
+ inputs_component = [
34
+ gr.Dropdown(choices=model_options, label="Select Model"),
35
+ gr.Textbox(lines=2, placeholder="Enter your text here...", label="Input Text"),
36
+ gr.Slider(minimum=1, maximum=10, step=1, default=5, label="Num Beams"),
37
+ gr.Slider(minimum=50, maximum=300, step=10, default=200, label="Max Length"),
38
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, default=0.9, label="Top p"),
39
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, default=0.1, label="Temperature"),
40
+ gr.Slider(minimum=1.0, maximum=5.0, step=0.5, default=3.0, label="Repetition Penalty"),
41
+ gr.Slider(minimum=2, maximum=5, step=1, default=3, label="No Repeat Ngram Size")
42
+ ]
43
+
44
+ # Setup the interface
45
+ iface = gr.Interface(
46
+ fn=generate_text,
47
+ inputs=inputs_component,
48
+ outputs="text",
49
+ title="ArabianGPT Playground",
50
+ description="Explore the capabilities of ArabianGPT models. Adjust the hyperparameters to see how they affect text generation.",
51
+ live=True
52
+ )
53
+
54
+ # Launch the app
55
+ iface.launch()