Blancior commited on
Commit
bd50a3b
·
verified ·
1 Parent(s): be5cf4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -21
app.py CHANGED
@@ -2,28 +2,47 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Inicjalizacja modelu
6
- model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it", device_map="auto", torch_dtype=torch.float16)
7
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
8
-
9
- def generate_description(prompt):
10
- inputs = tokenizer(prompt, return_tensors="pt")
11
- outputs = model.generate(
12
- **inputs,
13
- max_new_tokens=100,
14
- temperature=0.7,
15
- top_p=0.9,
16
- repetition_penalty=1.2,
17
- do_sample=True
18
  )
19
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Interfejs
22
- interface = gr.Interface(
23
- fn=generate_description,
24
- inputs=gr.Textbox(label="Prompt"),
25
- outputs=gr.Textbox(label="Generated Description"),
26
- title="RPG Battle Descriptions Generator"
 
 
 
 
27
  )
28
 
29
- interface.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ def load_model():
6
+ model_name = "TheBloke/Llama-2-13B-chat-GPTQ"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_name,
10
+ device_map="auto",
11
+ trust_remote_code=True,
12
+ revision="main"
 
 
 
 
 
13
  )
14
+ return model, tokenizer
15
+
16
+ def generate_response(prompt, max_length=100):
17
+ try:
18
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
19
+ outputs = model.generate(
20
+ **inputs,
21
+ max_new_tokens=max_length,
22
+ temperature=0.7,
23
+ top_p=0.9,
24
+ repetition_penalty=1.2,
25
+ do_sample=True
26
+ )
27
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+ return response
29
+ except Exception as e:
30
+ return f"Error: {str(e)}"
31
+
32
+ print("Ładowanie modelu...")
33
+ model, tokenizer = load_model()
34
+ print("Model załadowany!")
35
 
36
+ # Interfejs Gradio
37
+ iface = gr.Interface(
38
+ fn=generate_response,
39
+ inputs=[
40
+ gr.Textbox(label="Prompt", lines=5),
41
+ gr.Slider(minimum=1, maximum=500, value=100, label="Max Length")
42
+ ],
43
+ outputs=gr.Textbox(label="Response", lines=5),
44
+ title="Llama 2 Chat Bot",
45
+ description="Bot RPG oparty na Llama 2"
46
  )
47
 
48
+ iface.launch()