Stevross commited on
Commit
fb6312f
·
1 Parent(s): 7ce9bfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
3
 
4
- # Load the model and tokenizer in full precision
5
- model = AutoModelForCausalLM.from_pretrained("PAIXAI/Astrid-1B").to(dtype=torch.float32)
6
  tokenizer = AutoTokenizer.from_pretrained("PAIXAI/Astrid-1B")
7
 
8
- # Initialize the pipeline with the model and tokenizer
9
- generate_text = TextGenerationPipeline(model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
10
 
11
  # Streamlit UI
12
  st.title("Astrid-1B Chatbot")
@@ -14,7 +14,10 @@ st.write("Test the Astrid-1B chatbot from Hugging Face!")
14
 
15
  user_input = st.text_input("Enter your question:")
16
  if user_input:
17
- response = generate_text(user_input, min_new_tokens=2, max_new_tokens=256, do_sample=False, num_beams=1, temperature=0.3, repetition_penalty=1.2, renormalize_logits=True)
18
- st.write("Response:", response[0]["generated_text"])
 
 
 
19
 
20
  st.write("Note: This is a simple UI for demonstration purposes. Ensure you have the required libraries installed and adjust the model parameters as needed.")
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
3
 
4
+ # Load the model and tokenizer in full precision and ensure it's on CPU
5
+ model = AutoModelForCausalLM.from_pretrained("PAIXAI/Astrid-1B").to(dtype=torch.float32).cpu()
6
  tokenizer = AutoTokenizer.from_pretrained("PAIXAI/Astrid-1B")
7
 
8
+ # Initialize the pipeline with the model and tokenizer to run on CPU
9
+ generate_text = TextGenerationPipeline(model=model, tokenizer=tokenizer, device=-1) # -1 forces CPU usage
10
 
11
  # Streamlit UI
12
  st.title("Astrid-1B Chatbot")
 
14
 
15
  user_input = st.text_input("Enter your question:")
16
  if user_input:
17
+ try:
18
+ response = generate_text(user_input, min_new_tokens=2, max_new_tokens=256, do_sample=False, num_beams=1, temperature=0.3, repetition_penalty=1.2, renormalize_logits=True)
19
+ st.write("Response:", response[0]["generated_text"])
20
+ except Exception as e:
21
+ st.write("Error:", str(e))
22
 
23
  st.write("Note: This is a simple UI for demonstration purposes. Ensure you have the required libraries installed and adjust the model parameters as needed.")