freddyaboulton HF staff commited on
Commit
b72e7b5
1 Parent(s): 9a6ffe6
Files changed (6) hide show
  1. .gitignore +5 -0
  2. app.py +212 -0
  3. gradio_docs.json +0 -0
  4. logo.svg +1 -0
  5. requirements.in +6 -0
  6. requirements.txt +296 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .env
2
+ db_creation.py
3
+ db_test.py
4
+ docs.py
5
+ __pycache__/
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations as _annotations
2
+
3
+ import json
4
+ import os
5
+ from contextlib import asynccontextmanager
6
+ from dataclasses import dataclass
7
+ from typing import AsyncGenerator
8
+
9
+ import asyncpg
10
+ import gradio as gr
11
+ import numpy as np
12
+ import pydantic_core
13
+ from gradio_webrtc import (
14
+ AdditionalOutputs,
15
+ ReplyOnPause,
16
+ WebRTC,
17
+ audio_to_bytes,
18
+ get_twilio_turn_credentials,
19
+ )
20
+ from groq import Groq
21
+ from openai import AsyncOpenAI
22
+ from pydantic import BaseModel
23
+ from pydantic_ai import RunContext
24
+ from pydantic_ai.agent import Agent
25
+ from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn
26
+
27
+ DOCS = json.load(open("gradio_docs.json"))
28
+
29
+ groq_client = Groq()
30
+ openai = AsyncOpenAI()
31
+
32
+
33
+ @dataclass
34
+ class Deps:
35
+ openai: AsyncOpenAI
36
+ pool: asyncpg.Pool
37
+
38
+
39
+ SYSTEM_PROMPT = (
40
+ "You are an assistant designed to help users answer questions about Gradio. "
41
+ "You have a retrival tool that can provide relevant documentation sections based on the user query. "
42
+ "Be curteous and helpful to the user but feel free to refuse answering questions that are not about Gradio. "
43
+ )
44
+
45
+
46
+ agent = Agent(
47
+ "openai:gpt-4o",
48
+ deps_type=Deps,
49
+ system_prompt=SYSTEM_PROMPT,
50
+ )
51
+
52
+
53
+ class RetrievalResult(BaseModel):
54
+ content: str
55
+ ids: list[int]
56
+
57
+
58
+ @asynccontextmanager
59
+ async def database_connect(
60
+ create_db: bool = False,
61
+ ) -> AsyncGenerator[asyncpg.Pool, None]:
62
+ server_dsn, database = (
63
+ os.getenv("DATABASE_URL"),
64
+ "gradio_ai_rag",
65
+ )
66
+ if create_db:
67
+ conn = await asyncpg.connect(server_dsn)
68
+ try:
69
+ db_exists = await conn.fetchval(
70
+ "SELECT 1 FROM pg_database WHERE datname = $1", database
71
+ )
72
+ if not db_exists:
73
+ await conn.execute(f"CREATE DATABASE {database}")
74
+ finally:
75
+ await conn.close()
76
+
77
+ pool = await asyncpg.create_pool(f"{server_dsn}/{database}")
78
+ try:
79
+ yield pool
80
+ finally:
81
+ await pool.close()
82
+
83
+
84
+ @agent.tool
85
+ async def retrieve(context: RunContext[Deps], search_query: str) -> str:
86
+ """Retrieve documentation sections based on a search query.
87
+
88
+ Args:
89
+ context: The call context.
90
+ search_query: The search query.
91
+ """
92
+ print(f"create embedding for {search_query}")
93
+ embedding = await context.deps.openai.embeddings.create(
94
+ input=search_query,
95
+ model="text-embedding-3-small",
96
+ )
97
+
98
+ assert (
99
+ len(embedding.data) == 1
100
+ ), f"Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}"
101
+ embedding = embedding.data[0].embedding
102
+ embedding_json = pydantic_core.to_json(embedding).decode()
103
+ rows = await context.deps.pool.fetch(
104
+ "SELECT id, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8",
105
+ embedding_json,
106
+ )
107
+ content = "\n\n".join(f'# {row["title"]}\n{row["content"]}\n' for row in rows)
108
+ ids = [row["id"] for row in rows]
109
+ return RetrievalResult(content=content, ids=ids).model_dump_json()
110
+
111
+
112
+ async def stream_from_agent(
113
+ audio: tuple[int, np.ndarray], chatbot: list[dict], past_messages: list
114
+ ):
115
+ question = groq_client.audio.transcriptions.create(
116
+ file=("audio-file.mp3", audio_to_bytes(audio)),
117
+ model="whisper-large-v3-turbo",
118
+ response_format="verbose_json",
119
+ ).text
120
+
121
+ print("text", question)
122
+
123
+ chatbot.append({"role": "user", "content": question})
124
+ yield AdditionalOutputs(chatbot, gr.skip())
125
+
126
+ async with database_connect(False) as pool:
127
+ deps = Deps(openai=openai, pool=pool)
128
+ async with agent.run_stream(
129
+ question, deps=deps, message_history=past_messages
130
+ ) as result:
131
+ for message in result.new_messages():
132
+ past_messages.append(message)
133
+ if isinstance(message, ModelStructuredResponse):
134
+ for call in message.calls:
135
+ gr_message = {
136
+ "role": "assistant",
137
+ "content": "",
138
+ "metadata": {
139
+ "title": "🔍 Retrieving relevant docs",
140
+ "id": call.tool_id,
141
+ },
142
+ }
143
+ chatbot.append(gr_message)
144
+ if isinstance(message, ToolReturn):
145
+ for gr_message in chatbot:
146
+ if (
147
+ gr_message.get("metadata", {}).get("id", "")
148
+ == message.tool_id
149
+ ):
150
+ paths = []
151
+ for d in DOCS:
152
+ tool_result = RetrievalResult.model_validate_json(
153
+ message.content
154
+ )
155
+ if d["id"] in tool_result.ids:
156
+ paths.append(d["path"])
157
+ gr_message["content"] = (
158
+ f"Relevant Context:\n {'\n'.join(list(set(paths)))}"
159
+ )
160
+ yield AdditionalOutputs(chatbot, gr.skip())
161
+ chatbot.append({"role": "assistant", "content": ""})
162
+ async for message in result.stream_text():
163
+ chatbot[-1]["content"] = message
164
+ yield AdditionalOutputs(chatbot, gr.skip())
165
+ data = await result.get_data()
166
+ past_messages.append(ModelTextResponse(content=data))
167
+ yield AdditionalOutputs(gr.skip(), past_messages)
168
+
169
+
170
+ with gr.Blocks() as demo:
171
+ placeholder = """
172
+ <div style="display: flex; justify-content: center; align-items: center; gap: 1rem; padding: 1rem; width: 100%">
173
+ <img src="/gradio_api/file=logo.svg" style="max-width: 200px; height: auto">
174
+ <div>
175
+ <h1 style="margin: 0 0 1rem 0">Chat with Gradio Docs 🗣️</h1>
176
+ <h3 style="margin: 0 0 0.5rem 0">
177
+ Simple RAG agent over Gradio docs built with Pydantic AI.
178
+ </h3>
179
+ <h3 style="margin: 0">
180
+ Ask any question about Gradio with your natural voice and get an answer!
181
+ </h3>
182
+ </div>
183
+ </div>
184
+ """
185
+ past_messages = gr.State([])
186
+ chatbot = gr.Chatbot(
187
+ label="Gradio Docs Bot",
188
+ type="messages",
189
+ placeholder=placeholder,
190
+ avatar_images=(None, "logo.svg"),
191
+ )
192
+ audio = WebRTC(
193
+ label="Talk with the Agent",
194
+ modality="audio",
195
+ rtc_configuration=get_twilio_turn_credentials(),
196
+ mode="send",
197
+ )
198
+ audio.stream(
199
+ ReplyOnPause(stream_from_agent),
200
+ inputs=[audio, chatbot, past_messages],
201
+ outputs=[audio],
202
+ )
203
+ audio.on_additional_outputs(
204
+ lambda c, s: (c, s),
205
+ outputs=[chatbot, past_messages],
206
+ queue=False,
207
+ show_progress="hidden",
208
+ )
209
+
210
+
211
+ if __name__ == "__main__":
212
+ demo.launch(allowed_paths=["logo.svg"])
gradio_docs.json ADDED
The diff for this file is too large to render. See raw diff
 
logo.svg ADDED
requirements.in ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio-webrtc[vad]>=0.0.18
2
+ numba>=0.60.0
3
+ pydantic-ai
4
+ asyncpg
5
+ groq
6
+ openai
requirements.txt ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile requirements.in -o requirements.txt
3
+ aiofiles==23.2.1
4
+ # via gradio
5
+ aioice==0.9.0
6
+ # via aiortc
7
+ aiortc==1.9.0
8
+ # via gradio-webrtc
9
+ annotated-types==0.7.0
10
+ # via pydantic
11
+ anyio==4.7.0
12
+ # via
13
+ # gradio
14
+ # groq
15
+ # httpx
16
+ # openai
17
+ # starlette
18
+ asyncpg==0.30.0
19
+ # via -r requirements.in
20
+ audioread==3.0.1
21
+ # via librosa
22
+ av==12.3.0
23
+ # via aiortc
24
+ cachetools==5.5.0
25
+ # via google-auth
26
+ certifi==2024.8.30
27
+ # via
28
+ # httpcore
29
+ # httpx
30
+ # requests
31
+ cffi==1.17.1
32
+ # via
33
+ # aiortc
34
+ # cryptography
35
+ # pylibsrtp
36
+ # soundfile
37
+ charset-normalizer==3.4.0
38
+ # via requests
39
+ click==8.1.7
40
+ # via
41
+ # typer
42
+ # uvicorn
43
+ colorama==0.4.6
44
+ # via griffe
45
+ coloredlogs==15.0.1
46
+ # via onnxruntime
47
+ cryptography==44.0.0
48
+ # via
49
+ # aiortc
50
+ # pyopenssl
51
+ decorator==5.1.1
52
+ # via librosa
53
+ distro==1.9.0
54
+ # via
55
+ # groq
56
+ # openai
57
+ dnspython==2.7.0
58
+ # via aioice
59
+ eval-type-backport==0.2.0
60
+ # via pydantic-ai-slim
61
+ fastapi==0.115.6
62
+ # via gradio
63
+ ffmpy==0.4.0
64
+ # via gradio
65
+ filelock==3.16.1
66
+ # via huggingface-hub
67
+ flatbuffers==24.3.25
68
+ # via onnxruntime
69
+ fsspec==2024.10.0
70
+ # via
71
+ # gradio-client
72
+ # huggingface-hub
73
+ google-auth==2.36.0
74
+ # via pydantic-ai-slim
75
+ google-crc32c==1.6.0
76
+ # via aiortc
77
+ gradio==5.8.0
78
+ # via gradio-webrtc
79
+ gradio-client==1.5.1
80
+ # via gradio
81
+ gradio-webrtc==0.0.18
82
+ # via -r requirements.in
83
+ griffe==1.5.1
84
+ # via pydantic-ai-slim
85
+ groq==0.13.0
86
+ # via
87
+ # -r requirements.in
88
+ # pydantic-ai-slim
89
+ h11==0.14.0
90
+ # via
91
+ # httpcore
92
+ # uvicorn
93
+ httpcore==1.0.7
94
+ # via httpx
95
+ httpx==0.28.0
96
+ # via
97
+ # gradio
98
+ # gradio-client
99
+ # groq
100
+ # openai
101
+ # pydantic-ai-slim
102
+ # safehttpx
103
+ huggingface-hub==0.26.3
104
+ # via
105
+ # gradio
106
+ # gradio-client
107
+ humanfriendly==10.0
108
+ # via coloredlogs
109
+ idna==3.10
110
+ # via
111
+ # anyio
112
+ # httpx
113
+ # requests
114
+ ifaddr==0.2.0
115
+ # via aioice
116
+ jinja2==3.1.4
117
+ # via gradio
118
+ jiter==0.8.0
119
+ # via openai
120
+ joblib==1.4.2
121
+ # via
122
+ # librosa
123
+ # scikit-learn
124
+ lazy-loader==0.4
125
+ # via librosa
126
+ librosa==0.10.2.post1
127
+ # via gradio-webrtc
128
+ llvmlite==0.43.0
129
+ # via numba
130
+ logfire-api==2.6.2
131
+ # via pydantic-ai-slim
132
+ markdown-it-py==3.0.0
133
+ # via rich
134
+ markupsafe==2.1.5
135
+ # via
136
+ # gradio
137
+ # jinja2
138
+ mdurl==0.1.2
139
+ # via markdown-it-py
140
+ mpmath==1.3.0
141
+ # via sympy
142
+ msgpack==1.1.0
143
+ # via librosa
144
+ numba==0.60.0
145
+ # via
146
+ # -r requirements.in
147
+ # librosa
148
+ numpy==2.0.2
149
+ # via
150
+ # gradio
151
+ # librosa
152
+ # numba
153
+ # onnxruntime
154
+ # pandas
155
+ # scikit-learn
156
+ # scipy
157
+ # soxr
158
+ onnxruntime==1.20.1
159
+ # via gradio-webrtc
160
+ openai==1.57.0
161
+ # via
162
+ # -r requirements.in
163
+ # pydantic-ai-slim
164
+ orjson==3.10.12
165
+ # via gradio
166
+ packaging==24.2
167
+ # via
168
+ # gradio
169
+ # gradio-client
170
+ # huggingface-hub
171
+ # lazy-loader
172
+ # onnxruntime
173
+ # pooch
174
+ pandas==2.2.3
175
+ # via gradio
176
+ pillow==11.0.0
177
+ # via gradio
178
+ platformdirs==4.3.6
179
+ # via pooch
180
+ pooch==1.8.2
181
+ # via librosa
182
+ protobuf==5.29.1
183
+ # via onnxruntime
184
+ pyasn1==0.6.1
185
+ # via
186
+ # pyasn1-modules
187
+ # rsa
188
+ pyasn1-modules==0.4.1
189
+ # via google-auth
190
+ pycparser==2.22
191
+ # via cffi
192
+ pydantic==2.10.3
193
+ # via
194
+ # fastapi
195
+ # gradio
196
+ # groq
197
+ # openai
198
+ # pydantic-ai-slim
199
+ pydantic-ai==0.0.9
200
+ # via -r requirements.in
201
+ pydantic-ai-slim==0.0.9
202
+ # via pydantic-ai
203
+ pydantic-core==2.27.1
204
+ # via pydantic
205
+ pydub==0.25.1
206
+ # via gradio
207
+ pyee==12.1.1
208
+ # via aiortc
209
+ pygments==2.18.0
210
+ # via rich
211
+ pylibsrtp==0.10.0
212
+ # via aiortc
213
+ pyopenssl==24.3.0
214
+ # via aiortc
215
+ python-dateutil==2.9.0.post0
216
+ # via pandas
217
+ python-multipart==0.0.19
218
+ # via gradio
219
+ pytz==2024.2
220
+ # via pandas
221
+ pyyaml==6.0.2
222
+ # via
223
+ # gradio
224
+ # huggingface-hub
225
+ requests==2.32.3
226
+ # via
227
+ # huggingface-hub
228
+ # pooch
229
+ # pydantic-ai-slim
230
+ rich==13.9.4
231
+ # via typer
232
+ rsa==4.9
233
+ # via google-auth
234
+ ruff==0.8.2
235
+ # via gradio
236
+ safehttpx==0.1.6
237
+ # via gradio
238
+ scikit-learn==1.5.2
239
+ # via librosa
240
+ scipy==1.14.1
241
+ # via
242
+ # librosa
243
+ # scikit-learn
244
+ semantic-version==2.10.0
245
+ # via gradio
246
+ shellingham==1.5.4
247
+ # via typer
248
+ six==1.17.0
249
+ # via python-dateutil
250
+ sniffio==1.3.1
251
+ # via
252
+ # anyio
253
+ # groq
254
+ # openai
255
+ soundfile==0.12.1
256
+ # via librosa
257
+ soxr==0.5.0.post1
258
+ # via librosa
259
+ starlette==0.41.3
260
+ # via
261
+ # fastapi
262
+ # gradio
263
+ sympy==1.13.3
264
+ # via onnxruntime
265
+ threadpoolctl==3.5.0
266
+ # via scikit-learn
267
+ tomlkit==0.13.2
268
+ # via gradio
269
+ tqdm==4.67.1
270
+ # via
271
+ # huggingface-hub
272
+ # openai
273
+ typer==0.15.1
274
+ # via gradio
275
+ typing-extensions==4.12.2
276
+ # via
277
+ # anyio
278
+ # fastapi
279
+ # gradio
280
+ # gradio-client
281
+ # groq
282
+ # huggingface-hub
283
+ # librosa
284
+ # openai
285
+ # pydantic
286
+ # pydantic-core
287
+ # pyee
288
+ # typer
289
+ tzdata==2024.2
290
+ # via pandas
291
+ urllib3==2.2.3
292
+ # via requests
293
+ uvicorn==0.32.1
294
+ # via gradio
295
+ websockets==14.1
296
+ # via gradio-client