Spaces:
Running
Running
Create APIGuard Middleware class
Browse files
main.py
CHANGED
@@ -3,6 +3,7 @@ from starlette.applications import Starlette
|
|
3 |
from starlette.routing import Route
|
4 |
from starlette.middleware import Middleware
|
5 |
from starlette.middleware.cors import CORSMiddleware
|
|
|
6 |
from gensim.models import KeyedVectors
|
7 |
"""Prompt templates for LLM"""
|
8 |
from env import LLM_API_KEY
|
@@ -58,34 +59,36 @@ class QType(Enum):
|
|
58 |
STMT = 3
|
59 |
FILL = 6
|
60 |
|
|
|
|
|
|
|
61 |
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
res = None
|
74 |
-
|
75 |
-
if res == None:
|
76 |
-
ok = await pipeline.set(client_ip_addr, 25).execute()
|
77 |
-
r.expire(60)
|
78 |
-
elif res > 0:
|
79 |
-
ok = await pipeline.set(client_ip_addr, res-1).execute()
|
80 |
-
else:
|
81 |
-
raise HTTPException(status_code=429, detail="This IP address is rate-limited")
|
82 |
-
|
83 |
-
response = await call_next(request)
|
84 |
-
|
85 |
-
# Optionally modify the response object
|
86 |
-
# ...
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
sys_random = SystemRandom()
|
90 |
|
91 |
# TODO: Change to environment variable in prod.
|
@@ -651,7 +654,7 @@ middleware = [
|
|
651 |
allow_origins=['http://localhost:8100', 'https://text2quiz-three.vercel.app'],
|
652 |
allow_methods =['*'],
|
653 |
),
|
654 |
-
Middleware(
|
655 |
]
|
656 |
|
657 |
app = Starlette(debug=True,routes=[
|
|
|
3 |
from starlette.routing import Route
|
4 |
from starlette.middleware import Middleware
|
5 |
from starlette.middleware.cors import CORSMiddleware
|
6 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
7 |
from gensim.models import KeyedVectors
|
8 |
"""Prompt templates for LLM"""
|
9 |
from env import LLM_API_KEY
|
|
|
59 |
STMT = 3
|
60 |
FILL = 6
|
61 |
|
62 |
+
class APIGuardMiddleware(BaseHTTPMiddleware):
|
63 |
+
def __init__(self, app):
|
64 |
+
super().__init__(app)
|
65 |
|
66 |
+
async def dispatch(self, request, call_next):
|
67 |
|
68 |
+
# Get current client url and client IP address
|
69 |
+
client_url = request.url.path
|
70 |
+
client_ip_addr = request.client.host
|
71 |
+
|
72 |
+
# IP-based rate limitation
|
73 |
+
async with r.pipeline(transaction=True) as pipeline:
|
74 |
+
try:
|
75 |
+
res = await (pipeline.get(client_ip_addr).execute())[-1]
|
76 |
+
except:
|
77 |
+
res = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
if res == None:
|
80 |
+
ok = await pipeline.set(client_ip_addr, 25).execute()
|
81 |
+
r.expire(60)
|
82 |
+
elif res > 0:
|
83 |
+
ok = await pipeline.set(client_ip_addr, res-1).execute()
|
84 |
+
else:
|
85 |
+
raise HTTPException(status_code=429, detail="This IP address is rate-limited")
|
86 |
+
|
87 |
+
# process the request and get the response
|
88 |
+
response = await call_next(request)
|
89 |
+
return response
|
90 |
+
|
91 |
+
|
92 |
sys_random = SystemRandom()
|
93 |
|
94 |
# TODO: Change to environment variable in prod.
|
|
|
654 |
allow_origins=['http://localhost:8100', 'https://text2quiz-three.vercel.app'],
|
655 |
allow_methods =['*'],
|
656 |
),
|
657 |
+
Middleware(APIGuardMiddleware),
|
658 |
]
|
659 |
|
660 |
app = Starlette(debug=True,routes=[
|