PierreJousselin commited on
Commit
4bbf0b6
·
verified ·
1 Parent(s): 329b94e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -7,12 +7,15 @@ model_name = "PierreJousselin/gpt2" # Replace with the name you used on Hugging
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu") # Force model to load on CPU
9
 
 
 
 
10
  # Function for generating responses using the model
11
  def generate_response(prompt):
12
  # Tokenize input prompt
13
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
14
 
15
- # Make sure the inputs are moved to the CPU
16
  input_ids = inputs["input_ids"].to("cpu")
17
 
18
  # Generate output (ensure it's on CPU)
@@ -22,13 +25,14 @@ def generate_response(prompt):
22
  response = tokenizer.decode(output[0], skip_special_tokens=True)
23
  return response
24
 
25
- # Create a Gradio interface
26
  iface = gr.Interface(
27
  fn=generate_response, # Function to call for generating response
28
  inputs=gr.Textbox(label="Input Prompt"), # Input type (text box for prompt)
29
  outputs=gr.Textbox(label="Generated Response"), # Output type (text box for response)
30
- live=True # Whether to update output live as user types
 
31
  )
32
 
33
- # Launch the interface
34
- iface.launch()
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu") # Force model to load on CPU
9
 
10
+ # Ensure pad_token_id is set to eos_token_id to avoid errors
11
+ model.config.pad_token_id = model.config.eos_token_id
12
+
13
  # Function for generating responses using the model
14
  def generate_response(prompt):
15
  # Tokenize input prompt
16
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
17
 
18
+ # Ensure the inputs are moved to the CPU
19
  input_ids = inputs["input_ids"].to("cpu")
20
 
21
  # Generate output (ensure it's on CPU)
 
25
  response = tokenizer.decode(output[0], skip_special_tokens=True)
26
  return response
27
 
28
+ # Create a Gradio interface with a "Generate" button
29
  iface = gr.Interface(
30
  fn=generate_response, # Function to call for generating response
31
  inputs=gr.Textbox(label="Input Prompt"), # Input type (text box for prompt)
32
  outputs=gr.Textbox(label="Generated Response"), # Output type (text box for response)
33
+ live=False, # Disable live update; only update when button is clicked
34
+ allow_flagging="never" # Prevent flagging (optional, if you don't need it)
35
  )
36
 
37
+ # Launch the interface with a "Generate" button
38
+ iface.launch(share=True) # You can set share=True if you want a public link