Kuberwastaken commited on
Commit
736f3ff
·
1 Parent(s): 6a0dc28

Added Support for searching movies

Browse files
.gitignore CHANGED
@@ -1 +1 @@
1
- treat-env
 
1
+ treat-scrape
__pycache__/script_search_api.cpython-310.pyc ADDED
Binary file (6.02 kB). View file
 
gradio_app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  from model.analyzer import analyze_content
3
  import asyncio
4
  import time
 
 
5
 
6
  custom_css = """
7
  * {
@@ -213,29 +215,81 @@ footer {
213
  }
214
  """
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def analyze_with_loading(text, progress=gr.Progress()):
217
- """
218
- Synchronous wrapper for the async analyze_content function with smooth progress updates
219
- """
220
- # Initialize progress
221
  progress(0, desc="Starting analysis...")
222
 
223
- # Initial setup phase - smoother progression
224
  for i in range(25):
225
- time.sleep(0.04) # Slightly longer sleep for smoother animation
226
  progress((i + 1) / 100, desc="Initializing analysis...")
227
 
228
- # Pre-processing phase
229
  for i in range(25, 45):
230
  time.sleep(0.03)
231
  progress((i + 1) / 100, desc="Pre-processing content...")
232
 
233
- # Perform analysis
234
  progress(0.45, desc="Analyzing content...")
235
  try:
236
  result = asyncio.run(analyze_content(text))
237
 
238
- # Analysis progress simulation
239
  for i in range(45, 75):
240
  time.sleep(0.03)
241
  progress((i + 1) / 100, desc="Processing results...")
@@ -243,12 +297,10 @@ def analyze_with_loading(text, progress=gr.Progress()):
243
  except Exception as e:
244
  return f"Error during analysis: {str(e)}"
245
 
246
- # Final processing with smooth progression
247
  for i in range(75, 100):
248
  time.sleep(0.02)
249
  progress((i + 1) / 100, desc="Finalizing results...")
250
 
251
- # Format the results
252
  triggers = result["detected_triggers"]
253
  if triggers == ["None"]:
254
  return "✓ No triggers detected in the content."
@@ -256,9 +308,7 @@ def analyze_with_loading(text, progress=gr.Progress()):
256
  trigger_list = "\n".join([f"• {trigger}" for trigger in triggers])
257
  return f"⚠ Triggers Detected:\n{trigger_list}"
258
 
259
- # Create the Gradio interface
260
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
261
- # Title section
262
  gr.HTML("""
263
  <div class="treat-title">
264
  <h1>TREAT</h1>
@@ -270,7 +320,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
270
  </div>
271
  """)
272
 
273
- # Content input section
274
  with gr.Row():
275
  with gr.Column(elem_classes="content-area"):
276
  input_text = gr.Textbox(
@@ -279,15 +328,21 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
279
  lines=8,
280
  interactive=True
281
  )
 
 
 
 
 
 
 
282
 
283
- # Button section
284
  with gr.Row(elem_classes="center-row"):
285
  analyze_btn = gr.Button(
286
  "✨ Analyze Content",
287
  variant="primary"
288
  )
 
289
 
290
- # Results section
291
  with gr.Row():
292
  with gr.Column(elem_classes="results-area"):
293
  output_text = gr.Textbox(
@@ -295,16 +350,25 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
295
  lines=5,
296
  interactive=False
297
  )
 
 
 
 
298
 
299
- # Set up the click event
300
  analyze_btn.click(
301
  fn=analyze_with_loading,
302
  inputs=[input_text],
303
  outputs=[output_text],
304
  api_name="analyze"
305
  )
 
 
 
 
 
 
 
306
 
307
- # Footer section
308
  gr.HTML("""
309
  <div class="footer">
310
  <p>Made with <span class="heart">💖</span> by <a href="https://www.linkedin.com/in/kubermehta/" target="_blank">Kuber Mehta</a></p>
 
2
  from model.analyzer import analyze_content
3
  import asyncio
4
  import time
5
+ import httpx
6
+ import json
7
 
8
  custom_css = """
9
  * {
 
215
  }
216
  """
217
 
218
+
219
+ async def fetch_and_analyze_script(movie_name, progress=gr.Progress(track_tqdm=True)):
220
+ try:
221
+ async with httpx.AsyncClient(timeout=60.0) as client:
222
+ # Start the analysis request
223
+ progress(0.2, desc="Initiating script search...")
224
+ response = await client.get(
225
+ f"http://localhost:8000/api/fetch_and_analyze",
226
+ params={"movie_name": movie_name}
227
+ )
228
+
229
+ if response.status_code == 200:
230
+ # Start progress polling
231
+ while True:
232
+ progress_response = await client.get(
233
+ f"http://localhost:8000/api/progress",
234
+ params={"movie_name": movie_name}
235
+ )
236
+
237
+ if progress_response.status_code == 200:
238
+ progress_data = progress_response.json()
239
+ current_progress = progress_data["progress"]
240
+ current_status = progress_data.get("status", "Processing...")
241
+
242
+ progress(current_progress, desc=current_status)
243
+
244
+ if current_progress >= 1.0:
245
+ break
246
+
247
+ await asyncio.sleep(0.5) # Poll every 500ms
248
+
249
+ result = response.json()
250
+ triggers = result.get("detected_triggers", [])
251
+
252
+ if not triggers or triggers == ["None"]:
253
+ formatted_result = "✓ No triggers detected in the content."
254
+ else:
255
+ trigger_list = "\n".join([f"• {trigger}" for trigger in triggers])
256
+ formatted_result = f"⚠ Triggers Detected:\n{trigger_list}"
257
+
258
+ return formatted_result
259
+ else:
260
+ return f"Error: Server returned status code {response.status_code}"
261
+
262
+ except httpx.TimeoutError:
263
+ return "Error: Request timed out. Please try again."
264
+ except Exception as e:
265
+ return f"An unexpected error occurred: {str(e)}"
266
+
267
+ async def track_progress(movie_name, progress):
268
+ async with httpx.AsyncClient() as client:
269
+ while True:
270
+ response = await client.get(f"http://localhost:8000/api/progress", params={"movie_name": movie_name})
271
+ if response.status_code == 200:
272
+ progress_data = response.json()
273
+ progress(progress_data["progress"], desc="Tracking progress...")
274
+ if progress_data["progress"] >= 1.0:
275
+ break
276
+ await asyncio.sleep(1)
277
+
278
  def analyze_with_loading(text, progress=gr.Progress()):
 
 
 
 
279
  progress(0, desc="Starting analysis...")
280
 
 
281
  for i in range(25):
282
+ time.sleep(0.04)
283
  progress((i + 1) / 100, desc="Initializing analysis...")
284
 
 
285
  for i in range(25, 45):
286
  time.sleep(0.03)
287
  progress((i + 1) / 100, desc="Pre-processing content...")
288
 
 
289
  progress(0.45, desc="Analyzing content...")
290
  try:
291
  result = asyncio.run(analyze_content(text))
292
 
 
293
  for i in range(45, 75):
294
  time.sleep(0.03)
295
  progress((i + 1) / 100, desc="Processing results...")
 
297
  except Exception as e:
298
  return f"Error during analysis: {str(e)}"
299
 
 
300
  for i in range(75, 100):
301
  time.sleep(0.02)
302
  progress((i + 1) / 100, desc="Finalizing results...")
303
 
 
304
  triggers = result["detected_triggers"]
305
  if triggers == ["None"]:
306
  return "✓ No triggers detected in the content."
 
308
  trigger_list = "\n".join([f"• {trigger}" for trigger in triggers])
309
  return f"⚠ Triggers Detected:\n{trigger_list}"
310
 
 
311
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
 
312
  gr.HTML("""
313
  <div class="treat-title">
314
  <h1>TREAT</h1>
 
320
  </div>
321
  """)
322
 
 
323
  with gr.Row():
324
  with gr.Column(elem_classes="content-area"):
325
  input_text = gr.Textbox(
 
328
  lines=8,
329
  interactive=True
330
  )
331
+ with gr.Row():
332
+ search_query = gr.Textbox(
333
+ label="Search Movie Scripts",
334
+ placeholder="Enter movie title...",
335
+ lines=1,
336
+ interactive=True
337
+ )
338
 
 
339
  with gr.Row(elem_classes="center-row"):
340
  analyze_btn = gr.Button(
341
  "✨ Analyze Content",
342
  variant="primary"
343
  )
344
+ search_button = gr.Button("🔍 Search and Analyze Script")
345
 
 
346
  with gr.Row():
347
  with gr.Column(elem_classes="results-area"):
348
  output_text = gr.Textbox(
 
350
  lines=5,
351
  interactive=False
352
  )
353
+ status_text = gr.Markdown(
354
+ label="Status",
355
+ value=""
356
+ )
357
 
 
358
  analyze_btn.click(
359
  fn=analyze_with_loading,
360
  inputs=[input_text],
361
  outputs=[output_text],
362
  api_name="analyze"
363
  )
364
+
365
+ search_button.click(
366
+ fn=fetch_and_analyze_script,
367
+ inputs=[search_query],
368
+ outputs=[output_text],
369
+ show_progress=True
370
+ )
371
 
 
372
  gr.HTML("""
373
  <div class="footer">
374
  <p>Made with <span class="heart">💖</span> by <a href="https://www.linkedin.com/in/kubermehta/" target="_blank">Kuber Mehta</a></p>
requirements.txt CHANGED
@@ -1,8 +1,11 @@
1
- flask
2
- flask_cors
3
  torch
4
  gradio
5
- transformers
6
  accelerate
7
  safetensors
8
- huggingface-hub
 
 
 
 
 
 
 
 
 
1
  torch
2
  gradio
 
3
  accelerate
4
  safetensors
5
+ huggingface-hub
6
+ fastapi
7
+ httpx
8
+ beautifulsoup4
9
+ bs4
10
+ httpx
11
+ json
script_search_api.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # script_search_api.py
2
+
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import requests
6
+ from bs4 import BeautifulSoup
7
+ from model.analyzer import analyze_content
8
+ import logging
9
+ from difflib import get_close_matches
10
+ import re
11
+ from typing import Dict
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ app = FastAPI()
19
+
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ BASE_URL = "https://imsdb.com"
29
+ ALL_SCRIPTS_URL = f"{BASE_URL}/all-scripts.html"
30
+
31
+ @dataclass
32
+ class ProgressInfo:
33
+ progress: float
34
+ status: str
35
+ timestamp: datetime
36
+
37
+ progress_tracker: Dict[str, ProgressInfo] = {}
38
+
39
+ def update_progress(movie_name: str, progress: float, message: str):
40
+ """
41
+ Update the progress tracker with current progress and status message.
42
+ """
43
+ progress_tracker[movie_name] = ProgressInfo(
44
+ progress=progress,
45
+ status=message,
46
+ timestamp=datetime.now()
47
+ )
48
+ logger.info(f"{message} (Progress: {progress * 100:.0f}%)")
49
+
50
+ def find_movie_link(movie_name: str, soup: BeautifulSoup) -> str | None:
51
+ """
52
+ Find the closest matching movie link from the script database.
53
+ """
54
+ movie_links = {link.text.strip().lower(): link['href'] for link in soup.find_all('a', href=True)}
55
+ close_matches = get_close_matches(movie_name.lower(), movie_links.keys(), n=1, cutoff=0.6)
56
+
57
+ if close_matches:
58
+ logger.info(f"Close match found: {close_matches[0]}")
59
+ return BASE_URL + movie_links[close_matches[0]]
60
+
61
+ logger.info("No close match found.")
62
+ return None
63
+
64
+ def find_script_link(soup: BeautifulSoup, movie_name: str) -> str | None:
65
+ """
66
+ Find the script download link for a given movie.
67
+ """
68
+ patterns = [
69
+ f'Read "{movie_name}" Script',
70
+ f'Read "{movie_name.title()}" Script',
71
+ f'Read "{movie_name.upper()}" Script',
72
+ f'Read "{movie_name.lower()}" Script'
73
+ ]
74
+
75
+ for link in soup.find_all('a', href=True):
76
+ link_text = link.text.strip()
77
+ if any(pattern.lower() in link_text.lower() for pattern in patterns):
78
+ return link['href']
79
+ elif all(word.lower() in link_text.lower() for word in ["Read", "Script", movie_name]):
80
+ return link['href']
81
+ return None
82
+
83
+ def fetch_script(movie_name: str) -> str | None:
84
+ """
85
+ Fetch and extract the script content for a given movie.
86
+ """
87
+ # Initial page load
88
+ update_progress(movie_name, 0.1, "Fetching the script database...")
89
+ try:
90
+ response = requests.get(ALL_SCRIPTS_URL)
91
+ response.raise_for_status()
92
+ except requests.RequestException as e:
93
+ logger.error(f"Failed to load the main page: {str(e)}")
94
+ return None
95
+
96
+ # Search for movie
97
+ update_progress(movie_name, 0.2, "Searching for the movie...")
98
+ soup = BeautifulSoup(response.text, 'html.parser')
99
+ movie_link = find_movie_link(movie_name, soup)
100
+
101
+ if not movie_link:
102
+ logger.error(f"Script for '{movie_name}' not found.")
103
+ return None
104
+
105
+ # Fetch movie page
106
+ update_progress(movie_name, 0.3, "Loading movie details...")
107
+ try:
108
+ response = requests.get(movie_link)
109
+ response.raise_for_status()
110
+ except requests.RequestException as e:
111
+ logger.error(f"Failed to load the movie page: {str(e)}")
112
+ return None
113
+
114
+ # Find script link
115
+ update_progress(movie_name, 0.4, "Locating script download...")
116
+ soup = BeautifulSoup(response.text, 'html.parser')
117
+ script_link = find_script_link(soup, movie_name)
118
+
119
+ if not script_link:
120
+ logger.error(f"Unable to find script link for '{movie_name}'.")
121
+ return None
122
+
123
+ # Fetch script content
124
+ script_page_url = BASE_URL + script_link
125
+ update_progress(movie_name, 0.5, "Downloading script content...")
126
+
127
+ try:
128
+ response = requests.get(script_page_url)
129
+ response.raise_for_status()
130
+ except requests.RequestException as e:
131
+ logger.error(f"Failed to load the script: {str(e)}")
132
+ return None
133
+
134
+ # Extract script text
135
+ update_progress(movie_name, 0.6, "Extracting script text...")
136
+ soup = BeautifulSoup(response.text, 'html.parser')
137
+ script_content = soup.find('pre')
138
+
139
+ if script_content:
140
+ update_progress(movie_name, 0.7, "Script extracted successfully")
141
+ return script_content.get_text()
142
+ else:
143
+ logger.error("Failed to extract script content.")
144
+ return None
145
+
146
+ @app.get("/api/fetch_and_analyze")
147
+ async def fetch_and_analyze(movie_name: str):
148
+ """
149
+ Fetch and analyze a movie script, with progress tracking.
150
+ """
151
+ try:
152
+ # Initialize progress
153
+ update_progress(movie_name, 0.0, "Starting script search...")
154
+
155
+ # Fetch script
156
+ script_text = fetch_script(movie_name)
157
+ if not script_text:
158
+ raise HTTPException(status_code=404, detail="Script not found or error occurred")
159
+
160
+ # Analyze content
161
+ update_progress(movie_name, 0.8, "Analyzing script content...")
162
+ result = await analyze_content(script_text)
163
+
164
+ # Finalize
165
+ update_progress(movie_name, 1.0, "Analysis complete!")
166
+ return result
167
+
168
+ except Exception as e:
169
+ logger.error(f"Error in fetch_and_analyze: {str(e)}", exc_info=True)
170
+ # Clean up progress tracker in case of error
171
+ if movie_name in progress_tracker:
172
+ del progress_tracker[movie_name]
173
+ raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
174
+
175
+ @app.get("/api/progress")
176
+ def get_progress(movie_name: str):
177
+ """
178
+ Get the current progress and status for a movie analysis.
179
+ """
180
+ if movie_name not in progress_tracker:
181
+ return {
182
+ "progress": 0,
183
+ "status": "Waiting to start..."
184
+ }
185
+
186
+ progress_info = progress_tracker[movie_name]
187
+
188
+ # Clean up old entries (optional)
189
+ current_time = datetime.now()
190
+ if (current_time - progress_info.timestamp).total_seconds() > 3600: # 1 hour timeout
191
+ del progress_tracker[movie_name]
192
+ return {
193
+ "progress": 0,
194
+ "status": "Session expired. Please try again."
195
+ }
196
+
197
+ return {
198
+ "progress": progress_info.progress,
199
+ "status": progress_info.status
200
+ }
201
+
202
+ @app.on_event("startup")
203
+ async def startup_event():
204
+ """
205
+ Initialize the server and clear any existing progress data.
206
+ """
207
+ progress_tracker.clear()
208
+ logger.info("Server started, progress tracker initialized")
209
+
210
+ if __name__ == "__main__":
211
+ import uvicorn
212
+ uvicorn.run(app, host="0.0.0.0", port=8000)