|
import os |
|
import multiprocessing |
|
import concurrent.futures |
|
from langchain.document_loaders import TextLoader, DirectoryLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import FAISS |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import torch |
|
import numpy as np |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig |
|
from datetime import datetime |
|
import json |
|
import gradio as gr |
|
import re |
|
|
|
|
|
import transformers |
|
from transformers import BloomForCausalLM |
|
from transformers import BloomForTokenClassification |
|
from transformers import BloomForTokenClassification |
|
from transformers import BloomTokenizerFast |
|
import torch |
|
class DocumentRetrievalAndGeneration: |
|
def __init__(self, embedding_model_name, lm_model_id, data_folder): |
|
|
|
hf="hf_VuNNBwnFqlcKzV" |
|
token="vCfLXEBxyAOftxvlWpwf" |
|
self.hf_token=hf+token |
|
|
|
self.all_splits = self.load_documents(data_folder) |
|
self.embeddings = SentenceTransformer(embedding_model_name) |
|
self.cpu_index = self.create_faiss_index() |
|
self.llm = self.initialize_llm2(lm_model_id) |
|
|
|
|
|
def load_documents(self, folder_path): |
|
loader = DirectoryLoader(folder_path, loader_cls=TextLoader) |
|
documents = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250) |
|
all_splits = text_splitter.split_documents(documents) |
|
print('Length of documents:', len(documents)) |
|
print("LEN of all_splits", len(all_splits)) |
|
return all_splits |
|
|
|
def create_faiss_index(self): |
|
all_texts = [split.page_content for split in self.all_splits] |
|
embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy() |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(embeddings) |
|
return index |
|
|
|
def initialize_llm(self, model_id): |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config,token=self.hf_token) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
generate_text = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
return_full_text=True, |
|
task='text-generation', |
|
temperature=0.6, |
|
max_new_tokens=256, |
|
) |
|
return generate_text |
|
def initialize_llm2(self,model_id): |
|
|
|
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_response_with_timeout(self, model_inputs): |
|
try: |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
future = executor.submit(self.llm.model.generate, model_inputs, max_new_tokens=1000, do_sample=True) |
|
generated_ids = future.result(timeout=800) |
|
return generated_ids |
|
except concurrent.futures.TimeoutError: |
|
return "Text generation process timed out" |
|
raise TimeoutError("Text generation process timed out") |
|
|
|
def query_and_generate_response(self, query): |
|
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() |
|
distances, indices = self.cpu_index.search(np.array([query_embedding]), k=5) |
|
|
|
content = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for idx in indices[0]: |
|
if idx < len(self.all_splits) and idx < len(distances[0]): |
|
content += "-" * 50 + "\n" |
|
content += self.all_splits[idx].page_content + "\n" |
|
distance = distances[0][idx] |
|
print("CHUNK", idx) |
|
print("Distance :", distance) |
|
print(self.all_splits[idx].page_content) |
|
print("############################") |
|
else: |
|
print(f"Index {idx} is out of bounds. Skipping.") |
|
|
|
prompt = f"""<s> |
|
You are a knowledgeable assistant with access to a comprehensive database. |
|
I need you to answer my question and provide related information in a specific format. |
|
I have provided five relatable json files {content}, choose the most suitable chunks for answering the query |
|
Here's what I need: |
|
Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point. |
|
content |
|
Here's my question: |
|
Query: |
|
Solution==> |
|
RETURN ONLY SOLUTION . IF THEIR IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS , RETURN " NO SOLUTION AVAILABLE" |
|
IF THE QUERY AND THE RETRIEVED CHUNKS DO NOT CORRELATE MEANINGFULLY, OR IF THE QUERY IS NOT RELEVANT TO TDA2 OR RELATED TOPICS, THEN "NO SOLUTION AVAILABLE." |
|
Example1 |
|
Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM", |
|
Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.", |
|
|
|
Example2 |
|
Query: "Can BQ25896 support I2C interface?", |
|
Solution: "Yes, the BQ25896 charger supports the I2C interface for communication." |
|
Example3 |
|
Query: "Who is the fastest runner in the world", |
|
Solution:"NO SOLUTION AVAILABLE" |
|
Example4 |
|
Query:"What is the price of latest apple MACBOOK " |
|
Solution:"NO SOLUTION AVAILABLE" |
|
</s> |
|
""" |
|
messages = [{"role": "system", "content": prompt}] |
|
messages.append({"role": "user", "content": message}) |
|
response = "" |
|
|
|
for message in client.chat_completion(messages,max_tokens=2048,stream=True,temperature=0.7): |
|
token = message.choices[0].delta.content |
|
response += token |
|
|
|
generated_response=response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL) |
|
|
|
match2 = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE) |
|
if match1: |
|
solution_text = match1.group(1).strip() |
|
if "Solution:" in solution_text: |
|
solution_text = solution_text.split("Solution:", 1)[1].strip() |
|
elif match2: |
|
solution_text = match2.group(1).strip() |
|
else: |
|
solution_text=generated_response |
|
print("Generated response:", generated_response) |
|
print("Time elapsed:", elapsed_time) |
|
print("Device in use:", self.llm.device) |
|
|
|
return solution_text, content |
|
|
|
def qa_infer_gradio(self, query): |
|
response = self.query_and_generate_response(query) |
|
return response |
|
|
|
if __name__ == "__main__": |
|
print("starting...") |
|
embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12' |
|
|
|
lm_model_id= "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" |
|
data_folder = 'text_files' |
|
|
|
doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder) |
|
|
|
def launch_interface(): |
|
css_code = """ |
|
.gradio-container { |
|
background-color: #daccdb; |
|
} |
|
/* Button styling for all buttons */ |
|
button { |
|
background-color: #927fc7; /* Default color for all other buttons */ |
|
color: black; |
|
border: 1px solid black; |
|
padding: 10px; |
|
margin-right: 10px; |
|
font-size: 16px; /* Increase font size */ |
|
font-weight: bold; /* Make text bold */ |
|
} |
|
""" |
|
EXAMPLES = ["What are the main types of blood cancer, and how do they differ in terms of symptoms, progression, and treatment options? ", |
|
"What are the latest advancements in the treatment of blood cancer, and how do they improve patient outcomes compared to traditional therapies?", |
|
"How do genetic factors and environmental exposures contribute to the risk of developing blood cancer, and what preventive measures can be taken?"] |
|
|
|
interface = gr.Interface( |
|
fn=doc_retrieval_gen.qa_infer_gradio, |
|
inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")], |
|
allow_flagging='never', |
|
examples=EXAMPLES, |
|
cache_examples=False, |
|
outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")], |
|
css=css_code |
|
) |
|
|
|
interface.launch(debug=True) |
|
|
|
launch_interface() |
|
|