|
import os |
|
import gradio as gr |
|
from langchain-community.vectorstores import Chroma |
|
from transformers import pipeline |
|
from sentence_transformers import SentenceTransformer |
|
import langchain.chains.LLMChain |
|
import langchain_core.prompts.PromptTemplate |
|
from langchain_huggingface import HuggingFacePipeline |
|
|
|
|
|
|
|
ANTI_BOT_PW = os.getenv("CORRECT_VALIDATE") |
|
PATH_WORK = "." |
|
CHROMA_DIR = "/chroma/kkg" |
|
CHROMA_PDF = './chroma/kkg/pdf' |
|
CHROMA_WORD = './chroma/kkg/word' |
|
CHROMA_EXCEL = './chroma/kkg/excel' |
|
|
|
MODEL_NAME_HF = "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
|
|
|
hf_token = os.getenv("HF_READ") |
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HF_READ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vektordatenbank = None |
|
retriever = None |
|
|
|
|
|
|
|
|
|
file_path_download = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text (prompt, chatbot, history, vektordatenbank, websuche, top_p=0.6, temperature=0.2, max_new_tokens=4048, max_context_length_tokens=2048, repetition_penalty=1.3, top_k=35): |
|
print("Text pur..............................") |
|
if (prompt == ""): |
|
raise gr.Error("Prompt ist erforderlich.") |
|
|
|
try: |
|
|
|
print("HF Anfrage.......................") |
|
model_kwargs={"temperature": 0.5, "max_length": 512, "num_return_sequences": 1, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty} |
|
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs=model_kwargs) |
|
llm = HuggingFaceChain(model=MODEL_NAME_HF, model_kwargs={"temperature": 0.5, "max_length": 128}) |
|
|
|
|
|
history_text_und_prompt = generate_prompt_with_history(prompt, history) |
|
|
|
|
|
print("LLM aufrufen mit RAG: ...........") |
|
result = rag_chain(llm, history_text_und_prompt, retriever) |
|
|
|
|
|
print("result regchain.....................") |
|
print(result) |
|
|
|
except Exception as e: |
|
raise gr.Error(e) |
|
|
|
return result, suche_im_Netz |
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_input(user_input_validate, validate=False): |
|
user_input_hashed = hash_input(user_input_validate) |
|
if user_input_hashed == hash_input(ANTI_BOT_PW): |
|
return "Richtig! Weiter gehts... ", True, gr.Textbox(visible=False), gr.Button(visible=False) |
|
else: |
|
return "Falsche Antwort!!!!!!!!!", False, gr.Textbox(label = "", placeholder="Bitte tippen Sie das oben im Moodle Kurs angegebene Wort ein, um zu beweisen, dass Sie kein Bot sind.", visible=True, scale= 5), gr.Button("Validieren", visible = True) |
|
|
|
|
|
|
|
def custom_css(): |
|
return """ |
|
body, html { |
|
background-color: #303030; /* Dunkler Hintergrund */ |
|
color:#353535; |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_rag_response(question): |
|
|
|
docs = chroma_db.search(question, top_k=5) |
|
passages = [doc['text'] for doc in docs] |
|
links = [doc.get('url', 'No URL available') for doc in docs] |
|
|
|
|
|
context = " ".join(passages) |
|
qa_input = {"question": question, "context": context} |
|
answer = qa_pipeline(qa_input)['answer'] |
|
|
|
|
|
response = { |
|
"answer": answer, |
|
"documents": [{"link": link, "passage": passage} for link, passage in zip(links, passages)] |
|
} |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
def generate_auswahl(prompt_in, file, file_history, chatbot, history, top_p=0.6, temperature=0.5, max_new_tokens=4048, max_context_length_tokens=2048, repetition_penalty=1.3,top_k=5, validate=False): |
|
global vektordatenbank, retriever |
|
|
|
|
|
if (validate and not prompt_in == "" and not prompt_in == None): |
|
|
|
|
|
neu_file = file_history |
|
|
|
|
|
prompt = normalise_prompt(prompt_in) |
|
|
|
|
|
if vektordatenbank == None: |
|
print("db neu aufbauen!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1") |
|
splits = document_loading_splitting() |
|
vektordatenbank, retriever = document_storage_chroma(splits) |
|
|
|
|
|
status = "Antwort der KI ..." |
|
if (file == None and file_history == None): |
|
result, status = generate_text(prompt, chatbot, history,vektordatenbank, top_p=0.6, temperature=0.5, max_new_tokens=4048, max_context_length_tokens=2048, repetition_penalty=1.3, top_k=3) |
|
history = history + [[prompt, result]] |
|
else: |
|
|
|
|
|
if (file != None): |
|
|
|
neu_file = file |
|
|
|
|
|
result = generate_text_zu_doc(neu_file, prompt, k, rag_option, chatbot, history, vektordatenbank) |
|
|
|
|
|
if (file != None): |
|
history = history + [[(file,), None],[prompt, result]] |
|
else: |
|
history = history + [[prompt, result]] |
|
|
|
chatbot[-1][1] = "" |
|
for character in result: |
|
chatbot[-1][1] += character |
|
time.sleep(0.03) |
|
yield chatbot, history, None, neu_file, status |
|
if shared_state.interrupted: |
|
shared_state.recover() |
|
try: |
|
yield chatbot, history, None, neu_file, "Stop: Success" |
|
except: |
|
pass |
|
else: |
|
return chatbot, history, None, file_history, "Erst validieren oder einen Prompt eingeben!" |
|
|
|
|
|
|
|
|
|
|
|
|
|
print ("Start GUI Vorabfrage") |
|
|
|
print ("Start GUI Hauptanwendung") |
|
with open("custom.css", "r", encoding="utf-8") as f: |
|
customCSS = f.read() |
|
|
|
|
|
additional_inputs = [ |
|
gr.Slider(label="Temperature", value=0.65, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Höhere Werte erzeugen diversere Antworten", visible=True), |
|
gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=4096, step=64, interactive=True, info="Maximale Anzahl neuer Tokens", visible=True), |
|
gr.Slider(label="Top-p (nucleus sampling)", value=0.6, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Höhere Werte verwenden auch Tokens mit niedrigerer Wahrscheinlichkeit.", visible=True), |
|
gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Strafe für wiederholte Tokens", visible=True) |
|
] |
|
with gr.Blocks(css=customCSS, theme=themeAlex) as demo: |
|
|
|
validate = gr.State(False) |
|
|
|
|
|
|
|
history = gr.State([]) |
|
uploaded_file_paths= gr.State([]) |
|
history3 = gr.State([]) |
|
uploaded_file_paths3= gr.State([]) |
|
|
|
chats = gr.State({}) |
|
|
|
user_question = gr.State("") |
|
|
|
|
|
user_question2 = gr.State("") |
|
user_question3 = gr.State("") |
|
attached_file = gr.State(None) |
|
attached_file_history = gr.State(None) |
|
attached_file3 = gr.State(None) |
|
attached_file_history3 = gr.State(None) |
|
status_display = gr.State("") |
|
status_display2 = gr.State("") |
|
status_display3 = gr.State("") |
|
|
|
|
|
|
|
gr.Markdown(description_top) |
|
with gr.Row(): |
|
user_input_validate =gr.Textbox(label= "Bitte das oben im Moodle Kurs angegebene Wort eingeben, um die Anwendung zu starten", visible=True, interactive=True, scale= 7) |
|
validate_btn = gr.Button("Validieren", visible = True) |
|
|
|
|
|
with gr.Tab("KKG Chatbot"): |
|
with gr.Row(): |
|
|
|
status_display = gr.Markdown("Antwort der KI ...", visible = True) |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
with gr.Row(): |
|
chatbot = gr.Chatbot(elem_id="li-chat",show_copy_button=True) |
|
with gr.Row(): |
|
with gr.Column(scale=12): |
|
user_input = gr.Textbox( |
|
show_label=False, placeholder="Gib hier deinen Prompt ein...", |
|
container=False |
|
) |
|
with gr.Column(min_width=70, scale=1): |
|
submitBtn = gr.Button("Senden") |
|
with gr.Column(min_width=70, scale=1): |
|
cancelBtn = gr.Button("Stop") |
|
with gr.Row(): |
|
image_display = gr.Image( visible=False) |
|
upload = gr.UploadButton("📁", file_types=["image", "pdf", "docx", "pptx", "xlsx"], scale = 10) |
|
emptyBtn = gr.ClearButton([user_input, chatbot, history, attached_file, attached_file_history, image_display], value="🧹 Neue Session", scale=10) |
|
|
|
with gr.Column(): |
|
with gr.Column(min_width=50, scale=1): |
|
with gr.Tab(label="Chats ..."): |
|
|
|
|
|
|
|
file_download = gr.File(label="Noch keine Chatsverläufe", visible=True, interactive = False, file_count="multiple",) |
|
|
|
with gr.Tab(label="Parameter"): |
|
|
|
rag_option = gr.Radio(["Aus", "An"], label="KKG Erweiterungen (RAG)", value = "Aus") |
|
model_option = gr.Radio(["OpenAI", "HuggingFace"], label="Modellauswahl", value = "OpenAI") |
|
websuche = gr.Radio(["Aus", "An"], label="Web-Suche", value = "Aus") |
|
|
|
|
|
top_p = gr.Slider( |
|
minimum=-0, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
interactive=True, |
|
label="Top-p", |
|
visible=False, |
|
) |
|
top_k = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=35, |
|
step=1, |
|
interactive=True, |
|
label="Top-k", |
|
visible=False, |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.2, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
visible=False |
|
) |
|
max_length_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=512, |
|
value=512, |
|
step=8, |
|
interactive=True, |
|
label="Max Generation Tokens", |
|
visible=False, |
|
) |
|
max_context_length_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=4096, |
|
value=2048, |
|
step=128, |
|
interactive=True, |
|
label="Max History Tokens", |
|
visible=False, |
|
) |
|
repetition_penalty=gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Strafe für wiederholte Tokens", visible=False) |
|
anzahl_docs = gr.Slider(label="Anzahl Dokumente", value=3, minimum=1, maximum=10, step=1, interactive=True, info="wie viele Dokumententeile aus dem Vektorstore an den prompt gehängt werden", visible=False) |
|
openai_key = gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1, visible = False) |
|
|
|
|
|
gr.Markdown(description) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predict_args = dict( |
|
fn=generate_auswahl, |
|
inputs=[ |
|
user_question, |
|
attached_file, |
|
attached_file_history, |
|
chatbot, |
|
history, |
|
anzahl_docs, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens, |
|
repetition_penalty, |
|
top_k, |
|
websuche, |
|
validate |
|
], |
|
outputs=[chatbot, history, attached_file, attached_file_history, status_display], |
|
show_progress=True, |
|
) |
|
|
|
reset_args = dict( |
|
fn=reset_textbox, inputs=[], outputs=[user_input, status_display] |
|
) |
|
|
|
|
|
transfer_input_args = dict( |
|
fn=add_text, inputs=[chatbot, history, user_input, attached_file, attached_file_history], outputs=[chatbot, history, user_question, attached_file, attached_file_history, image_display , user_input], show_progress=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
validate_btn.click(validate_input, inputs=[user_input_validate, validate], outputs=[status_display, validate, user_input_validate, validate_btn]) |
|
user_input_validate.submit(validate_input, inputs=[user_input_validate, validate], outputs=[status_display, validate, user_input_validate, validate_btn]) |
|
|
|
predict_event1 = user_input.submit(**transfer_input_args, queue=False,).then(**predict_args) |
|
predict_event2 = submitBtn.click(**transfer_input_args, queue=False,).then(**predict_args) |
|
predict_event3 = upload.upload(file_anzeigen, [upload], [image_display, image_display, attached_file] ) |
|
emptyBtn.click(clear_all, [history, uploaded_file_paths, chats], [attached_file, image_display, uploaded_file_paths, history, file_download, chats]) |
|
|
|
image_display.select(file_loeschen, [], [attached_file, image_display]) |
|
|
|
|
|
|
|
|
|
cancelBtn.click(cancel_outputing, [], [status_display], cancels=[predict_event1,predict_event2, predict_event3]) |
|
|
|
|