nhathuy07 commited on
Commit
1b8a458
·
verified ·
1 Parent(s): f3cb78d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -2
main.py CHANGED
@@ -3,7 +3,7 @@ from starlette.applications import Starlette
3
  from starlette.routing import Route
4
  from starlette.middleware import Middleware
5
  from starlette.middleware.cors import CORSMiddleware
6
-
7
  """Prompt templates for LLM"""
8
  from env import LLM_API_KEY
9
  import prompt
@@ -544,6 +544,45 @@ async def get_flashcards(request):
544
 
545
  return JSONResponse({"tldr": __tldr, "defs": __definitions, "imgs": await fetch_img_for_words(__keywords)})
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
  app = Starlette(debug=True,routes=[
549
  Route('/getFlashcards/{id}/{lang}', get_flashcards, methods=['GET']),
@@ -556,7 +595,7 @@ app = Starlette(debug=True,routes=[
556
  Route('/generateQuiz/{id}/{lang}', generate_questions, methods=['GET']),
557
  Route('/convert2md', convert2md, methods=['POST']),
558
  Route('/mltest', __mltest, methods=['GET'])
559
-
560
  ],
561
  middleware=middleware)
562
 
 
3
  from starlette.routing import Route
4
  from starlette.middleware import Middleware
5
  from starlette.middleware.cors import CORSMiddleware
6
+ from gensim.models import KeyedVectors
7
  """Prompt templates for LLM"""
8
  from env import LLM_API_KEY
9
  import prompt
 
544
 
545
  return JSONResponse({"tldr": __tldr, "defs": __definitions, "imgs": await fetch_img_for_words(__keywords)})
546
 
547
+ """
548
+ Similarity validation
549
+ """
550
+ w2v_vi = KeyedVectors.load_word2vec_format('wiki.vi.model.bin', binary=True)
551
+ # w2v_en = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin',binary=True)
552
+ vocab_vi = w2v_vi.key_to_index
553
+ # vocab_en = w2v_en.vocab
554
+ from underthesea import word_tokenize
555
+ from nltk.tokenize import word_tokenize as word_tokenize_en
556
+ from numpy import zeros,zeros_like
557
+ from scipy.spatial.distance import cosine
558
+
559
+ async def validate_similarity(request):
560
+ req = await request.json()
561
+ sent1, sent2 = req['sentences']
562
+ l = req['lang']
563
+
564
+ if (l == lang.VI_VN):
565
+ tokens1 = word_tokenize(sent1.lower())
566
+ tokens2 = word_tokenize(sent2.lower())
567
+ else:
568
+ tokens1 = word_tokenize_en(sent1.lower())
569
+ tokens2 = word_tokenize_en(sent2.lower())
570
+
571
+ vect1 = zeros_like(w2v_vi.get_vector('an'))
572
+ vect2 = zeros_like(w2v_vi.get_vector('an'))
573
+
574
+ for t in tokens1:
575
+ if t in vocab_vi:
576
+ vect1 += w2v_vi.get_vector(t)
577
+
578
+
579
+ for t in tokens2:
580
+ if t in vocab_vi:
581
+ vect2 += w2v_vi.get_vector(t)
582
+
583
+ # Calculate similarity using cosine similarity: This metric measures the cosine of the angle between two embedding vectors. A higher cosine similarity indicates more similar sentences.
584
+ sim = 1 - cosine(vect1, vect2) >= 0.8
585
+ return JSONResponse({"isSimilar": str(sim)})
586
 
587
  app = Starlette(debug=True,routes=[
588
  Route('/getFlashcards/{id}/{lang}', get_flashcards, methods=['GET']),
 
595
  Route('/generateQuiz/{id}/{lang}', generate_questions, methods=['GET']),
596
  Route('/convert2md', convert2md, methods=['POST']),
597
  Route('/mltest', __mltest, methods=['GET'])
598
+ Route('/validateSimilarity', validate_similarity, methods=['POST'])
599
  ],
600
  middleware=middleware)
601