MJobe commited on
Commit
694da95
·
verified ·
1 Parent(s): f0e6e2e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +45 -35
main.py CHANGED
@@ -406,53 +406,63 @@ def get_sub_classification(statement: str) -> str:
406
  return sub_label
407
  return "None" # Default to "None" if no keywords match
408
 
409
- @app.post("/classify_with_subcategory/", description="Quickly classify text into predefined categories.")
410
- async def fast_classify_text(statement: str = Form(...)):
411
  try:
412
- # Check for empty or "N/A" statements
413
  if not statement or statement.strip().lower() == "n/a":
414
- return {"classification": "Note not clear", "confidence": 1.0, "sub_classification": "None", "scores": {}}
415
-
416
- # Determine main classification based on keywords
417
- if any(keyword.lower() in statement.lower() for keyword in change_to_quote_keywords):
418
- main_classification = "Change to Quote"
419
- sub_classification = "None"
420
- elif any(keyword.lower() in statement.lower() for keyword in copy_quote_requested_keywords):
421
- main_classification = "Copy Quote Requested"
422
- # Perform sub-classification for Copy Quote Requested
423
- if "msrp" in statement.lower():
424
- sub_classification = "MRSP"
425
- elif "all pricing" in statement.lower():
426
- sub_classification = "All"
427
- elif "direct" in statement.lower():
428
- sub_classification = "Direct"
429
- else:
430
- sub_classification = "None" # No sub-classification when keywords don’t match
431
  else:
432
- # Call the Hugging Face model for cases where keywords don’t match
433
  loop = asyncio.get_running_loop()
434
- result = await loop.run_in_executor(
435
- executor,
436
- lambda: nlp_sequence_classification(statement, labels, multi_label=False)
437
  )
438
 
439
- main_classification = result["labels"][0]
440
- main_confidence = result["scores"][0]
441
- scores = dict(zip(result["labels"], result["scores"]))
442
- sub_classification = "None" # Set sub-classification to None for non-matching keywords
443
-
444
- return {
445
- "classification": main_classification,
446
- "confidence": main_confidence,
447
- "sub_classification": sub_classification,
448
- "scores": scores
449
- }
 
 
 
 
 
 
 
 
 
 
450
 
451
  except asyncio.TimeoutError:
 
452
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
453
  except HTTPException as http_exc:
 
454
  return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
455
  except Exception as e:
 
456
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
457
 
458
  # Set up CORS middleware
 
406
  return sub_label
407
  return "None" # Default to "None" if no keywords match
408
 
409
+ @app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
410
+ async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
411
  try:
412
+ # Check if the statement is empty or "N/A"
413
  if not statement or statement.strip().lower() == "n/a":
414
+ return ClassificationResponse(
415
+ classification="Notes not clear",
416
+ sub_classification="None",
417
+ confidence=1.0,
418
+ scores={"main": 1.0}
419
+ )
420
+
421
+ # Keyword-based classification override
422
+ if check_keywords(statement, change_to_quote_keywords):
423
+ main_best_label = "Change to quote"
424
+ main_best_score = 1.0 # High confidence since it's a direct match
425
+ elif check_keywords(statement, copy_quote_requested_keywords):
426
+ main_best_label = "Copy quote requested"
427
+ main_best_score = 1.0
 
 
 
428
  else:
429
+ # If no keywords matched, perform the main classification using the model
430
  loop = asyncio.get_running_loop()
431
+ main_classification_result = await loop.run_in_executor(
432
+ None,
433
+ lambda: nlp_sequence_classification(statement, main_labels, multi_label=False)
434
  )
435
 
436
+ # Extract the best main classification label and confidence score
437
+ main_best_label = main_classification_result["labels"][0]
438
+ main_best_score = main_classification_result["scores"][0]
439
+
440
+ # Perform sub-classification only if the main classification is "Copy quote requested"
441
+ if main_best_label == "Copy quote requested":
442
+ best_sub_label = get_sub_classification(statement)
443
+ else:
444
+ best_sub_label = "None"
445
+
446
+ # Gather the scores for response
447
+ scores = {"main": main_best_score}
448
+ if best_sub_label != "None":
449
+ scores[best_sub_label] = 1.0 # Assign full confidence to sub-classification matches
450
+
451
+ return ClassificationResponse(
452
+ classification=main_best_label,
453
+ sub_classification=best_sub_label,
454
+ confidence=main_best_score,
455
+ scores=scores
456
+ )
457
 
458
  except asyncio.TimeoutError:
459
+ # Handle timeout errors
460
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
461
  except HTTPException as http_exc:
462
+ # Handle HTTP errors
463
  return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
464
  except Exception as e:
465
+ # Handle any other errors
466
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
467
 
468
  # Set up CORS middleware