alexkueck commited on
Commit
5076c3d
·
verified ·
1 Parent(s): 9b64feb

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +53 -7
utils.py CHANGED
@@ -19,6 +19,10 @@ import operator
19
  from typing import Annotated, Sequence, TypedDict
20
  import pprint
21
 
 
 
 
 
22
  import gradio as gr
23
  from pypinyin import lazy_pinyin
24
  import tiktoken
@@ -51,7 +55,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field
51
  from langchain_core.runnables import RunnablePassthrough
52
  from langchain.text_splitter import RecursiveCharacterTextSplitter
53
  from chromadb.errors import InvalidDimensionException
54
- import io
55
  #from PIL import Image, ImageDraw, ImageOps, ImageFont
56
  #import base64
57
  #from tempfile import NamedTemporaryFile
@@ -127,6 +131,18 @@ urls = [
127
  ]
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  ##################################################
131
  #Normalisierung eines Prompts
132
  ##################################################
@@ -303,6 +319,7 @@ def llm_chain2(llm, prompt):
303
  def rag_chain(llm, prompt, retriever):
304
  #Langgraph nutzen für ein wenig mehr Intelligenz beim Dokumente suchen
305
  relevant_docs=[]
 
306
  relevant_docs = retriever.get_relevant_documents(prompt)
307
 
308
  print("releant docs1......................")
@@ -313,14 +330,43 @@ def rag_chain(llm, prompt, retriever):
313
  #result = llm_chain.run({"context": relevant_docs, "question": prompt})
314
  # Erstelle ein PromptTemplate mit Platzhaltern für Kontext und Frage
315
  #RAG_CHAIN_PROMPT = PromptTemplate(template="Context: {context}\n\nQuestion: {question}\n\nAnswer:")
316
-
317
- # Erstelle eine RunnableSequence
318
- chain = RunnableSequence(steps=[RAG_CHAIN_PROMPT, llm])
319
- # Verwende die Kette
320
- result = chain.invoke({"context": relevant_docs, "question": prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  else:
322
  # keine relevanten Dokumente gefunden
323
- result = "Keine relevanten Dokumente gefunden"
 
 
 
324
 
325
  return result
326
 
 
19
  from typing import Annotated, Sequence, TypedDict
20
  import pprint
21
 
22
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
23
+ from sentence_transformers import SentenceTransformer, util
24
+ from typing import List, Dict
25
+
26
  import gradio as gr
27
  from pypinyin import lazy_pinyin
28
  import tiktoken
 
55
  from langchain_core.runnables import RunnablePassthrough
56
  from langchain.text_splitter import RecursiveCharacterTextSplitter
57
  from chromadb.errors import InvalidDimensionException
58
+ #import io
59
  #from PIL import Image, ImageDraw, ImageOps, ImageFont
60
  #import base64
61
  #from tempfile import NamedTemporaryFile
 
131
  ]
132
 
133
 
134
+
135
+ ##################################################
136
+ #Modell und Tokenizer für die Anfrage der RAG Chain
137
+ ##################################################
138
+ # Schritt 1: Initialisiere den Sentence-Transformer und das Generierungsmodell
139
+ embedder_modell = SentenceTransformer('all-MiniLM-L6-v2')
140
+ HF_MODELL = "t5-small"
141
+ modell_rag = AutoModelForSeq2SeqLM.from_pretrained(HF_MODELL)
142
+ tokenizer_rag = AutoTokenizer.from_pretrained(HF_MODELL)
143
+
144
+
145
+
146
  ##################################################
147
  #Normalisierung eines Prompts
148
  ##################################################
 
319
  def rag_chain(llm, prompt, retriever):
320
  #Langgraph nutzen für ein wenig mehr Intelligenz beim Dokumente suchen
321
  relevant_docs=[]
322
+ most_relevant_docs=[]
323
  relevant_docs = retriever.get_relevant_documents(prompt)
324
 
325
  print("releant docs1......................")
 
330
  #result = llm_chain.run({"context": relevant_docs, "question": prompt})
331
  # Erstelle ein PromptTemplate mit Platzhaltern für Kontext und Frage
332
  #RAG_CHAIN_PROMPT = PromptTemplate(template="Context: {context}\n\nQuestion: {question}\n\nAnswer:")
333
+
334
+ # Inahlte Abrufen der relevanten Dokumente
335
+ doc_contents = [doc["content"] for doc in relevant_docs]
336
+
337
+ #Berechne die Ähnlichkeiten und finde das relevanteste Dokument
338
+ question_embedding = embedder_modell.encode(prompt, convert_to_tensor=True)
339
+ doc_embeddings = embedder_modell.encode(doc_contents, convert_to_tensor=True)
340
+ similarity_scores = util.pytorch_cos_sim(question_embedding, doc_embeddings)
341
+ most_relevant_doc_indices = similarity_scores.argsort(descending=True).squeeze().tolist()
342
+
343
+ #Erstelle eine Liste der relevantesten Dokumente
344
+ most_relevant_docs = [relevant_docs[i] for i in most_relevant_doc_indices]
345
+
346
+ #Kombiniere die Inhalte aller relevanten Dokumente
347
+ combined_content = " ".join([doc["content"] for doc in most_relevant_docs])
348
+
349
+ #Formuliere die Eingabe für das Generierungsmodell
350
+ input_text = f"frage: {prompt} kontext: {combined_content}"
351
+ inputs = tokenizer_rag(input_text, return_tensors="pt", max_length=1024, truncation=True)
352
+
353
+ #Generiere die Antwort
354
+ outputs = model_rag.generate(inputs['input_ids'], max_length=150, num_beams=2, early_stopping=True)
355
+ answer = tokenizer_rag.decode(outputs[0], skip_special_tokens=True)
356
+
357
+
358
+ # Erstelle das Ergebnis-Dictionary
359
+ result = {
360
+ "answer": answer,
361
+ "relevant_docs": most_relevant_docs
362
+ }
363
+
364
  else:
365
  # keine relevanten Dokumente gefunden
366
+ result = {
367
+ "answer": "Keine relevanten Dokumente gefunden",
368
+ "relevant_docs": most_relevant_docs
369
+ }
370
 
371
  return result
372