mannadamay12 commited on
Commit
23a54f8
·
verified ·
1 Parent(s): 7321565

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -17
app.py CHANGED
@@ -64,37 +64,63 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
64
  try:
65
  model, tokenizer = initialize_model()
66
 
67
- # Get relevant context from the database
68
  retriever = db.as_retriever(search_kwargs={"k": 2})
69
  docs = retriever.get_relevant_documents(message)
70
  context = "\n".join([doc.page_content for doc in docs])
71
 
72
- # Generate the complete prompt
73
  prompt = generate_prompt(context=context, question=message, system_prompt=system_message)
74
 
75
- # Set up the streamer
76
- streamer = CustomTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
77
-
78
- # Set up the pipeline
79
- text_pipeline = pipeline(
80
- "text-generation",
81
- model=model,
82
- tokenizer=tokenizer,
83
  max_new_tokens=max_tokens,
84
  temperature=temperature,
85
  top_p=top_p,
86
  repetition_penalty=1.15,
87
- streamer=streamer,
88
- )
89
-
90
- # Generate response
91
- _ = text_pipeline(prompt, max_new_tokens=max_tokens)
92
 
93
- # Return only the generated response
94
- yield streamer.output_text.strip()
95
 
96
  except Exception as e:
97
  yield f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # Create Gradio interface
100
  demo = gr.ChatInterface(
 
64
  try:
65
  model, tokenizer = initialize_model()
66
 
67
+ # Get context from database
68
  retriever = db.as_retriever(search_kwargs={"k": 2})
69
  docs = retriever.get_relevant_documents(message)
70
  context = "\n".join([doc.page_content for doc in docs])
71
 
72
+ # Generate prompt
73
  prompt = generate_prompt(context=context, question=message, system_prompt=system_message)
74
 
75
+ # Generate response without streamer for direct string output
76
+ output = text_pipeline(
77
+ prompt,
 
 
 
 
 
78
  max_new_tokens=max_tokens,
79
  temperature=temperature,
80
  top_p=top_p,
81
  repetition_penalty=1.15,
82
+ return_full_text=False
83
+ )[0]['generated_text']
 
 
 
84
 
85
+ yield output.strip()
 
86
 
87
  except Exception as e:
88
  yield f"An error occurred: {str(e)}"
89
+ # def respond(message, history, system_message, max_tokens, temperature, top_p):
90
+ # try:
91
+ # model, tokenizer = initialize_model()
92
+
93
+ # # Get relevant context from the database
94
+ # retriever = db.as_retriever(search_kwargs={"k": 2})
95
+ # docs = retriever.get_relevant_documents(message)
96
+ # context = "\n".join([doc.page_content for doc in docs])
97
+
98
+ # # Generate the complete prompt
99
+ # prompt = generate_prompt(context=context, question=message, system_prompt=system_message)
100
+
101
+ # # Set up the streamer
102
+ # streamer = CustomTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
103
+
104
+ # # Set up the pipeline
105
+ # text_pipeline = pipeline(
106
+ # "text-generation",
107
+ # model=model,
108
+ # tokenizer=tokenizer,
109
+ # max_new_tokens=max_tokens,
110
+ # temperature=temperature,
111
+ # top_p=top_p,
112
+ # repetition_penalty=1.15,
113
+ # streamer=streamer,
114
+ # )
115
+
116
+ # # Generate response
117
+ # _ = text_pipeline(prompt, max_new_tokens=max_tokens)
118
+
119
+ # # Return only the generated response
120
+ # yield streamer.output_text.strip()
121
+
122
+ # except Exception as e:
123
+ # yield f"An error occurred: {str(e)}"
124
 
125
  # Create Gradio interface
126
  demo = gr.ChatInterface(