mannadamay12 commited on
Commit
26b862a
·
verified ·
1 Parent(s): ec21171

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -70
app.py CHANGED
@@ -3,91 +3,71 @@ import torch
3
  import gradio as gr
4
  import spaces
5
  from huggingface_hub import InferenceClient
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
  from langchain.vectorstores import Chroma
8
  from langchain.prompts import PromptTemplate
 
 
 
9
 
10
- # Verify PyTorch version compatibility
11
- TORCH_VERSION = torch.__version__
12
- SUPPORTED_TORCH_VERSIONS = ['2.0.1', '2.1.2', '2.2.2', '2.4.0']
13
- if TORCH_VERSION.rsplit('+')[0] not in SUPPORTED_TORCH_VERSIONS:
14
- print(f"Warning: Current PyTorch version {TORCH_VERSION} may not be compatible with ZeroGPU. "
15
- f"Supported versions are: {', '.join(SUPPORTED_TORCH_VERSIONS)}")
16
 
17
- # Initialize components outside of GPU scope
18
- client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")
19
- embeddings = HuggingFaceEmbeddings(
20
- model_name="sentence-transformers/all-MiniLM-L6-v2",
21
- model_kwargs={"device": "cpu"} # Keep embeddings on CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
- # Load database
25
  db = Chroma(
26
  persist_directory="db",
27
  embedding_function=embeddings
28
  )
29
 
30
- # Prompt templates
31
- DEFAULT_SYSTEM_PROMPT = """
32
- Based on the information in this document provided in context, answer the question as accurately as possible in 1 or 2 lines. If the information is not in the context,
33
- respond with "I don't know" or a similar acknowledgment that the answer is not available.
34
- """.strip()
35
-
36
- def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
37
- return f"""
38
- [INST] <<SYS>>
39
- {system_prompt}
40
- <</SYS>>
41
-
42
- {prompt} [/INST]
43
- """.strip()
44
-
45
- template = generate_prompt(
46
- """
47
- {context}
48
-
49
- Question: {question}
50
- """,
51
- system_prompt="Use the following pieces of context to answer the question at the end. Do not provide commentary or elaboration more than 1 or 2 lines.?"
52
  )
53
 
54
- prompt_template = PromptTemplate(template=template, input_variables=["context", "question"])
 
 
 
 
 
 
 
 
55
 
56
- @spaces.GPU(duration=30) # Reduced duration for faster queue priority
57
- def respond(
58
- message,
59
- history,
60
- system_message,
61
- max_tokens,
62
- temperature,
63
- top_p,
64
- ):
65
- """GPU-accelerated response generation"""
66
  try:
67
- # Retrieve context (CPU operation)
68
- docs = db.similarity_search(message, k=2)
69
- context = "\n".join([doc.page_content for doc in docs])
70
- print(f"Retrieved context: {context[:200]}...")
71
-
72
- # Format prompt
73
- formatted_prompt = prompt_template.format(
74
- context=context,
75
- question=message
76
- )
77
- print(f"Full prompt: {formatted_prompt}")
78
-
79
- # Stream response (GPU operation)
80
- response = ""
81
- for message in client.text_generation(
82
- prompt=formatted_prompt,
83
- max_new_tokens=max_tokens,
84
- stream=True,
85
- temperature=temperature,
86
- top_p=top_p,
87
- ):
88
- response += message
89
- yield response
90
-
91
  except Exception as e:
92
  yield f"An error occurred: {str(e)}"
93
 
 
3
  import gradio as gr
4
  import spaces
5
  from huggingface_hub import InferenceClient
6
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
7
  from langchain.vectorstores import Chroma
8
  from langchain.prompts import PromptTemplate
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.llms import HuggingFacePipeline
11
+ from transformers import AutoTokenizer, TextStreamer, pipeline, BitsAndBytesConfig, AutoModelForCausalLM
12
 
13
+ # Model initialization
14
+ model_id = "meta-llama/Llama-3.2-3B-Instruct"
15
+ token = os.environ.get("HF_TOKEN")
 
 
 
16
 
17
+ bnb_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype=torch.bfloat16
22
+ )
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ token=token,
28
+ quantization_config=bnb_config
29
+ )
30
+
31
+ # Initialize InstructEmbeddings
32
+ embeddings = HuggingFaceInstructEmbeddings(
33
+ model_name="hkunlp/instructor-base",
34
+ model_kwargs={"device": "cpu"}
35
  )
36
 
 
37
  db = Chroma(
38
  persist_directory="db",
39
  embedding_function=embeddings
40
  )
41
 
42
+ # Setup pipeline
43
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
44
+ text_pipeline = pipeline(
45
+ "text-generation",
46
+ model=model,
47
+ tokenizer=tokenizer,
48
+ max_new_tokens=500,
49
+ temperature=0.1,
50
+ top_p=0.95,
51
+ repetition_penalty=1.15,
52
+ streamer=streamer,
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
+ # Create LLM chain
56
+ llm = HuggingFacePipeline(pipeline=text_pipeline)
57
+ qa_chain = RetrievalQA.from_chain_type(
58
+ llm=llm,
59
+ chain_type="stuff",
60
+ retriever=db.as_retriever(search_kwargs={"k": 2}),
61
+ return_source_documents=False,
62
+ chain_type_kwargs={"prompt": prompt_template}
63
+ )
64
 
65
+ @spaces.GPU(duration=30)
66
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
67
  try:
68
+ # Use the QA chain directly
69
+ response = qa_chain.invoke({"query": message})
70
+ yield response["result"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
  yield f"An error occurred: {str(e)}"
73