Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, Response | |
from fastapi.responses import HTMLResponse, RedirectResponse | |
from fastapi.templating import Jinja2Templates | |
from google.oauth2 import id_token | |
from google.auth.transport import requests as google_requests | |
from google_auth_oauthlib.flow import Flow | |
from chainlit.utils import mount_chainlit | |
import secrets | |
import json | |
import base64 | |
from constants import ( | |
OAUTH_GOOGLE_CLIENT_ID, | |
OAUTH_GOOGLE_CLIENT_SECRET, | |
CHAINLIT_URL, | |
) | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
import os | |
GOOGLE_CLIENT_ID = OAUTH_GOOGLE_CLIENT_ID | |
GOOGLE_CLIENT_SECRET = OAUTH_GOOGLE_CLIENT_SECRET | |
GOOGLE_REDIRECT_URI = f"{CHAINLIT_URL}/auth/oauth/google/callback" | |
app = FastAPI() | |
# app.mount("/public", StaticFiles(directory="public"), name="public") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Update with appropriate origins | |
allow_methods=["*"], | |
allow_headers=["*"], # or specify the headers you want to allow | |
expose_headers=["X-User-Info"], # Expose the custom header | |
) | |
templates = Jinja2Templates(directory="templates") | |
session_store = {} | |
CHAINLIT_PATH = "/chainlit_tutor" | |
USER_ROLES = { | |
"[email protected]": ["instructor", "bu"], | |
"[email protected]": ["instructor", "bu"], | |
"[email protected]": ["instructor", "bu"], | |
"[email protected]": ["guest"], | |
# Add more users and roles as needed | |
} | |
# Create a Google OAuth flow | |
flow = Flow.from_client_config( | |
{ | |
"web": { | |
"client_id": GOOGLE_CLIENT_ID, | |
"client_secret": GOOGLE_CLIENT_SECRET, | |
"auth_uri": "https://accounts.google.com/o/oauth2/auth", | |
"token_uri": "https://oauth2.googleapis.com/token", | |
"redirect_uris": [GOOGLE_REDIRECT_URI], | |
"scopes": [ | |
"openid", | |
# "https://www.googleapis.com/auth/userinfo.email", | |
# "https://www.googleapis.com/auth/userinfo.profile", | |
], | |
} | |
}, | |
scopes=[ | |
"openid", | |
"https://www.googleapis.com/auth/userinfo.email", | |
"https://www.googleapis.com/auth/userinfo.profile", | |
], | |
redirect_uri=GOOGLE_REDIRECT_URI, | |
) | |
def get_user_role(username: str): | |
return USER_ROLES.get(username, ["student"]) # Default to "student" role | |
def get_user_info_from_cookie(request: Request): | |
user_info_encoded = request.cookies.get("X-User-Info") | |
if user_info_encoded: | |
try: | |
user_info_json = base64.b64decode(user_info_encoded).decode() | |
return json.loads(user_info_json) | |
except Exception as e: | |
print(f"Error decoding user info: {e}") | |
return None | |
return None | |
def get_user_info(request: Request): | |
session_token = request.cookies.get("session_token") | |
if session_token and session_token in session_store: | |
return session_store[session_token] | |
return None | |
async def login_page(request: Request): | |
user_info = get_user_info_from_cookie(request) | |
if user_info and user_info.get("google_signed_in"): | |
return RedirectResponse("/post-signin") | |
return templates.TemplateResponse("login.html", {"request": request}) | |
async def login_guest(): | |
username = "guest" | |
session_token = secrets.token_hex(16) | |
unique_session_id = secrets.token_hex(8) | |
username = f"{username}_{unique_session_id}" | |
session_store[session_token] = { | |
"email": username, | |
"name": "Guest", | |
"profile_image": "", | |
"google_signed_in": False, # Ensure guest users do not have this flag | |
} | |
user_info_json = json.dumps(session_store[session_token]) | |
user_info_encoded = base64.b64encode(user_info_json.encode()).decode() | |
# Set cookies | |
response = RedirectResponse(url="/post-signin", status_code=303) | |
response.set_cookie(key="session_token", value=session_token) | |
response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True) | |
return response | |
async def login_google(request: Request): | |
# Clear any existing session cookies to avoid conflicts with guest sessions | |
response = RedirectResponse(url="/post-signin") | |
response.delete_cookie(key="session_token") | |
response.delete_cookie(key="X-User-Info") | |
user_info = get_user_info_from_cookie(request) | |
print(f"User info: {user_info}") | |
# Check if user is already signed in using Google | |
if user_info and user_info.get("google_signed_in"): | |
return RedirectResponse("/post-signin") | |
else: | |
authorization_url, _ = flow.authorization_url(prompt="consent") | |
return RedirectResponse(authorization_url, headers=response.headers) | |
async def auth_google(request: Request): | |
try: | |
flow.fetch_token(code=request.query_params.get("code")) | |
credentials = flow.credentials | |
user_info = id_token.verify_oauth2_token( | |
credentials.id_token, google_requests.Request(), GOOGLE_CLIENT_ID | |
) | |
email = user_info["email"] | |
name = user_info.get("name", "") | |
profile_image = user_info.get("picture", "") | |
session_token = secrets.token_hex(16) | |
session_store[session_token] = { | |
"email": email, | |
"name": name, | |
"profile_image": profile_image, | |
"google_signed_in": True, # Set this flag to True for Google-signed users | |
} | |
user_info_json = json.dumps(session_store[session_token]) | |
user_info_encoded = base64.b64encode(user_info_json.encode()).decode() | |
# Set cookies | |
response = RedirectResponse(url="/post-signin", status_code=303) | |
response.set_cookie(key="session_token", value=session_token) | |
response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True) | |
return response | |
except Exception as e: | |
print(f"Error during Google OAuth callback: {e}") | |
return RedirectResponse(url="/", status_code=302) | |
async def post_signin(request: Request): | |
user_info = get_user_info_from_cookie(request) | |
if not user_info: | |
user_info = get_user_info(request) | |
# if user_info and user_info.get("google_signed_in"): | |
if user_info: | |
username = user_info["email"] | |
role = get_user_role(username) | |
jwt_token = request.cookies.get("X-User-Info") | |
return templates.TemplateResponse( | |
"dashboard.html", | |
{ | |
"request": request, | |
"username": username, | |
"role": role, | |
"jwt_token": jwt_token, | |
}, | |
) | |
return RedirectResponse("/") | |
async def start_tutor(request: Request): | |
user_info = get_user_info_from_cookie(request) | |
if user_info: | |
user_info_json = json.dumps(user_info) | |
user_info_encoded = base64.b64encode(user_info_json.encode()).decode() | |
response = RedirectResponse(CHAINLIT_PATH, status_code=303) | |
response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True) | |
return response | |
return RedirectResponse(url="/") | |
async def exception_handler(request: Request, exc: Exception): | |
return templates.TemplateResponse( | |
"error.html", {"request": request, "error": str(exc)}, status_code=500 | |
) | |
async def app_logout(request: Request, response: Response): | |
# Clear session cookies | |
response.delete_cookie("session_token") | |
response.delete_cookie("X-User-Info") | |
print("logout_page called") | |
# Redirect to the logout page with embedded JavaScript | |
return RedirectResponse(url="/", status_code=302) | |
mount_chainlit(app=app, target="main.py", path=CHAINLIT_PATH) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="127.0.0.1", port=7860) | |