File size: 31,005 Bytes
491df38
dacef8c
 
 
 
68016fb
75f5c1d
 
1b8a458
dacef8c
8f1d401
dacef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0f26f
 
dacef8c
 
 
 
 
fb3d342
dacef8c
40c4aae
 
ba0549c
a56712e
b6b3aea
 
 
a56712e
 
dacef8c
472e084
 
 
 
 
 
 
 
dacef8c
 
 
 
 
 
 
68016fb
 
 
ba0549c
68016fb
dacef8c
68016fb
 
 
c949efc
68016fb
 
 
300d20f
cd1f226
68016fb
 
97fe105
68016fb
c949efc
 
 
 
 
63fd2c8
300d20f
68016fb
 
 
 
 
 
 
 
 
 
dacef8c
 
 
 
 
 
ba0549c
f8bcabb
 
 
300d20f
c5c68c1
f8bcabb
 
 
ba0549c
 
f8bcabb
 
 
 
 
 
 
 
 
 
ba0549c
 
 
6313289
dacef8c
 
 
 
 
ba0549c
dacef8c
 
 
 
 
 
ba0549c
 
 
dacef8c
 
 
 
ba0549c
dacef8c
 
 
ba0549c
dacef8c
 
 
 
 
0ace5b7
7456058
 
 
b6b3aea
 
b3b29cf
 
b6b3aea
 
 
 
 
192b822
b3b29cf
b6b3aea
b3b29cf
 
 
 
 
 
 
 
 
 
33eb15f
 
 
 
 
 
 
 
 
26ccf72
 
3bdc656
fb3d342
3bdc656
fb3d342
 
 
 
 
 
 
 
 
 
3bdc656
fb3d342
3bdc656
fb3d342
 
 
 
 
3bdc656
b3b29cf
dacef8c
3bdc656
 
c14568c
 
 
dacef8c
b3b29cf
966c4cd
dacef8c
 
 
 
 
 
 
 
f2607ba
 
dacef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380b164
dacef8c
 
 
 
 
 
 
 
 
 
 
380b164
 
dacef8c
 
 
 
 
 
19ac67f
dacef8c
eed7a79
dacef8c
 
 
 
 
 
be0f26f
 
 
8f1d401
be0f26f
 
 
 
 
cfe66d9
be0f26f
 
808acd9
 
be0f26f
425e817
dacef8c
380b164
 
85e9082
a6272d6
 
 
 
 
dacef8c
 
 
 
be0f26f
 
 
8f1d401
be0f26f
 
 
 
 
cfe66d9
be0f26f
808acd9
 
be0f26f
a6272d6
380b164
a6272d6
 
 
 
 
 
 
 
 
 
 
380b164
be0f26f
 
 
8f1d401
be0f26f
 
 
 
 
cfe66d9
be0f26f
808acd9
 
dacef8c
 
a6272d6
 
 
 
 
bc2d16f
 
 
 
 
a6272d6
 
 
dacef8c
 
 
 
 
 
 
8030023
 
dacef8c
a5a325d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dacef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e94358
ba0549c
4a54c07
cf43ada
4a54c07
dacef8c
 
 
 
 
 
 
 
 
dcb60b6
dacef8c
 
dcb60b6
783b45d
fb3d342
dacef8c
 
bc2d16f
dacef8c
bc2d16f
 
dacef8c
 
 
cfe66d9
 
 
 
 
 
 
 
 
35be032
cfe66d9
45d78a9
 
6017cc1
cfe66d9
7a1af35
 
35be032
 
 
7a1af35
cfe66d9
 
6017cc1
45d78a9
35be032
 
45d78a9
1b90642
cfe66d9
dacef8c
 
 
 
 
 
 
 
 
bc2d16f
 
 
 
dacef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33eb15f
 
dacef8c
33eb15f
dacef8c
33eb15f
 
 
2188b8e
 
dacef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b29cf
7b281ec
dacef8c
7b281ec
472e084
dacef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c949efc
 
 
 
 
8f1d401
c949efc
 
 
 
 
 
 
 
 
 
 
851df45
7456058
 
 
c949efc
215d918
966c4cd
c949efc
851df45
 
 
dacef8c
 
 
ba0549c
7e94358
ba0549c
 
dacef8c
ba0549c
dacef8c
 
 
c949efc
dacef8c
 
 
 
1b8a458
 
 
 
 
 
 
 
 
 
 
d4e66b1
1b8a458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4e66b1
 
 
 
 
 
 
 
dacef8c
491df38
 
 
 
 
 
 
 
a22e37c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b5f6e8
 
 
97fe105
 
 
 
 
 
 
 
 
 
68016fb
97fe105
 
dacef8c
 
 
 
 
 
 
 
 
491df38
 
dacef8c
6214607
2b5f6e8
c949efc
a22e37c
c949efc
dacef8c
 
0bf4c69
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
from starlette.responses import PlainTextResponse, JSONResponse, FileResponse
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.exceptions import HTTPException

from gensim.models import KeyedVectors
"""Prompt templates for LLM"""
from env import LLM_API_KEY, get_llm_api_key
import prompt
from time import time

from re import split, match

from PIL import Image

import requests

import json

import pypandoc
import cv2
from io import BytesIO
import numpy as np
import os

import pytesseract
import lang

import httpx

from secrets import SystemRandom

from random import randint, sample

from enum import Enum
from re import sub, findall, escape

from functools import partial

import redis.asyncio as redis

import asyncio
import subprocess

pool = redis.ConnectionPool.from_url("redis://localhost")
r = redis.Redis.from_pool(pool)

import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format="[%(asctime)s %(levelname)s] %(message)s")

# Define a logger for your application (optional)
app_logger = logging.getLogger(__name__)

from flashcard_util import tldr, get_definitions_from_words, fetch_img_for_words

class QType(Enum):
    WH = 0
    STMT = 3
    FILL = 6

class APIGuardMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)

    async def dispatch(self, request, call_next):

        # Get current client url and client IP address
        client_url = request.url.path
        client_ip_addr = request.client.host
        
        # IP-based rate limitation
        async with r.pipeline(transaction=True) as pipeline:
            try:
                res = await pipeline.get(client_ip_addr).execute()
                res = int(res[-1])
            except:
                res = None
    
            if res == None:
                # lim = 25 if client_url.endswith('text2quiz-three.vercel.app') else 5
                # lim = 10 if client_url.endswith('localhost:8100') else 5

                app_logger.info(client_url)
                
                ok = await pipeline.set(client_ip_addr, 60).execute()
                await r.expire(client_ip_addr, 60)
            elif res > 0:
                ok = await pipeline.set(client_ip_addr, res-1).execute()
            else:
                raise HTTPException(status_code=429, detail="This IP address is rate-limited")
        
        # process the request and get the response    
        response = await call_next(request)
        return response


sys_random = SystemRandom()

# TODO: Change to environment variable in prod.

#pytesseract.pytesseract.tesseract_cmd = r"C:\Users\Admin\AppData\Local\Programs\Tesseract-OCR\tesseract.exe"

async def __internal_tmp_w(id, content:any):
    try:
        async with r.pipeline(transaction=True) as pipeline:
            ok = await pipeline.set(id, json.dumps(content).encode("utf-8")).execute()
            await r.expire(id, 600)

        return ok
    except Exception as e:
        app_logger.info(e)

async def __internal_tmp_r(id):
    try:
        async with r.pipeline(transaction=True) as pipeline:
            res = await (pipeline.get(id).execute())
            if res[-1] == None:
                return [None, None, None]
            res = res[-1].decode("utf-8")
        return json.loads(res)
    except Exception as e:
        app_logger.info(e)
        return [None,None,None]

async def __internal_tmp_d(id):
    async with r.pipeline(transaction=True) as pipeline:
        res = await (pipeline.delete(id).execute())

async def __mltest(request):
    pass

async def __save_temp(request):
    file_id = sys_random.randbytes(20).hex()
    content = ""
    # async with request.form(max_fields=3) as form:
    form = await request.json()
    content = form['content']
    title = form['title']
    keywords = form['keywords']

    await __internal_tmp_w(file_id, [title, content, keywords])
    
    print(file_id)
    return PlainTextResponse(file_id, 200)
    
async def __get_temp(request, entry = 1):
    return JSONResponse(await __internal_tmp_r(request.path_params['id']))

async def __remove_temp(request):
    try:
        __internal_tmp_d(request.path_params['id'])
    except:
        return PlainTextResponse("", 500)
    
    return PlainTextResponse("", 200)

async def __convert_text(input, type_out="plain", type_in="html"):
    if (not input):
        app_logger.info("__convert_text: nothing to convert!")
        return ""
    # Create a subprocess
    process = await asyncio.create_subprocess_exec(
        # command to execute
        'pandoc', '-f', type_in, '-t', type_out,
        stdout=asyncio.subprocess.PIPE,  # redirect stdout
        stderr=asyncio.subprocess.STDOUT,
        stdin=asyncio.subprocess.PIPE,# redirect stderr
    )
    stdout, _ = await process.communicate(input=input.encode())
    #print("CONVERTED: ",stdout.decode("utf-8"))
    return (stdout.decode("utf-8"))

async def __convert_file(fname_in, type_out="plain"):
    proc = await asyncio.create_subprocess_exec(
        'pandoc', '-i', fname_in, '-t', type_out,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.STDOUT,
        stdin=asyncio.subprocess.PIPE,
    )
    stdout, _ = await proc.communicate()
    return stdout.decode("utf-8")

async def __ocr(im, file_id):
    # Perform image preprocessing
    processed_im = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2GRAY)

    cv2.imwrite(f"{file_id}.png", processed_im)

    out = pytesseract.image_to_string(f"{file_id}.png", lang="vie", config=r'--psm 4')
    os.remove(f"{file_id}.png")

    return out

def convert_links_to_text(text):
    txt = text

    # Anything that isn't a square closing bracket
    name_regex = "[^]]+"
    # http:// or https:// followed by anything but a closing paren
    url_regex = "http[s]?://[^)]+"
    
    markup_regex = '\[({0})]\(\s*({1})\s*\)'.format(name_regex, url_regex)
    
    for match in findall(markup_regex,txt):
    	link_str = f"[{match[0]}]({match[1]})"
    	txt = txt.replace(link_str, match[0])

    return txt

def remove_wikipedia_footnote_ptrs(text):
    txt = text
    wiki_footnote_regex = r'\\\[\d+\\]'
    txt = sub(wiki_footnote_regex, '', txt)
    return txt

async def __convert2md(inp):
    # Use gfm-raw_html to strip styling data from source file
    converted = await __convert_text(inp, "gfm-raw_html", "html")
    converted_without_link = convert_links_to_text(converted)
    converted_without_footnote_ptr = remove_wikipedia_footnote_ptrs(converted_without_link)
    print("[CONVERT]:", converted_without_footnote_ptr)
    return converted_without_footnote_ptr

async def __convert2plain(inp):
    return await __convert_text(inp, "plain", "html")

def convert2md(req):
    pass

async def __parse_paragraphs (content: str, batching: bool = False):
    _p = ""
    _rp = content

    _rp = await __convert2md(_rp)
    _rp = _rp.replace('\r','')
    # remove empty lines and headers
    _p = [_x.strip() for _x in _rp.split('\n\n') if len(_x)!=0 and _x.strip().count('#') != len(_x)]

    _p_json = []
    h_cnt =0
    header=""
    for _n in _p:


        __h_cnt =0
        prev_h = ""
        # parse header for each paragraphs
        try:
            for _c in _n:
                if _c == '#': __h_cnt+=1
                else: break

            if (__h_cnt >= 1 and len(_n) > __h_cnt): 
                header=_n
                h_cnt = __h_cnt
                # print(_n, len(_n))
            elif (len(_n.replace('#','').strip())): 
                # remove accidental /n's in converted HTML content

                if (batching and len(_p_json) >= 1):
                    if (header == _p_json[-1]['header']):
                        # print(header)
                        _p_json[-1]['content'] += '\n'
                        _p_json[-1]['content'] += _n.replace('\n', ' ')
                        _p_json[-1]['count']+=1
                        continue

                _p_json.append({'header': header, 'h_cnt': h_cnt, 'content': _n.replace('\n',' '), 'count': 1})

            

        except:
            continue

    return _p_json

async def __query_ml_predict(qtype: QType, content: str, header: str, token_limit: int, num_qs=5, l=lang.VI_VN):
    """Get prediction from a third-party Llama3-8B-Instruct deployment"""
    app_logger.info('[PROC] ML prediction started')
    stopwatch = time()

    match qtype:
        case QType.WH:
            
            # Make request to Awan LLM endpoint
            async with httpx.AsyncClient() as client:
                _r = await client.post(
                    url="https://api.awanllm.com/v1/chat/completions",
                    headers={'Content-Type': 'application/json', 'Authorization': f'Bearer {get_llm_api_key()}'},
                    data=json.dumps({
                        "model": "Meta-Llama-3-8B-Instruct",
                        "messages": [
                            {"role": "user", "content": prompt.gen_prompt_wh(content=content, header=header, num_qs=num_qs, lang=l)}
                        ],
                        "max_tokens": 4096,
                        "presence_penalty":0.3,
                        "temperature":0.55
                    }),
                    timeout=None
                )
            
            print(time() - stopwatch)
            if _r.status_code != 200: 
                app_logger.info(_r.json())
                return {"content": "", "style": None, "success": False}

            try:
                return {"content": _r.json()['choices'][0]['message']['content'], "style": QType.WH, "success": True}
            except:
                return {"content": "", "style": None, "success": False}
        
        case QType.STMT:

            # Make request to Awan LLM endpoint
            async with httpx.AsyncClient() as client:
                _r = await client.post(
                    url="https://api.awanllm.com/v1/chat/completions",
                    headers={'Content-Type': 'application/json', 'Authorization': f'Bearer {get_llm_api_key()}'},
                    data=json.dumps({
                        "model": "Meta-Llama-3-8B-Instruct",
                        "messages": [
                            {"role": "user", "content": prompt.gen_prompt_statements(content=content, header=header, num_qs=num_qs, lang=l)}
                        ],
                        "max_tokens": 4096,
                        
                    }),
                    timeout=None
                )
            if _r.status_code//100 != 2: 
                app_logger.info(_r.json())
                return {"content": "", "style": QType.STMT, "success": False}

            try:
                _r_content = _r.json()['choices'][0]['message']['content']
            except:
                return {"content": "", "style":None, "success":False}
            
            try:
                _r_content = _r.json()['choices'][0]['message']['content'].split('\n\n',1)[1]
            except:
                _r_content = _r.json()['choices'][0]['message']['content'].split('\n',1)[1]
                
            async with httpx.AsyncClient() as client:
                _w = await client.post(
                    url="https://api.awanllm.com/v1/chat/completions",
                    headers={'Content-Type': 'application/json', 'Authorization': f'Bearer {get_llm_api_key()}'},
                    data=json.dumps({
                        "model": "Meta-Llama-3-8B-Instruct",
                        "messages": [
                            {"role": "user", "content": prompt.gen_prompt_statements_false(content=_r_content, lang=l)}
                        ],
                        "max_tokens": 4096,
                        
                    }),
                    timeout=None
            )

            try:
                _ch = _w.json()['choices'][0]['message']['content']
            except:
                return {"content": "", "style": None, "success":False}
            
            try:
                _w_content = _w.json()['choices'][0]['message']['content'].split('\n\n',1)[1]
                #print(time() - stopwatch)
                return {"content": f"{_r_content}\n{_w_content}", "style": QType.STMT, "success": True}
            except:
                _w_content = _w.json()['choices'][0]['message']['content'].split('\n',1)[1]
                #print(time() - stopwatch)
                return {"content": f"{_r_content}\n{_w_content}", "style": QType.STMT, "success": True}


async def parse_wh_question(raw_qa_list, pgph_i):
    __ANS_KEY_MAPPING = {'A': 1, 'B':2, 'C':3,'D':4}
    __parsed_outputs = []
    for x in raw_qa_list:
        try:
            segments = x
            raw_key = segments[5]
            raw_key = 'A' if 'A' in raw_key else 'B' if 'B' in raw_key else 'D' if 'D' in raw_key else 'C'

            # print(segments) 
            match randint(0, 3):
                case 0 | 1:
    
                    __parsed_outputs.append(
                        {
                            "pgph_i": pgph_i,
                            "prompt": segments[0],
                            "type": "MCQ",
                            "choices": segments[1:5],
                            "keys": [segments[__ANS_KEY_MAPPING[raw_key]],],
                        }
                    )
                
                case 2 | 3:
                    __parsed_outputs.append(
                        {
                            "pgph_i": pgph_i,
                            "prompt": segments[0],
                            "type": "OPEN",
                            # Cleaning up ML output
                            "keys": [segments[__ANS_KEY_MAPPING[raw_key]].split(' ',1)[1]],
                            "choices": [segments[__ANS_KEY_MAPPING[raw_key]]]
                        }
                    )

        except:
            print("invalid: ", x)
            continue

    return __parsed_outputs

async def parse_stmt_question(stmts: list[str], pgph_i, __lang:str):
    print("starting inference...")
    
    if (stmts[0].__contains__('True: ') or stmts[0].__contains__('False: ')):
        __correct_stmts = [r[5:].strip() for r in stmts if r.__contains__('True: ')]
        __false_stmts = [r[5:].strip() for r in stmts if r.__contains__('False: ')]
    else:
        __correct_stmts = stmts[:len(stmts)//2]
        __false_stmts = stmts[len(stmts)//2:]
    __parsed_outputs = []

            # while len(__correct_stmts) >= 2:


    for c in range(0, len(__correct_stmts), 2):
        
        match randint(0, 6):
            
            case 6:
                try:
                    __parsed_outputs.append(
                        {
                            "pgph_i": pgph_i,
                            "prompt": prompt.USER_PROMPTS['AMEND'] if __lang==lang.VI_VN else prompt.USER_PROMPTS_EN['AMEND'],
                            "type": "AMEND",
                            "keys": __correct_stmts[c],
                            "choices": [__false_stmts[c]]
                        }
                    )
                    __parsed_outputs.append(
                        {
                            "pgph_i": pgph_i,
                            "prompt": prompt.USER_PROMPTS['AMEND'] if __lang==lang.VI_VN else prompt.USER_PROMPTS_EN['AMEND'],
                            "type": "AMEND",
                            "keys": __correct_stmts[c+1],
                            "choices": [__false_stmts[c+1]]
                        }
                    )
                except:
                    continue

            case 2|4:
                __c = __correct_stmts[c:c+2]
                # print(min(2, len(__false_stmts)))
                try:
                    __parsed_outputs.append(
                        {
                            "pgph_i": pgph_i,
                            "prompt": prompt.USER_PROMPTS['MULT'] if __lang==lang.VI_VN else prompt.USER_PROMPTS_EN['MULT'],
                            "type": "MULT",
                            "keys": __c,
                            "choices": sample([*__c,  *sample( __false_stmts, min(2, len(__false_stmts)) )], min(2, len(__false_stmts)) + len(__c))
                        }
                    )
                except:
                    continue

            case 3|5:
                try:
                    __c = sample(__false_stmts, 2)
                    __parsed_outputs.append(
                        {
                            "pgph_i": pgph_i,
                            "prompt": prompt.USER_PROMPTS['MULT_INV'] if __lang==lang.VI_VN else prompt.USER_PROMPTS_EN['MULT_INV'],
                            "type": "MULT",
                            "keys": __c,
                            "choices": sample([*__c, __correct_stmts[0], __correct_stmts[1]], 2+len(__c))
                        }
                    )
                except:
                    continue

            case 0|1:
                for aa in range(2):
                    try:
                        _prompt = __correct_stmts[c+aa]
                    except:
                        continue
                    # print(_prompt)
                    # FIXME: To circumvent some quirky 3rd party lib bugs around chunking phrases with quote, strip them from the sentences for the time being.
                    _prompt = _prompt.replace("\"", "").replace("\'", "")
                    _content_w = []
                    if __lang == lang.VI_VN:
                        _, _content_w = prompt.parse_content_words([_prompt])
                    else:
                        _, _content_w = prompt.parse_content_words_nltk([_prompt])
                    # print(_proper_n)
                    
                    for i, ns in enumerate(_content_w, 1):
                        try:
                            initials = "...".join([w[0] for w in ns.split(" ") if w])
                        except:
                            initials = "..."
                        _prompt = _prompt.replace(ns, f"({initials}...)", 1)

                    __parsed_outputs.append(
                        {
                            "pgph_i": pgph_i,
                            "prompt": _prompt,
                            "type": "OPEN",
                            "keys": _content_w,
                            "choices": []
                        }
                    )


    return __parsed_outputs


async def generate_questions(request):

    # parse paragraphs from document file
    try:
        __cont = await __internal_tmp_r(request.path_params['id'])
        __ps = await __parse_paragraphs(__cont[1], batching=True)
    except Exception as e:
        print(str(e))
        return JSONResponse({"err": str(e)}, 500)

    # Map asyncronous ML prediction function over list of paragraphs
    ptasks = [] 
    __raw_outputs = []
    __parsed_outputs = []

    # print(__ps)
    
    for z, _p in enumerate(__ps):
        # __query_ml_predict is an awaitable
        ptasks.append(__query_ml_predict(qtype=(QType.STMT if z%2==1 else QType.WH), content=_p['content'], header=_p['header'], l=request.path_params.get('lang', lang.VI_VN), num_qs=request.path_params.get('num_qs', 5 * _p.get('count', 1)), token_limit = int(1024 * _p.get('count', 1))))

    # __raw_outputs = [await p for p in ptasks]
    __raw_outputs = await asyncio.gather(*ptasks)
    print(__raw_outputs)
    for pgph_i, o in enumerate(__raw_outputs):
        # print(o)
        # print(pgph_i)
    # TODO: Parse ML output to JSON
        if (not o['success']):
            continue
        if (o['style'] == QType.WH):
            
            raw_qa_list = []
            # raw_segmented: list[str] = list(filter(lambda x: (len(x)>0), o['content'].split("\n\n")))[1:]

            # for i in range(len(raw_segmented)):
            #     if (len(raw_segmented[i]) and raw_segmented[i].count('\n') < 5):
            #         raw_segmented[i] += f'\n{raw_segmented[i+1]}'
            #         raw_segmented[i+1] = ""

            # print(raw_segmented)
            seg_index = 0
            seg_index_map = ['Q', 'A', 'B', 'C', 'D', '']
            raw_segmented = []
            raw_segmented_list = []
            

            for seg in o['content'].split('\n'):
                
                if seg.strip().startswith(seg_index_map[seg_index]):
                    if seg_index == 5:
                        if not ('A' in seg or 'B' in seg or 'C' in seg or 'D' in seg):
                            continue
                    print(seg)
                    raw_segmented.append(seg)
                    seg_index+=1
                if seg_index == 6:
                    raw_segmented_list.append(raw_segmented.copy())
                    raw_segmented = []
                    seg_index = 0

            __parsed_outputs.extend(await parse_wh_question(raw_segmented_list, pgph_i))
            seg_index = 0

        elif (o['style'] == QType.STMT):
            print(o['content'])
            # remove_after_dash_and_parentheses
            stmts = [ sub(r" - .*| \(.*\)", "", x.split('. ',1)[1]) for x in o['content'].split('\n') if bool(match("^\d+\.", x))]
            # print(stmts)
            __parsed_outputs.extend(await parse_stmt_question(stmts, pgph_i, request.path_params.get('lang', lang.VI_VN)))
    
    # Return the question data
    if len(__parsed_outputs):
        return JSONResponse({"questions": __parsed_outputs, "paragraphs": __ps, "title": __cont[0]})
    else:
        raise HTTPException(500)

async def scan2OCR(request):
    
    content = b''
    ret = []

    async with request.form(max_files=10, max_fields=20) as form:
        for i in range(int(form['uploads'])):

            # Get random file ID
            file_id = sys_random.randbytes(12).hex()

            # Load image using PIL and convert to opencv grayscale format
            im = Image.open(BytesIO(await form[f'upload_{i}'].read()))
            
            # # Perform image preprocessing
            # processed_im = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2GRAY)

            # cv2.imwrite(f"{file_id}.png", processed_im)

            # out = pytesseract.image_to_string(f"{file_id}.png", lang="vie", config=r'--psm 4')
            # os.remove(f"{file_id}.png")
            _loop = asyncio.get_running_loop()
            _out = await _loop.run_in_executor(None, partial(__ocr, im, file_id))
            out = await _out
            # adapt the output text to the HTML-based rich text editor
            ret.append({"content": out.replace('\n','<br/>')})
    
    return JSONResponse(ret, 200)


async def convert2html(request):
    content = b''
    filename = ""
    output = ""
    files = []

    rets = []

    async with request.form(max_files=10, max_fields=20) as form:

        print(form['uploads'])
        for i in range(int(form['uploads'])):
            # Get random file ID
            filename = sys_random.randbytes(12).hex()
            ext = form[f'upload_{i}'].filename.split(".")[-1]

            content = await form[f'upload_{i}'].read()

            with open(f"{filename}.{ext}", 'wb') as o:
                o.write(content)
            files.append(f"{filename}.{ext}")
    
    for file in files:
        try:
            output = await __convert_file(file, "html")
            print(output)

        except Exception as e:
            app_logger.error(e)
            return JSONResponse({"detail": ""}, status_code=422)
    
        # Extract image sources from document
        imgs = []
        start = -1

        for i in range(len(output)):
            if output[i:i+4] == "<img":
                
                start = i

            if output[i:i+2] == "/>" and start != -1:

                img_tag = output[start:i+2]
                imgs.append(img_tag)

                start = -1

        for x in imgs:
            output = output.replace(x, " ")
        
        # Remove upload file
        os.remove(file)
    
        rets.append({"content": output, "resources": imgs})

    return JSONResponse(rets)

async def llm_generate_text(request):
    async with httpx.AsyncClient() as client:
        o = await request.json()
        _r = await client.post(
                    url="https://api.awanllm.com/v1/chat/completions",
                    headers={'Content-Type': 'application/json', 'Authorization': f'Bearer {get_llm_api_key()}'},
                    data=json.dumps({
                        "model": "Meta-Llama-3-8B-Instruct",
                        "messages": [
                            {"role": "user", "content": f"Generate a study note about {o['prompt']}. Use {o['lang']}."}
                        ],
                        "max_tokens": 4096,
                        "presence_penalty":0.3,
                        "temperature":0.5
                    }),
                    timeout=None)

        try:  
            if _r.status_code != 200:
                raise HTTPException(status_code=429, detail=str(_r.json()))
                
            repl = _r.json()['choices'][0]['message']['content']
            # input, type_out="plain", type_in="html"
            repl = await __convert_text(repl, "html-raw_html", "markdown")
            return PlainTextResponse(repl)
            
        except Exception as e:
            raise HTTPException(status_code=422, detail=str(e))

async def get_flashcards(request):
    # [title, content, keywords]

    __file = await __internal_tmp_r(request.path_params['id'])
    __content = __file[1]
    
    __lang = request.path_params['lang']
    __keywords = [r.strip() for r in __file[2] if len(r) > 0]

    __tldr = await tldr(__content, __lang)
    print(__tldr)
    __definitions = await get_definitions_from_words(__keywords, __tldr, __lang)
    print(__definitions)

    return JSONResponse({"tldr": __tldr, "defs": __definitions, "imgs": await fetch_img_for_words(__keywords)})

"""
Similarity validation
"""
w2v_vi = KeyedVectors.load_word2vec_format('wiki.vi.model.bin', binary=True)
# w2v_en = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin',binary=True)
vocab_vi = w2v_vi.key_to_index
# vocab_en = w2v_en.vocab
from underthesea import word_tokenize
from nltk.tokenize import word_tokenize as word_tokenize_en
from numpy import zeros,zeros_like
from scipy.spatial.distance import cosine
import warnings
async def validate_similarity(request):
    req = await request.json()
    sent1, sent2 = req['sentences']
    l = req['lang']

    if (l == lang.VI_VN):
        tokens1 = word_tokenize(sent1.lower())
        tokens2 = word_tokenize(sent2.lower())
    else:
        tokens1 = word_tokenize_en(sent1.lower())
        tokens2 = word_tokenize_en(sent2.lower())

    vect1 = zeros_like(w2v_vi.get_vector('an'))
    vect2 = zeros_like(w2v_vi.get_vector('an'))

    for t in tokens1:
        if t in vocab_vi:
            vect1 += w2v_vi.get_vector(t)
    
    
    for t in tokens2:
        if t in vocab_vi:
            vect2 += w2v_vi.get_vector(t)

    # Calculate similarity using cosine similarity: This metric measures the cosine of the angle between two embedding vectors. A higher cosine similarity indicates more similar sentences.
    with warnings.catch_warnings():
        warnings.simplefilter('error', RuntimeWarning)
        try:
            sim = 1 - cosine(vect1, vect2) >= 0.8
        except RuntimeWarning as e:
            return JSONResponse({"isSimilar": "False"})
            
        return JSONResponse({"isSimilar": str(sim)})

async def get_cached_img_from_disk(request):
    _fn = request.path_params['fn']
    # /images/img_-3711971785602203114.webp HTTP/1.1"
    if _fn.startswith('img_') and _fn.endswith('.webp'):
        return FileResponse(_fn)
    else:
        raise HTTPException(404)

async def generate_redemption(request):
    req = await request.json()
    paragraphs = req['pgphs']

    ret_questions = []

    ptasks = []
    for paragraph in paragraphs:
        ptasks.append(__query_ml_predict(QType.WH, paragraph["content"], paragraph["header"], 4096, num_qs=paragraph.get("count", 1)*5, l=req['lang']))

    raw_questions: list[str] = await asyncio.gather(*ptasks)
    for query in raw_questions:
        if not query['success']:
            continue

        q = query['content']

        raw_segments = [x.strip() for x in q.split('\n') if x.strip()]
        filtered = []
        seg_cnt = 0
        seg_map = ['Q', 'A', 'B', 'C', 'D', '']
        for r in raw_segments:
            if r.startswith(seg_map[seg_cnt]):
                if seg_cnt == 5:
                    if not ('A' in r or 'B' in r or 'C' in r or 'D' in r):
                        continue
                filtered.append(r)
                seg_cnt += 1
            if seg_cnt == 6:
                _r = filtered[5]
                ans_index = 1 if 'A' in _r else 2 if 'B' in _r else 4 if 'D' in _r else 3
                ans_key = filtered[ans_index][2:].strip()

                if randint(0, 1) == 1:
                    ret_questions.append({'prompt': filtered[0], 'keys': [ans_key,]})
                else:
                    prompt_format = "\n".join(filtered[0:5])
                    ret_questions.append({'prompt': prompt_format, 'keys': [seg_map[ans_index]]})

                filtered.clear()
                seg_cnt = 0

    return JSONResponse({'questions': ret_questions})



async def root(requests):
    return PlainTextResponse("Success")


# Application entry point
routes = ...

middleware = [
    Middleware(
        CORSMiddleware,
        allow_origins=['http://localhost:8100', 'https://text2quiz-three.vercel.app'],
        allow_methods =['*'],
    ),
    Middleware(APIGuardMiddleware),
]

app = Starlette(debug=True,routes=[
    Route('/getFlashcards/{id}/{lang}', get_flashcards, methods=['GET']),
    Route('/convert2html',convert2html, methods=['POST']),
    Route('/scan2ocr', scan2OCR, methods=['POST']),
    Route('/temp', __save_temp, methods=['POST']),
    Route('/temp/{id}', __get_temp, methods=['GET']),
    Route('/temp/{id}', __remove_temp, methods=['DELETE']),
    # Route('/generateQuiz/{id}', generate_questions, methods=['GET']),
    Route('/generateQuiz/{id}/{lang}', generate_questions, methods=['GET']),
    # /images/img_-3711971785602203114.webp HTTP/1.1"
    Route('/images/{fn}', get_cached_img_from_disk, methods=['GET']),
    Route('/convert2md', convert2md, methods=['POST']),
    Route('/mltest', __mltest, methods=['GET']),
    Route('/validateSimilarity', validate_similarity, methods=['POST']),
    Route('/llmGenerateText', llm_generate_text, methods=['POST']),
    Route('/generateRedemption', generate_redemption, methods=['POST']),
    Route('/', root, methods=['GET']),
],
middleware=middleware)

import os
print("running at: " + os.getcwd())