Miguel Castro commited on
Commit
06b46a4
·
1 Parent(s): 0560429

Update generate_text function with attention_mask and pad_token_id

Browse files
Files changed (1) hide show
  1. script_analyzer.py +14 -5
script_analyzer.py CHANGED
@@ -11,10 +11,19 @@ sentiment_classifier = pipeline("sentiment-analysis", model=model_sentiment, tok
11
  tokenizer_gpt = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
12
  model_gpt = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
13
 
14
- # Helper function to generate text using GPT-Neo or GPT-J
15
  def generate_text(prompt):
16
- inputs = tokenizer_gpt(prompt, return_tensors="pt", truncation=True, max_length=50)
17
- outputs = model_gpt.generate(inputs["input_ids"], max_length=50, num_return_sequences=1, no_repeat_ngram_size=2)
 
 
 
 
 
 
 
 
 
18
  generated_text = tokenizer_gpt.decode(outputs[0], skip_special_tokens=True)
19
  return generated_text.strip()
20
 
@@ -114,5 +123,5 @@ with gr.Blocks() as interface:
114
 
115
  display_dashboard_button.click(display_dashboard, inputs=script_input, outputs=[output_dashboard, output_graph])
116
 
117
- # Launch the Gradio app with sharing enabled
118
- interface.launch(share=True)
 
11
  tokenizer_gpt = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
12
  model_gpt = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
13
 
14
+ # Updated generate_text function with attention_mask and pad_token_id
15
  def generate_text(prompt):
16
+ # Prepare the input tensors with attention_mask and padding
17
+ inputs = tokenizer_gpt(prompt, return_tensors="pt", padding=True, truncation=True, max_length=50)
18
+ # Generate text using max_new_tokens instead of max_length
19
+ outputs = model_gpt.generate(
20
+ inputs["input_ids"],
21
+ attention_mask=inputs["attention_mask"],
22
+ max_new_tokens=50, # Controls the new tokens generated beyond input length
23
+ num_return_sequences=1,
24
+ no_repeat_ngram_size=2,
25
+ pad_token_id=tokenizer_gpt.eos_token_id # Sets padding to eos_token_id to prevent issues
26
+ )
27
  generated_text = tokenizer_gpt.decode(outputs[0], skip_special_tokens=True)
28
  return generated_text.strip()
29
 
 
123
 
124
  display_dashboard_button.click(display_dashboard, inputs=script_input, outputs=[output_dashboard, output_graph])
125
 
126
+ # Launch the Gradio app (no need for share=True in Hugging Face Spaces)
127
+ interface.launch()