nhathuy07 commited on
Commit
68016fb
·
verified ·
1 Parent(s): c0e9b69

Create APIGuard Middleware class

Browse files
Files changed (1) hide show
  1. main.py +29 -26
main.py CHANGED
@@ -3,6 +3,7 @@ from starlette.applications import Starlette
3
  from starlette.routing import Route
4
  from starlette.middleware import Middleware
5
  from starlette.middleware.cors import CORSMiddleware
 
6
  from gensim.models import KeyedVectors
7
  """Prompt templates for LLM"""
8
  from env import LLM_API_KEY
@@ -58,34 +59,36 @@ class QType(Enum):
58
  STMT = 3
59
  FILL = 6
60
 
 
 
 
61
 
 
62
 
63
- async def __api_guard(request, call_next, **kwargs):
64
- # Your custom logic here
65
- client_url = request.url.path
66
- client_ip_addr = request.client.host
67
-
68
- # IP-based rate limitation
69
- async with r.pipeline(transaction=True) as pipeline:
70
- try:
71
- res = await (pipeline.get(client_ip_addr).execute())[-1]
72
- except:
73
- res = None
74
-
75
- if res == None:
76
- ok = await pipeline.set(client_ip_addr, 25).execute()
77
- r.expire(60)
78
- elif res > 0:
79
- ok = await pipeline.set(client_ip_addr, res-1).execute()
80
- else:
81
- raise HTTPException(status_code=429, detail="This IP address is rate-limited")
82
-
83
- response = await call_next(request)
84
-
85
- # Optionally modify the response object
86
- # ...
87
 
88
- return response
 
 
 
 
 
 
 
 
 
 
 
 
89
  sys_random = SystemRandom()
90
 
91
  # TODO: Change to environment variable in prod.
@@ -651,7 +654,7 @@ middleware = [
651
  allow_origins=['http://localhost:8100', 'https://text2quiz-three.vercel.app'],
652
  allow_methods =['*'],
653
  ),
654
- Middleware(__api_guard),
655
  ]
656
 
657
  app = Starlette(debug=True,routes=[
 
3
  from starlette.routing import Route
4
  from starlette.middleware import Middleware
5
  from starlette.middleware.cors import CORSMiddleware
6
+ from starlette.middleware.base import BaseHTTPMiddleware
7
  from gensim.models import KeyedVectors
8
  """Prompt templates for LLM"""
9
  from env import LLM_API_KEY
 
59
  STMT = 3
60
  FILL = 6
61
 
62
+ class APIGuardMiddleware(BaseHTTPMiddleware):
63
+ def __init__(self, app):
64
+ super().__init__(app)
65
 
66
+ async def dispatch(self, request, call_next):
67
 
68
+ # Get current client url and client IP address
69
+ client_url = request.url.path
70
+ client_ip_addr = request.client.host
71
+
72
+ # IP-based rate limitation
73
+ async with r.pipeline(transaction=True) as pipeline:
74
+ try:
75
+ res = await (pipeline.get(client_ip_addr).execute())[-1]
76
+ except:
77
+ res = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ if res == None:
80
+ ok = await pipeline.set(client_ip_addr, 25).execute()
81
+ r.expire(60)
82
+ elif res > 0:
83
+ ok = await pipeline.set(client_ip_addr, res-1).execute()
84
+ else:
85
+ raise HTTPException(status_code=429, detail="This IP address is rate-limited")
86
+
87
+ # process the request and get the response
88
+ response = await call_next(request)
89
+ return response
90
+
91
+
92
  sys_random = SystemRandom()
93
 
94
  # TODO: Change to environment variable in prod.
 
654
  allow_origins=['http://localhost:8100', 'https://text2quiz-three.vercel.app'],
655
  allow_methods =['*'],
656
  ),
657
+ Middleware(APIGuardMiddleware),
658
  ]
659
 
660
  app = Starlette(debug=True,routes=[