Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -406,53 +406,63 @@ def get_sub_classification(statement: str) -> str:
|
|
406 |
return sub_label
|
407 |
return "None" # Default to "None" if no keywords match
|
408 |
|
409 |
-
@app.post("/classify_with_subcategory/", description="
|
410 |
-
async def
|
411 |
try:
|
412 |
-
# Check
|
413 |
if not statement or statement.strip().lower() == "n/a":
|
414 |
-
return
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
sub_classification = "Direct"
|
429 |
-
else:
|
430 |
-
sub_classification = "None" # No sub-classification when keywords don’t match
|
431 |
else:
|
432 |
-
#
|
433 |
loop = asyncio.get_running_loop()
|
434 |
-
|
435 |
-
|
436 |
-
lambda: nlp_sequence_classification(statement,
|
437 |
)
|
438 |
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
|
451 |
except asyncio.TimeoutError:
|
|
|
452 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
453 |
except HTTPException as http_exc:
|
|
|
454 |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
455 |
except Exception as e:
|
|
|
456 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
457 |
|
458 |
# Set up CORS middleware
|
|
|
406 |
return sub_label
|
407 |
return "None" # Default to "None" if no keywords match
|
408 |
|
409 |
+
@app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
|
410 |
+
async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
|
411 |
try:
|
412 |
+
# Check if the statement is empty or "N/A"
|
413 |
if not statement or statement.strip().lower() == "n/a":
|
414 |
+
return ClassificationResponse(
|
415 |
+
classification="Notes not clear",
|
416 |
+
sub_classification="None",
|
417 |
+
confidence=1.0,
|
418 |
+
scores={"main": 1.0}
|
419 |
+
)
|
420 |
+
|
421 |
+
# Keyword-based classification override
|
422 |
+
if check_keywords(statement, change_to_quote_keywords):
|
423 |
+
main_best_label = "Change to quote"
|
424 |
+
main_best_score = 1.0 # High confidence since it's a direct match
|
425 |
+
elif check_keywords(statement, copy_quote_requested_keywords):
|
426 |
+
main_best_label = "Copy quote requested"
|
427 |
+
main_best_score = 1.0
|
|
|
|
|
|
|
428 |
else:
|
429 |
+
# If no keywords matched, perform the main classification using the model
|
430 |
loop = asyncio.get_running_loop()
|
431 |
+
main_classification_result = await loop.run_in_executor(
|
432 |
+
None,
|
433 |
+
lambda: nlp_sequence_classification(statement, main_labels, multi_label=False)
|
434 |
)
|
435 |
|
436 |
+
# Extract the best main classification label and confidence score
|
437 |
+
main_best_label = main_classification_result["labels"][0]
|
438 |
+
main_best_score = main_classification_result["scores"][0]
|
439 |
+
|
440 |
+
# Perform sub-classification only if the main classification is "Copy quote requested"
|
441 |
+
if main_best_label == "Copy quote requested":
|
442 |
+
best_sub_label = get_sub_classification(statement)
|
443 |
+
else:
|
444 |
+
best_sub_label = "None"
|
445 |
+
|
446 |
+
# Gather the scores for response
|
447 |
+
scores = {"main": main_best_score}
|
448 |
+
if best_sub_label != "None":
|
449 |
+
scores[best_sub_label] = 1.0 # Assign full confidence to sub-classification matches
|
450 |
+
|
451 |
+
return ClassificationResponse(
|
452 |
+
classification=main_best_label,
|
453 |
+
sub_classification=best_sub_label,
|
454 |
+
confidence=main_best_score,
|
455 |
+
scores=scores
|
456 |
+
)
|
457 |
|
458 |
except asyncio.TimeoutError:
|
459 |
+
# Handle timeout errors
|
460 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
461 |
except HTTPException as http_exc:
|
462 |
+
# Handle HTTP errors
|
463 |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
464 |
except Exception as e:
|
465 |
+
# Handle any other errors
|
466 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
467 |
|
468 |
# Set up CORS middleware
|