linxy commited on
Commit
bb48ea5
ยท
1 Parent(s): 218f679
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/*.pyc
LLM.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding=utf-8
2
+ import multiprocessing as mp
3
+ import warnings
4
+
5
+ import requests
6
+ import tiktoken
7
+ from tqdm import tqdm
8
+ from dataclasses import dataclass, field
9
+ from typing import (
10
+ AbstractSet,
11
+ Any,
12
+ Callable,
13
+ Collection,
14
+ Dict,
15
+ Generator,
16
+ List,
17
+ Literal,
18
+ Mapping,
19
+ Optional,
20
+ Set,
21
+ Tuple,
22
+ Union,
23
+ )
24
+ from pydantic import Extra, Field, root_validator
25
+ from loguru import logger
26
+
27
+ from langchain.llms.base import BaseLLM
28
+ from langchain.schema import Generation, LLMResult
29
+ from langchain.utils import get_from_dict_or_env
30
+
31
+ from langchain.callbacks.manager import (
32
+ AsyncCallbackManagerForLLMRun,
33
+ CallbackManagerForLLMRun,
34
+ )
35
+ import sys
36
+ import json
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class ChatGPTConfig:
41
+ r"""Defines the parameters for generating chat completions using the
42
+ OpenAI API.
43
+
44
+ Args:
45
+ temperature (float, optional): Sampling temperature to use, between
46
+ :obj:`0` and :obj:`2`. Higher values make the output more random,
47
+ while lower values make it more focused and deterministic.
48
+ (default: :obj:`0.2`)
49
+ top_p (float, optional): An alternative to sampling with temperature,
50
+ called nucleus sampling, where the model considers the results of
51
+ the tokens with top_p probability mass. So :obj:`0.1` means only
52
+ the tokens comprising the top 10% probability mass are considered.
53
+ (default: :obj:`1.0`)
54
+ n (int, optional): How many chat completion choices to generate for
55
+ each input message. ()default: :obj:`1`)
56
+ stream (bool, optional): If True, partial message deltas will be sent
57
+ as data-only server-sent events as they become available.
58
+ (default: :obj:`False`)
59
+ stop (str or list, optional): Up to :obj:`4` sequences where the API
60
+ will stop generating further tokens. (default: :obj:`None`)
61
+ max_tokens (int, optional): The maximum number of tokens to generate
62
+ in the chat completion. The total length of input tokens and
63
+ generated tokens is limited by the model's context length.
64
+ (default: :obj:`None`)
65
+ presence_penalty (float, optional): Number between :obj:`-2.0` and
66
+ :obj:`2.0`. Positive values penalize new tokens based on whether
67
+ they appear in the text so far, increasing the model's likelihood
68
+ to talk about new topics. See more information about frequency and
69
+ presence penalties. (default: :obj:`0.0`)
70
+ frequency_penalty (float, optional): Number between :obj:`-2.0` and
71
+ :obj:`2.0`. Positive values penalize new tokens based on their
72
+ existing frequency in the text so far, decreasing the model's
73
+ likelihood to repeat the same line verbatim. See more information
74
+ about frequency and presence penalties. (default: :obj:`0.0`)
75
+ logit_bias (dict, optional): Modify the likelihood of specified tokens
76
+ appearing in the completion. Accepts a json object that maps tokens
77
+ (specified by their token ID in the tokenizer) to an associated
78
+ bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
79
+ is added to the logits generated by the model prior to sampling.
80
+ The exact effect will vary per model, but values between:obj:` -1`
81
+ and :obj:`1` should decrease or increase likelihood of selection;
82
+ values like :obj:`-100` or :obj:`100` should result in a ban or
83
+ exclusive selection of the relevant token. (default: :obj:`{}`)
84
+ user (str, optional): A unique identifier representing your end-user,
85
+ which can help OpenAI to monitor and detect abuse.
86
+ (default: :obj:`""`)
87
+ """
88
+ temperature: float = 1.0 # openai default: 1.0
89
+ top_p: float = 1.0
90
+ max_in_tokens: int = 3200
91
+ timeout: int = 20
92
+
93
+
94
+ def get_userid_and_token(
95
+ url='http://avatar.aicubes.cn/vtuber/auth/api/oauth/v1/login',
96
+ app_id='6027294018fd496693d0b8c77e2d20a1',
97
+ app_secret='52806a6fff8a452497061b9dcc5779f4'
98
+ ):
99
+ d = {'app_id': app_id, 'app_secret': app_secret}
100
+ h = {'Content-Type': 'application/json'}
101
+ r = requests.post(url, json=d, headers=h)
102
+ data = r.json()['data']
103
+ return data['user_id'], data['token']
104
+
105
+
106
+ class ChatAPI:
107
+ def __init__(self, timeout=20, verbose=False) -> None:
108
+ self.timeout = timeout
109
+ self.verbose = verbose
110
+ self.user_id, self.token = get_userid_and_token()
111
+
112
+ def create_chat_completion(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None) -> str:
113
+ res = self.create_chat_completion_response_data(messages, model, temperature, max_tokens)
114
+ return res['choices'][0]['message']['content']
115
+
116
+ def create_chat_completion_response_data(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None):
117
+ res = self.create_chat_completion_response(messages, model, temperature, max_tokens)
118
+ res = res.json()['data']
119
+ return res
120
+
121
+ def create_chat_completion_response(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None):
122
+ chat_url = 'http://avatar.aicubes.cn/vtuber/ai_access/chatgpt/v1/chat/completions'
123
+ chat_header = {
124
+ 'Content-Type': 'application/json',
125
+ 'userId': self.user_id,
126
+ 'token': self.token
127
+ }
128
+ payload = {
129
+ 'model': model,
130
+ 'messages': messages,
131
+ 'temperature': temperature,
132
+ 'max_tokens': max_tokens,
133
+ }
134
+ timeout = self.timeout
135
+ res = requests.post(chat_url, json=payload, headers=chat_header, timeout=timeout)
136
+ if self.verbose:
137
+ data = res.json()["data"]
138
+ if data is None:
139
+ logger.debug(res.json())
140
+ else:
141
+ logger.debug(data["choices"][0]["message"]["content"])
142
+ return res
143
+
144
+
145
+ class OpenAIChat(BaseLLM):
146
+ """Wrapper around OpenAI Chat large language models.
147
+
148
+ To use, you should have the ``openai`` python package installed, and the
149
+ environment variable ``OPENAI_API_KEY`` set with your API key.
150
+
151
+ Any parameters that are valid to be passed to the openai.create call can be passed
152
+ in, even if not explicitly saved on this class.
153
+
154
+ Example:
155
+ .. code-block:: python
156
+
157
+ from langchain.llms import OpenAIChat
158
+ openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
159
+ """
160
+
161
+ model_name: str = "gpt-3.5-turbo"
162
+ """Model name to use."""
163
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
164
+ """Holds any model parameters valid for `create` call not explicitly specified."""
165
+ max_retries: int = 6
166
+ """Maximum number of retries to make when generating."""
167
+ prefix_messages: List = Field(default_factory=list)
168
+ """Series of messages for Chat input."""
169
+ streaming: bool = False
170
+ """Whether to stream the results or not."""
171
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
172
+ """Set of special tokens that are allowedใ€‚"""
173
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all"
174
+ """Set of special tokens that are not allowedใ€‚"""
175
+ api = ChatAPI(timeout=60)
176
+ generate_verbose: bool = False
177
+
178
+ class Config:
179
+ """Configuration for this pydantic object."""
180
+
181
+ extra = Extra.ignore
182
+
183
+ @root_validator(pre=True)
184
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
185
+ """Build extra kwargs from additional params that were passed in."""
186
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
187
+
188
+ extra = values.get("model_kwargs", {})
189
+ for field_name in list(values):
190
+ if field_name not in all_required_field_names:
191
+ if field_name in extra:
192
+ raise ValueError(f"Found {field_name} supplied twice.")
193
+ extra[field_name] = values.pop(field_name)
194
+ values["model_kwargs"] = extra
195
+ return values
196
+
197
+ @root_validator()
198
+ def validate_environment(cls, values: Dict) -> Dict:
199
+ """Validate that api key and python package exists in environment."""
200
+ return values
201
+
202
+ @property
203
+ def _default_params(self) -> Dict[str, Any]:
204
+ """Get the default parameters for calling OpenAI API."""
205
+ return self.model_kwargs
206
+
207
+ def _get_chat_params(
208
+ self, prompts: List[str], stop: Optional[List[str]] = None
209
+ ) -> Tuple:
210
+ if len(prompts) > 1:
211
+ raise ValueError(
212
+ f"OpenAIChat currently only supports single prompt, got {prompts}"
213
+ )
214
+ messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
215
+ params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
216
+ if stop is not None:
217
+ if "stop" in params:
218
+ raise ValueError("`stop` found in both the input and default params.")
219
+ params["stop"] = stop
220
+ if params.get("max_tokens") == -1:
221
+ # for ChatGPT api, omitting max_tokens is equivalent to having no limit
222
+ del params["max_tokens"]
223
+ return messages, params
224
+
225
+ def _generate(
226
+ self,
227
+ prompts: List[str],
228
+ stop: Optional[List[str]] = None,
229
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
230
+ ) -> LLMResult:
231
+ messages, params = self._get_chat_params(prompts, stop)
232
+ if self.generate_verbose:
233
+ logger.debug(json.dumps(params, indent=2))
234
+ for msg in messages:
235
+ logger.debug(msg["role"] + " : " + msg["content"])
236
+ resp = self.api.create_chat_completion_response_data(messages, self.model_name, self.model_kwargs['temperature'])
237
+ full_response = resp
238
+ llm_output = {
239
+ "token_usage": full_response["usage"],
240
+ "model_name": self.model_name,
241
+ }
242
+ return LLMResult(
243
+ generations=[
244
+ [Generation(text=full_response["choices"][0]["message"]["content"])]
245
+ ],
246
+ llm_output=llm_output,
247
+ )
248
+
249
+ async def _agenerate(
250
+ self,
251
+ prompts: List[str],
252
+ stop: Optional[List[str]] = None,
253
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
254
+ ) -> LLMResult:
255
+ # messages, params = self._get_chat_params(prompts, stop)
256
+ # full_response = await acompletion_with_retry(
257
+ # self, messages=messages, **params
258
+ # )
259
+ # llm_output = {
260
+ # "token_usage": full_response["usage"],
261
+ # "model_name": self.model_name,
262
+ # }
263
+ # return LLMResult(
264
+ # generations=[
265
+ # [Generation(text=full_response["choices"][0]["message"]["content"])]
266
+ # ],
267
+ # llm_output=llm_output,
268
+ # )
269
+ raise NotImplementedError("Async not supported for OpenAIChat")
270
+
271
+ @property
272
+ def _identifying_params(self) -> Mapping[str, Any]:
273
+ """Get the identifying parameters."""
274
+ return {**{"model_name": self.model_name}, **self._default_params}
275
+
276
+ @property
277
+ def _llm_type(self) -> str:
278
+ """Return type of llm."""
279
+ return "openai-chat"
280
+
281
+ def get_num_tokens(self, text: str) -> int:
282
+ """Calculate num tokens with tiktoken package."""
283
+ # tiktoken NOT supported for Python < 3.8
284
+ if sys.version_info[1] < 8:
285
+ return super().get_num_tokens(text)
286
+ try:
287
+ import tiktoken
288
+ except ImportError:
289
+ raise ValueError(
290
+ "Could not import tiktoken python package. "
291
+ "This is needed in order to calculate get_num_tokens. "
292
+ "Please install it with `pip install tiktoken`."
293
+ )
294
+ # create a GPT-3.5-Turbo encoder instance
295
+ enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
296
+
297
+ # encode the text using the GPT-3.5-Turbo encoder
298
+ tokenized_text = enc.encode(
299
+ text,
300
+ allowed_special=self.allowed_special,
301
+ disallowed_special=self.disallowed_special,
302
+ )
303
+
304
+ # calculate the number of tokens in the encoded text
305
+ return len(tokenized_text)
306
+
307
+ class ChatSession:
308
+ def __init__(self, prompt: str = '', chatgpt_config: ChatGPTConfig = ChatGPTConfig()) -> None:
309
+ self.chatgpt_config = chatgpt_config.__dict__
310
+ self.user_id, self.token = self.get_userid_and_token()
311
+ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-0301")
312
+ self.count = lambda x: len(encoding.encode(x))
313
+ self.history = []
314
+ self.system = [self.make_msg("system", prompt)] if prompt else []
315
+
316
+ def restart(self, prompt: str = '') -> None:
317
+ self.system = [self.make_msg("system", prompt)] if prompt else []
318
+
319
+ @staticmethod
320
+ def make_msg(role: str, msg: str) -> Dict:
321
+ assert role in {"system", "assistant", "user"}
322
+ return {"role": role, "content": msg}
323
+
324
+ @staticmethod
325
+ def get_userid_and_token(
326
+ url='http://avatar.aicubes.cn/vtuber/auth/api/oauth/v1/login',
327
+ app_id='6027294018fd496693d0b8c77e2d20a1',
328
+ app_secret='52806a6fff8a452497061b9dcc5779f4'
329
+ ):
330
+ d = {'app_id': app_id, 'app_secret': app_secret}
331
+ h = {'Content-Type': 'application/json'}
332
+ r = requests.post(url, json=d, headers=h)
333
+ data = r.json()['data']
334
+ return data['user_id'], data['token']
335
+
336
+ def make_chat_session(self, user_id: str, token: str, input_message: List[Dict[str, str]]):
337
+ chat_h = {
338
+ 'Content-Type': 'application/json',
339
+ 'userId': user_id,
340
+ 'token': token
341
+ }
342
+ chat_url = 'http://avatar.aicubes.cn/vtuber/ai_access/chatgpt/v1/chat/completions'
343
+ res = requests.post(chat_url, json={
344
+ 'messages': input_message, **self.chatgpt_config
345
+ }, headers=chat_h, timeout=self.chatgpt_config['timeout'])
346
+ return res.json()['data']['choices'][0]['message']['content']
347
+
348
+ def create_chat_completion(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None) -> str:
349
+ chat_url = 'http://avatar.aicubes.cn/vtuber/ai_access/chatgpt/v1/chat/completions'
350
+ chat_header = {
351
+ 'Content-Type': 'application/json',
352
+ 'userId': self.user_id,
353
+ 'token': self.token
354
+ }
355
+ payload = {
356
+ 'model': model,
357
+ 'messages': messages,
358
+ 'temperature': temperature,
359
+ 'max_tokens': max_tokens,
360
+ }
361
+ timeout = self.chatgpt_config['timeout']
362
+ res = requests.post(chat_url, json=payload, headers=chat_header, timeout=timeout)
363
+ return res.json()['data']['choices'][0]['message']['content']
364
+
365
+ def chat(self, msg: str):
366
+ self.history.append(self.make_msg("user", msg))
367
+ init_tokenCnt = self.count(self.system[0]['content']) if self.system else 0
368
+ inputStaMsgIdx, tokenCnt = len(self.history), init_tokenCnt
369
+ while inputStaMsgIdx and (
370
+ tokenCnt := tokenCnt + self.count(self.history[inputStaMsgIdx - 1]['content'])) < \
371
+ self.chatgpt_config['max_in_tokens']:
372
+ inputStaMsgIdx -= 1
373
+ inputStaMsgIdx = inputStaMsgIdx if inputStaMsgIdx < len(self.history) else -1
374
+ res = self.make_chat_session(self.user_id, self.token, self.system + self.history[inputStaMsgIdx:])
375
+ self.history.append(self.make_msg("assistant", res))
376
+ return res
377
+
378
+
379
+ def batch_chat(info_lst: List, request_num: int = 6) -> List:
380
+ res = []
381
+ pool = mp.Pool(processes=request_num)
382
+ for id, res_text in tqdm(pool.imap(single_chat, info_lst), desc="Asking API", total=len(info_lst)):
383
+ if res_text:
384
+ res.append((id, res_text))
385
+
386
+ return res
387
+
388
+
389
+ def single_chat(info: Dict) -> (int, str):
390
+ sess = ChatSession(info['sys'], info['config'])
391
+ try:
392
+ res = sess.chat(info['query'])
393
+ return info['id'], res
394
+ except Exception as e:
395
+ print(e)
396
+ return info['id'], ""
397
+
398
+
399
+ if __name__ == '__main__':
400
+
401
+ sys_prompt = """
402
+ ไฝ ๆ˜ฏไธ€ไฝไธฅๆ ผ็š„่ฏ„ๅˆ†ๅ‘˜๏ผŒๆˆ‘ไผš็ป™ไฝ ไธ€ไธชๆŒ‡ไปคๅ’Œ่ฟ™ไธชๆŒ‡ไปค็š„ๅ›žๅค๏ผŒไฝ ้œ€่ฆไป”็ป†ๆฃ€ๆŸฅๅ›žๅคๅนถ็ป™ๅ‡บๅˆ†ๆ•ฐ๏ผŒไฝ ๅฏไปฅไปŽๅคšไธช่ง’ๅบฆ่ฏ„ๅˆค่ฟ™ไธชๅ›žๅค๏ผŒๆฏ”ๅฆ‚๏ผš
403
+ ๅ›žๅคๆ˜ฏๅฆๅ‡†็กฎใ€ๆ˜ฏๅฆ่ฏฆๅฐฝใ€ๆ˜ฏๅฆๆ— ๅฎณใ€ๆ˜ฏๅฆๅฎŒๅ…จ็ฌฆๅˆๆŒ‡ไปค้‡Œ็š„่ฆๆฑ‚๏ผŒ็ญ‰็ญ‰ใ€‚ๅˆ†ๆ•ฐๅˆ†ไธบ5ไธช็ญ‰็บง๏ผŒ1ๅˆ†๏ผšๅฎŒๅ…จไธๅฏ็”จ๏ผŒ2ๅˆ†๏ผšไธๅฏ็”จไฝ†ๅฎŒๆˆไบ†้ƒจๅˆ†ๆŒ‡ไปค๏ผŒ
404
+ 3ๅˆ†๏ผšๅฏ็”จไฝ†ๆœ‰ๆ˜Žๆ˜พ็ผบ้™ท๏ผŒ4ๅˆ†๏ผšๅฏ็”จไฝ†ๆœ‰ๅฐ‘่ฎธ็ผบ้™ท๏ผŒ5ๅˆ†๏ผšๅฏ็”จไธ”ๆฒกๆœ‰็ผบ้™ทใ€‚ไฝ ๅœจๅทฅไฝœๆ—ถ้œ€่ฆๅŠ ๅ…ฅ่‡ชๅทฑ็š„ๆ€่€ƒ๏ผŒๅนถๅœจๆœ€ๅŽ็ป™ๅ‡บๅˆ†ๆ•ฐใ€‚
405
+ ไธ‹้ขๆ˜ฏไธ€ไธชไพ‹ๅญ๏ผš
406
+ User: \n\n<ๆŒ‡ไปค>้ฉฌไบ‘็š„ๅฆปๅญๆ˜ฏ่ฐ?</ๆŒ‡ไปค>\n\n<ๅ›žๅค>้ฉฌไบ‘็š„ๅฆปๅญๆ˜ฏๅผ ่‹ฑ็ชใ€‚</ๅ›žๅค>
407
+ Assistant: ่ฟ™ไธชๅ›žๅค้”™่ฏฏ๏ผŒ้ฉฌไบ‘ๆ˜ฏ้˜ฟ้‡Œๅทดๅทดๅˆ›ๅง‹ไบบ๏ผŒไป–็š„ๅฆปๅญๆ˜ฏๅผ ็‘›๏ผŒๅ› ๆญคๅ›žๅค้”™่ฏฏ๏ผŒๅ› ๆญค๏ผŒๆˆ‘็š„ๅˆ†ๆ•ฐๆ˜ฏ[1ๅˆ†]ใ€‚
408
+ """
409
+
410
+ aaa = """
411
+ fq(xm, m) = (Wqxm)e^(imฮธ)
412
+ fk(xn, n) = (Wkxn)e^(inฮธ)
413
+ g(xm, xn, m โˆ’ n) = Re[(Wqxm)(Wkxn)โˆ—e^(i(mโˆ’n)ฮธ)]
414
+ """
415
+ prompt = 'User: \n\n<ๆŒ‡ไปค>ๅงšๆ˜Žๅคš้ซ˜</ๆŒ‡ไปค>\n\n<ๅ›žๅค>18m</ๅ›žๅค>\nAssistant:'
416
+ bbb = "The given equation defines a function g(xm, xn, m-n) in terms of two complex functions fq(xm, m) and fk(xn, n) and their corresponding Fourier coefficients Wq and Wk, respectively. The function g(xm, xn, m-n) takes the real part of the product of the two complex exponential terms with phase angles m-theta and n-theta, respectively, where theta is an arbitrary constant angle. The term (m-n)theta in the exponent indicates that the two exponential terms are shifted by a phase difference of (m-n)theta."
417
+ session = ChatSession('่งฃ้‡Šๅ…ฌๅผ็š„ๅซไน‰')
418
+ # print(session.chat(aaa))
419
+ print(session.chat("ไฝ ๆ˜ฏ่ฐ๏ผŸ่ฐๅˆ›้€ ไบ†ไฝ ๏ผŸไฝ ็š„็Ÿฅ่ฏ†ๆˆชๆญขไบŽไป€ไนˆๆ—ถๅ€™๏ผŸไฝ ๅฏไปฅ็ป™่‡ชๅทฑๅ–ไธ€ไธชๅๅญ—๏ผŒ่ฏทๅ‘Š่ฏ‰ๆˆ‘ไฝ ็š„ๅๅญ—"))
420
+
app copy.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import huggingface_hub
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import shutil
6
+ import os
7
+ import datetime
8
+ from apscheduler.schedulers.background import BackgroundScheduler
9
+
10
+
11
+ DB_FILE = "./app.db"
12
+
13
+ TOKEN = os.environ.get('HUB_TOKEN')
14
+ repo = huggingface_hub.Repository(
15
+ local_dir="data",
16
+ repo_type="dataset",
17
+ clone_from="linxy/oh-my-words",
18
+ use_auth_token=TOKEN
19
+ )
20
+ repo.git_pull()
21
+
22
+ # Set db to latest
23
+ shutil.copyfile("./data/app.db", DB_FILE)
24
+
25
+
26
+ # Create table if it doesn't already exist
27
+
28
+ db = sqlite3.connect(DB_FILE)
29
+ try:
30
+ db.execute("SELECT * FROM reviews").fetchall()
31
+ db.close()
32
+ except sqlite3.OperationalError:
33
+ db.execute(
34
+ '''
35
+ CREATE TABLE reviews (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
36
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
37
+ name TEXT, review INTEGER, comments TEXT)
38
+ ''')
39
+ db.commit()
40
+ db.close()
41
+
42
+
43
+ def get_latest_reviews(db: sqlite3.Connection):
44
+ reviews = db.execute("SELECT * FROM reviews ORDER BY id DESC limit 10").fetchall()
45
+ total_reviews = db.execute("Select COUNT(id) from reviews").fetchone()[0]
46
+ reviews = pd.DataFrame(reviews, columns=["id", "date_created", "name", "review", "comments"])
47
+ return reviews, total_reviews
48
+
49
+
50
+ def add_review(name: str, review: int, comments: str):
51
+ db = sqlite3.connect(DB_FILE)
52
+ cursor = db.cursor()
53
+ cursor.execute("INSERT INTO reviews(name, review, comments) VALUES(?,?,?)", [name, review, comments])
54
+ db.commit()
55
+ reviews, total_reviews = get_latest_reviews(db)
56
+ db.close()
57
+ return reviews, total_reviews
58
+
59
+ def load_data():
60
+ db = sqlite3.connect(DB_FILE)
61
+ reviews, total_reviews = get_latest_reviews(db)
62
+ db.close()
63
+ return reviews, total_reviews
64
+
65
+
66
+ with gr.Blocks() as demo:
67
+ with gr.Row():
68
+ with gr.Column():
69
+ name = gr.Textbox(label="Name", placeholder="What is your name?")
70
+ review = gr.Radio(label="How satisfied are you with using gradio?", choices=[1, 2, 3, 4, 5])
71
+ comments = gr.Textbox(label="Comments", lines=10, placeholder="Do you have any feedback on gradio?")
72
+ submit = gr.Button(value="Submit Feedback")
73
+ with gr.Column():
74
+ with gr.Box():
75
+ gr.Markdown("Most recently created 10 rows: See full dataset [here](https://huggingface.co/datasets/freddyaboulton/gradio-reviews)")
76
+ data = gr.Dataframe()
77
+ count = gr.Number(label="Total number of reviews")
78
+ submit.click(add_review, [name, review, comments], [data, count])
79
+ demo.load(load_data, None, [data, count])
80
+
81
+
82
+ def backup_db():
83
+ shutil.copyfile(DB_FILE, "./data/reviews.db")
84
+ db = sqlite3.connect(DB_FILE)
85
+ reviews = db.execute("SELECT * FROM reviews").fetchall()
86
+ pd.DataFrame(reviews).to_csv("./data/reviews.csv", index=False)
87
+ print("updating db")
88
+ repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.datetime.now()}")
89
+
90
+
91
+ scheduler = BackgroundScheduler()
92
+ scheduler.add_job(func=backup_db, trigger="interval", seconds=60)
93
+ scheduler.start()
94
+
95
+
96
+ demo.launch()
app.py CHANGED
@@ -6,9 +6,10 @@ import shutil
6
  import os
7
  import datetime
8
  from apscheduler.schedulers.background import BackgroundScheduler
9
-
10
-
11
- DB_FILE = "./reviews.db"
 
12
 
13
  TOKEN = os.environ.get('HUB_TOKEN')
14
  repo = huggingface_hub.Repository(
@@ -18,73 +19,20 @@ repo = huggingface_hub.Repository(
18
  use_auth_token=TOKEN
19
  )
20
  repo.git_pull()
21
-
22
  # Set db to latest
23
- shutil.copyfile("./data/reviews.db", DB_FILE)
24
-
25
-
26
- # Create table if it doesn't already exist
27
-
28
- db = sqlite3.connect(DB_FILE)
29
- try:
30
- db.execute("SELECT * FROM reviews").fetchall()
31
- db.close()
32
- except sqlite3.OperationalError:
33
- db.execute(
34
- '''
35
- CREATE TABLE reviews (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
36
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
37
- name TEXT, review INTEGER, comments TEXT)
38
- ''')
39
- db.commit()
40
- db.close()
41
-
42
-
43
- def get_latest_reviews(db: sqlite3.Connection):
44
- reviews = db.execute("SELECT * FROM reviews ORDER BY id DESC limit 10").fetchall()
45
- total_reviews = db.execute("Select COUNT(id) from reviews").fetchone()[0]
46
- reviews = pd.DataFrame(reviews, columns=["id", "date_created", "name", "review", "comments"])
47
- return reviews, total_reviews
48
-
49
-
50
- def add_review(name: str, review: int, comments: str):
51
- db = sqlite3.connect(DB_FILE)
52
- cursor = db.cursor()
53
- cursor.execute("INSERT INTO reviews(name, review, comments) VALUES(?,?,?)", [name, review, comments])
54
- db.commit()
55
- reviews, total_reviews = get_latest_reviews(db)
56
- db.close()
57
- return reviews, total_reviews
58
-
59
- def load_data():
60
- db = sqlite3.connect(DB_FILE)
61
- reviews, total_reviews = get_latest_reviews(db)
62
- db.close()
63
- return reviews, total_reviews
64
-
65
-
66
- with gr.Blocks() as demo:
67
- with gr.Row():
68
- with gr.Column():
69
- name = gr.Textbox(label="Name", placeholder="What is your name?")
70
- review = gr.Radio(label="How satisfied are you with using gradio?", choices=[1, 2, 3, 4, 5])
71
- comments = gr.Textbox(label="Comments", lines=10, placeholder="Do you have any feedback on gradio?")
72
- submit = gr.Button(value="Submit Feedback")
73
- with gr.Column():
74
- with gr.Box():
75
- gr.Markdown("Most recently created 10 rows: See full dataset [here](https://huggingface.co/datasets/freddyaboulton/gradio-reviews)")
76
- data = gr.Dataframe()
77
- count = gr.Number(label="Total number of reviews")
78
- submit.click(add_review, [name, review, comments], [data, count])
79
- demo.load(load_data, None, [data, count])
80
-
81
 
82
  def backup_db():
83
- shutil.copyfile(DB_FILE, "./data/reviews.db")
84
- db = sqlite3.connect(DB_FILE)
85
- reviews = db.execute("SELECT * FROM reviews").fetchall()
86
- pd.DataFrame(reviews).to_csv("./data/reviews.csv", index=False)
87
- print("updating db")
 
 
 
 
88
  repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.datetime.now()}")
89
 
90
 
@@ -92,5 +40,14 @@ scheduler = BackgroundScheduler()
92
  scheduler.add_job(func=backup_db, trigger="interval", seconds=60)
93
  scheduler.start()
94
 
 
 
 
 
 
 
 
95
 
96
- demo.launch()
 
 
 
6
  import os
7
  import datetime
8
  from apscheduler.schedulers.background import BackgroundScheduler
9
+ from database.database import DATABASE_FILE as DB_FILE
10
+ from web import demo
11
+ from loguru import logger
12
+ from common.util import date_str
13
 
14
  TOKEN = os.environ.get('HUB_TOKEN')
15
  repo = huggingface_hub.Repository(
 
19
  use_auth_token=TOKEN
20
  )
21
  repo.git_pull()
22
+ DATASET_FILE = f"./data/{DB_FILE}"
23
  # Set db to latest
24
+ shutil.copyfile(DATASET_FILE, DB_FILE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def backup_db():
27
+ shutil.copyfile(DB_FILE, DATASET_FILE)
28
+ # db = sqlite3.connect(DB_FILE)
29
+ # pd.DataFrame(db.execute("SELECT * FROM words").fetchall()).to_csv("./data/words.csv", index=False)
30
+ # print("save word.csv")
31
+ # pd.DataFrame(db.execute("SELECT * FROM book").fetchall()).to_csv("./data/book.csv", index=False)
32
+ # print("save book.csv")
33
+ # pd.DataFrame(db.execute("SELECT * FROM unit").fetchall()).to_csv("./data/unit.csv", index=False)
34
+ # print("save unit.csv")
35
+ # db.close()
36
  repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.datetime.now()}")
37
 
38
 
 
40
  scheduler.add_job(func=backup_db, trigger="interval", seconds=60)
41
  scheduler.start()
42
 
43
+ # def load_data():
44
+ # db = sqlite3.connect(DB_FILE)
45
+ # reviews, total_reviews = get_latest_reviews(db)
46
+ # db.close()
47
+ # return reviews, total_reviews
48
+
49
+ # demo.load(load_data, None, [data, count])
50
 
51
+ if __name__ == "__main__":
52
+ logger.add(f"output/logs/web_{date_str}.log", rotation="1 day", retention="7 days", level="INFO")
53
+ demo.launch()
common/util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from loguru import logger
3
+ from tqdm import tqdm
4
+ import pandas as pd
5
+ import datetime
6
+
7
+ import multiprocessing
8
+ from multiprocessing import Pool
9
+ cpu_num = multiprocessing.cpu_count()
10
+ logger.info(f"cpu_num: {cpu_num}")
11
+
12
+
13
+ date_str = datetime.datetime.now().strftime("%Y%m%d_%Hh%Mm%Ss")
14
+
15
+
16
+ def multiprocessing_mapping(
17
+ mapping_func,
18
+ items: List[Any],
19
+ batch_size=1000,
20
+ tmp_filepath=f"./output/multiprocessing_mapping_{date_str}_tmp.xlsx",
21
+ ):
22
+ pool = Pool(processes=cpu_num)
23
+ total_rows: List[Dict[str, str]] = []
24
+ for i in tqdm(range(0, len(items), batch_size)):
25
+ new_rows = pool.map(mapping_func, items[i:i+batch_size])
26
+ total_rows += new_rows
27
+ df = pd.DataFrame(total_rows)
28
+ df.to_excel(tmp_filepath, index=False)
29
+ pool.close()
30
+ pool.join()
31
+ return total_rows
database/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .database import SessionLocal, engine, Base
2
+ from .schema import *
3
+ from .operation import *
4
+
5
+
6
+ def get_db():
7
+ db = SessionLocal()
8
+ try:
9
+ yield db
10
+ finally:
11
+ db.close()
database/constant.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ email = "[email protected]"
2
+ password = "123456"
database/database.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import URL, create_engine
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ from sqlalchemy.orm import sessionmaker
4
+
5
+ DATABASE_FILE = "app.db"
6
+ SQLALCHEMY_DATABASE_URL = f"sqlite:///./{DATABASE_FILE}"
7
+ engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
8
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
9
+
10
+ Base = declarative_base()
database/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ItemBase(BaseModel):
7
+ title: str
8
+ description: Union[str, None] = None
9
+
10
+
11
+ class ItemCreate(ItemBase):
12
+ pass
13
+
14
+
15
+ class Item(ItemBase):
16
+ id: int
17
+ owner_id: int
18
+
19
+ class Config:
20
+ orm_mode = True
21
+
22
+
23
+ class UserBase(BaseModel):
24
+ email: str
25
+
26
+
27
+ class UserCreate(UserBase):
28
+ password: str
29
+
30
+
31
+ class User(UserBase):
32
+ id: int
33
+ is_active: bool
34
+ items: List[Item] = []
35
+
36
+ class Config:
37
+ orm_mode = True
database/operation.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from sqlalchemy.orm import Session
3
+ from pydantic import BaseModel
4
+ from typing import List, Optional, Tuple, Dict
5
+ from . import schema
6
+ from sqlalchemy import func, or_
7
+ from sqlalchemy.orm import aliased
8
+ from sqlalchemy.orm import Query
9
+ from collections import defaultdict
10
+
11
+
12
+ # # ๅˆ›ๅปบไธ€ไธชๅˆซๅ็š„Address
13
+ # AddressAlias = aliased(Address)
14
+
15
+ # # ๅˆ›ๅปบไธ€ไธชๅญๆŸฅ่ฏข
16
+ # subquery = session.query(
17
+ # func.count(AddressAlias.id).label('address_count'),
18
+ # AddressAlias.user_id
19
+ # ).group_by(AddressAlias.user_id).subquery()
20
+
21
+ # # ไฝฟ็”จๅญๆŸฅ่ฏขๅ’Œ่ฟžๆŽฅๆŸฅ่ฏข
22
+ # users = session.query(User, subquery.c.address_count).\
23
+ # outerjoin(subquery, User.id == subquery.c.user_id).\
24
+ # order_by(User.id).all()
25
+
26
+ # for user, address_count in users:
27
+ # print(f'User {user.name} has {address_count} addresses.')
28
+
29
+ # region User
30
+ class UserBase(BaseModel):
31
+ email: str
32
+ is_active: bool
33
+
34
+ class UserCreate(UserBase):
35
+ pass
36
+
37
+ class UserUpdate(UserBase):
38
+ hashed_password: Optional[str] # This is optional because you may not always want to update the password
39
+
40
+ class User(UserBase):
41
+ id: str
42
+
43
+ def create_user(db: Session, email: str, password: str):
44
+ db_user = schema.User(email=email, is_active=True)
45
+ db_user.set_password(password)
46
+ db.add(db_user)
47
+ db.commit()
48
+ db.refresh(db_user)
49
+ return db_user
50
+
51
+ def get_user(db: Session, user_id: str) -> Optional[schema.User]:
52
+ return db.query(schema.User).filter(schema.User.id == user_id).first()
53
+
54
+ def get_user_by_email(db: Session, email: str) -> Optional[schema.User]:
55
+ return db.query(schema.User).filter(schema.User.email == email).first()
56
+
57
+ def get_users(db: Session, skip: int = 0, limit: int = 10) -> List[schema.User]:
58
+ return db.query(schema.User).offset(skip).limit(limit).all()
59
+
60
+ def get_all_users(db: Session) -> List[schema.User]:
61
+ return db.query(schema.User).all()
62
+
63
+ def update_user(db: Session, user_id: str, user: UserUpdate):
64
+ db_user = db.query(schema.User).filter(schema.User.id == user_id).first()
65
+ if db_user:
66
+ for key, value in user.dict(exclude_unset=True).items():
67
+ setattr(db_user, key, value)
68
+ db.commit()
69
+ db.refresh(db_user)
70
+ return db_user
71
+
72
+ def delete_user(db: Session, user_id: str):
73
+ db_user = db.query(schema.User).filter(schema.User.id == user_id).first()
74
+ if db_user:
75
+ db.delete(db_user)
76
+ db.commit()
77
+ return db_user
78
+ # endregion
79
+
80
+ # region Item
81
+ class ItemBase(BaseModel):
82
+ title: str
83
+ description: str
84
+
85
+ class ItemCreate(ItemBase):
86
+ pass
87
+
88
+ class ItemUpdate(ItemBase):
89
+ pass
90
+
91
+ class Item(ItemBase):
92
+ id: str
93
+ owner_id: str
94
+
95
+ def create_item(db: Session, item: ItemCreate, owner_id: str):
96
+ db_item = schema.Item(**item.dict(), owner_id=owner_id)
97
+ db.add(db_item)
98
+ db.commit()
99
+ db.refresh(db_item)
100
+ return db_item
101
+
102
+ def get_item(db: Session, item_id: str):
103
+ return db.query(schema.Item).filter(schema.Item.id == item_id).first()
104
+
105
+ def get_items(db: Session, skip: int = 0, limit: int = 10):
106
+ return db.query(schema.Item).offset(skip).limit(limit).all()
107
+
108
+ def update_item(db: Session, item_id: str, item: ItemUpdate):
109
+ db_item = db.query(schema.Item).filter(schema.Item.id == item_id).first()
110
+ if db_item:
111
+ for key, value in item.dict(exclude_unset=True).items():
112
+ setattr(db_item, key, value)
113
+ db.commit()
114
+ db.refresh(db_item)
115
+ return db_item
116
+
117
+ def delete_item(db: Session, item_id: str):
118
+ db_item = db.query(schema.Item).filter(schema.Item.id == item_id).first()
119
+ if db_item:
120
+ db.delete(db_item)
121
+ db.commit()
122
+ return db_item
123
+
124
+ # endregion
125
+
126
+ # region UserBook
127
+ class UserBookBase(BaseModel):
128
+ owner_id: str
129
+ book_id: str
130
+ title: str
131
+ random: bool
132
+ batch_size: int
133
+ memorizing_batch: str = ""
134
+
135
+ class UserBookCreate(UserBookBase):
136
+ pass
137
+
138
+ class UserBookUpdate(UserBookBase):
139
+ pass
140
+
141
+ class UserBook(UserBookBase):
142
+ id: str
143
+
144
+ def create_user_book(db: Session, user_book: UserBookCreate):
145
+ db_user_book = schema.UserBook(**user_book.dict())
146
+ db.add(db_user_book)
147
+ db.commit()
148
+ db.refresh(db_user_book)
149
+ return db_user_book
150
+
151
+ def get_user_book(db: Session, user_book_id: str) -> schema.UserBook:
152
+ return db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
153
+
154
+ def get_user_books_by_owner_id(db: Session, owner_id: str) -> List[schema.UserBook]:
155
+ return db.query(schema.UserBook).filter(schema.UserBook.owner_id == owner_id).all()
156
+
157
+ def get_user_books(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserBook]:
158
+ return db.query(schema.UserBook).offset(skip).limit(limit).all()
159
+
160
+ def update_user_book(db: Session, user_book_id: str, user_book: UserBookUpdate):
161
+ db_user_book = db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
162
+ if db_user_book:
163
+ for key, value in user_book.dict(exclude_unset=True).items():
164
+ setattr(db_user_book, key, value)
165
+ db.commit()
166
+ db.refresh(db_user_book)
167
+ return db_user_book
168
+
169
+ def update_user_book_memorizing_batch(db: Session, user_book_id: str, memorizing_batch: str):
170
+ db_user_book = db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
171
+ if db_user_book:
172
+ db_user_book.memorizing_batch = memorizing_batch
173
+ db.commit()
174
+ db.refresh(db_user_book)
175
+ return db_user_book
176
+
177
+ def delete_user_book(db: Session, user_book_id: str):
178
+ db_user_book = db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
179
+ if db_user_book:
180
+ db.delete(db_user_book)
181
+ db.commit()
182
+ return db_user_book
183
+
184
+ # endregion
185
+
186
+ # region UserMemoryBatch
187
+ class UserMemoryBatchBase(BaseModel):
188
+ user_book_id: str
189
+ story: str
190
+ translated_story: str
191
+ batch_type: str = "ๆ–ฐ่ฏ"
192
+
193
+ class UserMemoryBatchCreate(UserMemoryBatchBase):
194
+ pass
195
+
196
+ class UserMemoryBatchUpdate(UserMemoryBatchBase):
197
+ pass
198
+
199
+ class UserMemoryBatch(UserMemoryBatchBase):
200
+ id: str
201
+
202
+ def create_user_memory_batch(db: Session, memory_batch: UserMemoryBatchCreate):
203
+ db_memory_batch = schema.UserMemoryBatch(**memory_batch.dict())
204
+ db.add(db_memory_batch)
205
+ db.commit()
206
+ db.refresh(db_memory_batch)
207
+ return db_memory_batch
208
+
209
+ def get_user_memory_batch(db: Session, memory_batch_id: str):
210
+ return db.query(schema.UserMemoryBatch).filter(schema.UserMemoryBatch.id == memory_batch_id).first()
211
+
212
+ def get_user_memory_batchs(db: Session, skip: int = 0, limit: int = 10):
213
+ return db.query(schema.UserMemoryBatch).offset(skip).limit(limit).all()
214
+
215
+ def get_user_memory_batches_by_user_book_id(db: Session, user_book_id: str) -> List[schema.UserMemoryBatch]:
216
+ return db.query(schema.UserMemoryBatch).filter(
217
+ schema.UserMemoryBatch.user_book_id == user_book_id
218
+ ).order_by(schema.UserMemoryBatch.create_time).all()
219
+
220
+ def get_new_user_memory_batches_by_user_book_id(db: Session, user_book_id: str) -> List[schema.UserMemoryBatch]:
221
+ return db.query(schema.UserMemoryBatch).filter(
222
+ schema.UserMemoryBatch.user_book_id == user_book_id,
223
+ schema.UserMemoryBatch.batch_type == "ๆ–ฐ่ฏ",
224
+ ).order_by(schema.UserMemoryBatch.create_time).all()
225
+
226
+ def actions_infomation(db: Session, action_query: Query[schema.UserMemoryBatchAction]):
227
+ distinct_actions = action_query.distinct().subquery()
228
+ batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
229
+ batch_id_to_batch = {batch.id: batch for batch in batches}
230
+ batch_id_to_words = {batch.id: get_words_in_batch(db, batch.id) for batch in batches}
231
+ return batches, batch_id_to_batch, batch_id_to_words
232
+
233
+ def get_user_memory_batch_history(db: Session, user_book_id: str):
234
+ action_query = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.action == "end").join(
235
+ schema.UserMemoryBatch, schema.UserMemoryBatch.id == schema.UserMemoryBatchAction.batch_id
236
+ ).filter(schema.UserMemoryBatch.user_book_id == user_book_id)
237
+ actions = action_query.order_by(schema.UserMemoryBatchAction.create_time).all()
238
+
239
+ distinct_actions = action_query.distinct().subquery()
240
+ batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
241
+ batch_id_to_batch = {batch.id: batch for batch in batches}
242
+ batch_id_to_words = {batch.id: get_words_in_batch(db, batch.id) for batch in batches}
243
+ batch_id_to_actions = {batch.id: get_user_memory_actions_in_batch(db, batch.id) for batch in batches}
244
+ return actions, batch_id_to_batch, batch_id_to_words, batch_id_to_actions
245
+
246
+ def get_user_memory_batch_history_in_minutes(db: Session, user_book_id: str, minutes: int):
247
+ action_query = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.action == "end").join(
248
+ schema.UserMemoryBatch, schema.UserMemoryBatch.id == schema.UserMemoryBatchAction.batch_id
249
+ ).filter(
250
+ schema.UserMemoryBatch.user_book_id == user_book_id,
251
+ schema.UserMemoryBatchAction.create_time > datetime.datetime.now() - datetime.timedelta(minutes=minutes),
252
+ ).limit(20)
253
+ actions = action_query.order_by(schema.UserMemoryBatchAction.create_time).all()
254
+ distinct_actions = action_query.distinct().subquery()
255
+ batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
256
+ batch_id_to_batch = {batch.id: batch for batch in batches}
257
+ batch_id_to_words = {batch.id: get_words_in_batch(db, batch.id) for batch in batches}
258
+ # batch_actions = db.query(schema.UserMemoryBatchAction).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatchAction.batch_id).all()
259
+ # return actions, batch_id_to_batch, batch_id_to_words, batch_actions
260
+ return actions, batch_id_to_batch, batch_id_to_words
261
+
262
+
263
+ def get_user_memory_word_history_in_minutes(db: Session, user_book_id: str, minutes: int):
264
+ action_query = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.action == "end").join(
265
+ schema.UserMemoryBatch, schema.UserMemoryBatch.id == schema.UserMemoryBatchAction.batch_id
266
+ ).filter(
267
+ schema.UserMemoryBatch.user_book_id == user_book_id,
268
+ # schema.UserMemoryBatchAction.create_time > datetime.datetime.now() - datetime.timedelta(minutes=minutes),
269
+ ).limit(200)
270
+ distinct_actions = action_query.distinct().subquery()
271
+ batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
272
+ words = [get_words_in_batch(db, batch.id) for batch in batches]
273
+ words = sum(words, [])
274
+ return words
275
+
276
+ def update_user_memory_batch(db: Session, memory_batch_id: str, memory_batch: UserMemoryBatchUpdate):
277
+ db_memory_batch = db.query(schema.UserMemoryBatch).filter(schema.UserMemoryBatch.id == memory_batch_id).first()
278
+ if db_memory_batch:
279
+ for key, value in memory_batch.dict(exclude_unset=True).items():
280
+ setattr(db_memory_batch, key, value)
281
+ db.commit()
282
+ db.refresh(db_memory_batch)
283
+ return db_memory_batch
284
+
285
+ def delete_user_memory_batch(db: Session, memory_batch_id: str):
286
+ db_memory_batch = db.query(schema.UserMemoryBatch).filter(schema.UserMemoryBatch.id == memory_batch_id).first()
287
+ if db_memory_batch:
288
+ db.delete(db_memory_batch)
289
+ db.commit()
290
+ return db_memory_batch
291
+
292
+ # endregion
293
+
294
+ # region UserMemoryBatchAction
295
+ class UserMemoryBatchActionBase(BaseModel):
296
+ batch_id: str
297
+ action: str
298
+
299
+ class UserMemoryBatchActionCreate(UserMemoryBatchActionBase):
300
+ pass
301
+
302
+ class UserMemoryBatchActionUpdate(UserMemoryBatchActionBase):
303
+ pass
304
+
305
+ class UserMemoryBatchAction(UserMemoryBatchActionBase):
306
+ id: str
307
+ create_time: str
308
+ update_time: str
309
+
310
+ def create_user_memory_batch_action(db: Session, memory_batch_action: UserMemoryBatchActionCreate):
311
+ db_memory_batch_action = schema.UserMemoryBatchAction(**memory_batch_action.dict())
312
+ db.add(db_memory_batch_action)
313
+ db.commit()
314
+ db.refresh(db_memory_batch_action)
315
+ return db_memory_batch_action
316
+
317
+ def get_user_memory_batch_action(db: Session, memory_batch_action_id: str):
318
+ return db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.id == memory_batch_action_id).first()
319
+
320
+ def get_user_memory_batch_actions(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserMemoryBatchAction]:
321
+ return db.query(schema.UserMemoryBatchAction).offset(skip).limit(limit).all()
322
+
323
+ def get_user_memory_batch_actions_by_user_memory_batch_id(db: Session, user_memory_batch_id: str) -> List[schema.UserMemoryBatchAction]:
324
+ return db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.batch_id == user_memory_batch_id).all()
325
+
326
+ def get_actions_at_each_batch(db: Session, memory_batch_ids: List[str]) -> List[schema.UserMemoryBatchAction]:
327
+ return db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.batch_id.in_(memory_batch_ids)).all()
328
+
329
+ def get_finished_actions_at_each_batch(db: Session, memory_batch_ids: List[str]) -> List[schema.UserMemoryBatchAction]:
330
+ return db.query(schema.UserMemoryBatchAction).filter(
331
+ schema.UserMemoryBatchAction.batch_id.in_(memory_batch_ids),
332
+ schema.UserMemoryBatchAction.action == "end",
333
+ ).order_by(schema.UserMemoryBatchAction.create_time).all()
334
+
335
+ def update_user_memory_batch_action(db: Session, memory_batch_action_id: str, memory_batch_action: UserMemoryBatchActionUpdate):
336
+ db_memory_batch_action = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.id == memory_batch_action_id).first()
337
+ if db_memory_batch_action:
338
+ for key, value in memory_batch_action.dict(exclude_unset=True).items():
339
+ setattr(db_memory_batch_action, key, value)
340
+ db.commit()
341
+ db.refresh(db_memory_batch_action)
342
+ return db_memory_batch_action
343
+
344
+ def delete_user_memory_batch_action(db: Session, memory_batch_action_id: str):
345
+ db_memory_batch_action = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.id == memory_batch_action_id).first()
346
+ if db_memory_batch_action:
347
+ db.delete(db_memory_batch_action)
348
+ db.commit()
349
+ return db_memory_batch_action
350
+ # endregion
351
+
352
+ # region UserMemoryBatchGenerationHistory
353
+ class UserMemoryBatchGenerationHistoryBase(BaseModel):
354
+ batch_id: str
355
+ story: str
356
+ translated_story: str
357
+
358
+ class UserMemoryBatchGenerationHistoryCreate(UserMemoryBatchGenerationHistoryBase):
359
+ pass
360
+
361
+ class UserMemoryBatchGenerationHistoryUpdate(UserMemoryBatchGenerationHistoryBase):
362
+ pass
363
+
364
+ class UserMemoryBatchGenerationHistory(UserMemoryBatchGenerationHistoryBase):
365
+ id: str
366
+ create_time: str
367
+ update_time: str
368
+
369
+ def create_user_memory_batch_generation_history(db: Session, memory_batch_generation_history: UserMemoryBatchGenerationHistoryCreate):
370
+ db_memory_batch_generation_history = schema.UserMemoryBatchGenerationHistory(**memory_batch_generation_history.dict())
371
+ db.add(db_memory_batch_generation_history)
372
+ db.commit()
373
+ db.refresh(db_memory_batch_generation_history)
374
+ return db_memory_batch_generation_history
375
+
376
+ def get_user_memory_batch_generation_history(db: Session, memory_batch_generation_history_id: str):
377
+ return db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.id == memory_batch_generation_history_id).first()
378
+
379
+ def get_user_memory_batch_generation_historys(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserMemoryBatchGenerationHistory]:
380
+ return db.query(schema.UserMemoryBatchGenerationHistory).offset(skip).limit(limit).all()
381
+
382
+ def get_user_memory_batch_generation_historys_by_user_memory_batch_id(db: Session, user_memory_batch_id: str) -> List[schema.UserMemoryBatchGenerationHistory]:
383
+ return db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.batch_id == user_memory_batch_id).all()
384
+
385
+ def get_generation_historys_at_each_batch(db: Session, memory_batch_ids: List[str]) -> List[schema.UserMemoryBatchGenerationHistory]:
386
+ return db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.batch_id.in_(memory_batch_ids)).all()
387
+
388
+ def get_generation_hostorys_by_user_book_id(db: Session, user_book_id: str) -> Dict[str, Tuple[List[schema.Word], List[schema.UserMemoryBatchGenerationHistory]]]:
389
+ batches = get_user_memory_batches_by_user_book_id(db, user_book_id)
390
+ batch_ids = [batch.id for batch in batches]
391
+ batch_id_to_words_and_history = {}
392
+ for batch_id in batch_ids:
393
+ historys = get_user_memory_batch_generation_historys_by_user_memory_batch_id(db, batch_id)
394
+ if len(historys) == 0:
395
+ continue
396
+ words = get_words_in_batch(db, batch_id)
397
+ batch_id_to_words_and_history[batch_id] = (words, historys)
398
+ return batch_id_to_words_and_history
399
+
400
+ def update_user_memory_batch_generation_history(db: Session, memory_batch_generation_history_id: str, memory_batch_generation_history: UserMemoryBatchGenerationHistoryUpdate):
401
+ db_memory_batch_generation_history = db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.id == memory_batch_generation_history_id).first()
402
+ if db_memory_batch_generation_history:
403
+ for key, value in memory_batch_generation_history.dict(exclude_unset=True).items():
404
+ setattr(db_memory_batch_generation_history, key, value)
405
+ db.commit()
406
+ db.refresh(db_memory_batch_generation_history)
407
+ return db_memory_batch_generation_history
408
+
409
+ def delete_user_memory_batch_generation_history(db: Session, memory_batch_generation_history_id: str):
410
+ db_memory_batch_generation_history = db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.id == memory_batch_generation_history_id).first()
411
+ if db_memory_batch_generation_history:
412
+ db.delete(db_memory_batch_generation_history)
413
+ db.commit()
414
+ return db_memory_batch_generation_history
415
+
416
+ # endregion
417
+
418
+ # region UserMemoryWord
419
+ class UserMemoryWordBase(BaseModel):
420
+ batch_id: str
421
+ word_id: str
422
+
423
+ class UserMemoryWordCreate(UserMemoryWordBase):
424
+ pass
425
+
426
+ class UserMemoryWordUpdate(UserMemoryWordBase):
427
+ pass
428
+
429
+ class UserMemoryWord(UserMemoryWordBase):
430
+ id: str
431
+
432
+ def create_user_memory_word(db: Session, memory_word: UserMemoryWordCreate):
433
+ db_memory_word = schema.UserMemoryWord(**memory_word.dict())
434
+ db.add(db_memory_word)
435
+ db.commit()
436
+ db.refresh(db_memory_word)
437
+ return db_memory_word
438
+
439
+ def get_user_memory_word(db: Session, memory_word_id: str):
440
+ return db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.id == memory_word_id).first()
441
+
442
+ def get_user_memory_words(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserMemoryWord]:
443
+ return db.query(schema.UserMemoryWord).offset(skip).limit(limit).all()
444
+
445
+ def get_user_memory_words_by_batch_id(db: Session, batch_id: str) -> List[schema.UserMemoryWord]:
446
+ return db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.batch_id == batch_id).all()
447
+
448
+ def update_user_memory_word(db: Session, memory_word_id: str, memory_word: UserMemoryWordUpdate):
449
+ db_memory_word = db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.id == memory_word_id).first()
450
+ if db_memory_word:
451
+ for key, value in memory_word.dict(exclude_unset=True).items():
452
+ setattr(db_memory_word, key, value)
453
+ db.commit()
454
+ db.refresh(db_memory_word)
455
+ return db_memory_word
456
+
457
+ def delete_user_memory_word(db: Session, memory_word_id: str):
458
+ db_memory_word = db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.id == memory_word_id).first()
459
+ if db_memory_word:
460
+ db.delete(db_memory_word)
461
+ db.commit()
462
+ return db_memory_word
463
+
464
+ # endregion
465
+
466
+ # region UserMemoryAction
467
+ class UserMemoryActionBase(BaseModel):
468
+ batch_id: str
469
+ word_id: str
470
+ action: str
471
+
472
+ class UserMemoryActionCreate(UserMemoryActionBase):
473
+ pass
474
+
475
+ class UserMemoryActionUpdate(UserMemoryActionBase):
476
+ pass
477
+
478
+ class UserMemoryAction(UserMemoryActionBase):
479
+ id: str
480
+ create_time: str
481
+ update_time: str
482
+
483
+
484
+ def create_user_memory_action(db: Session, memory_action: UserMemoryActionCreate):
485
+ db_memory_action = schema.UserMemoryAction(**memory_action.dict())
486
+ db.add(db_memory_action)
487
+ db.commit()
488
+ db.refresh(db_memory_action)
489
+ return db_memory_action
490
+
491
+ def get_user_memory_action(db: Session, memory_action_id: str) -> schema.UserMemoryAction:
492
+ return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.id == memory_action_id).first()
493
+
494
+ def get_user_memory_actions(db: Session, skip: int = 0, limit: int = 10)-> List[schema.UserMemoryAction]:
495
+ return db.query(schema.UserMemoryAction).offset(skip).limit(limit).all()
496
+
497
+ def get_user_memory_actions_by_word_id(db: Session, word_id: str)-> List[schema.UserMemoryAction]:
498
+ return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.word_id == word_id).all()
499
+
500
+ def get_actions_at_each_word(db: Session, word_ids: List[str]) -> List[schema.UserMemoryAction]:
501
+ return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.word_id.in_(word_ids)).all()
502
+
503
+ def get_user_memory_actions_in_batch(db: Session, batch_id: str) -> List[schema.UserMemoryAction]:
504
+ return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.batch_id == batch_id).all()
505
+
506
+ def update_user_memory_action(db: Session, memory_action_id: str, memory_action: UserMemoryActionUpdate):
507
+ db_memory_action = db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.id == memory_action_id).first()
508
+ if db_memory_action:
509
+ for key, value in memory_action.dict(exclude_unset=True).items():
510
+ setattr(db_memory_action, key, value)
511
+ db.commit()
512
+ db.refresh(db_memory_action)
513
+ return db_memory_action
514
+
515
+ def delete_user_memory_action(db: Session, memory_action_id: str):
516
+ db_memory_action = db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.id == memory_action_id).first()
517
+ if db_memory_action:
518
+ db.delete(db_memory_action)
519
+ db.commit()
520
+ return db_memory_action
521
+
522
+ # endregion
523
+
524
+ # region Book
525
+ class BookBase(BaseModel):
526
+ bk_order: float = 0
527
+ bk_name: str
528
+ bk_item_num: int
529
+ bk_author: str = ""
530
+ bk_comment: str = ""
531
+ bk_organization: str = ""
532
+ bk_publisher: str = ""
533
+ bk_version: str = ""
534
+ permission: str = "private"
535
+ creator: str
536
+
537
+ class BookCreate(BookBase):
538
+ pass
539
+
540
+ class BookUpdate(BookBase):
541
+ pass
542
+
543
+ class Book(BookBase):
544
+ bk_id: str
545
+
546
+ def create_book(db: Session, book: BookCreate):
547
+ db_book = schema.Book(**book.dict())
548
+ db.add(db_book)
549
+ db.commit()
550
+ db.refresh(db_book)
551
+ return db_book
552
+
553
+ def get_book(db: Session, book_id: str):
554
+ return db.query(schema.Book).filter(schema.Book.bk_id == book_id).first()
555
+
556
+ def get_book_by_name(db: Session, book_name: str):
557
+ return db.query(schema.Book).filter(schema.Book.bk_name == book_name).first()
558
+
559
+ def get_books(db: Session, skip: int = 0, limit: int = 10):
560
+ return db.query(schema.Book).offset(skip).limit(limit).all()
561
+
562
+ def get_all_books(db: Session):
563
+ return db.query(schema.Book).all()
564
+
565
+ def get_all_books_for_user(db: Session, user_id: str):
566
+ return db.query(schema.Book).filter(
567
+ or_(schema.Book.creator == user_id, schema.Book.permission == "public")
568
+ ).order_by(schema.Book.permission, schema.Book.create_time).all()
569
+
570
+ def get_book_count(db: Session):
571
+ return db.query(schema.Book).count()
572
+
573
+ def update_book(db: Session, book_id: str, book: BookUpdate):
574
+ db_book = db.query(schema.Book).filter(schema.Book.bk_id == book_id).first()
575
+ if db_book:
576
+ for key, value in book.dict(exclude_unset=True).items():
577
+ setattr(db_book, key, value)
578
+ db.commit()
579
+ db.refresh(db_book)
580
+ return db_book
581
+
582
+ def delete_book(db: Session, book_id: str):
583
+ db_book = db.query(schema.Book).filter(schema.Book.bk_id == book_id).first()
584
+ if db_book:
585
+ db.delete(db_book)
586
+ db.commit()
587
+ return db_book
588
+
589
+ # endregion
590
+
591
+ # region Unit
592
+ class UnitBase(BaseModel):
593
+ bv_book_id: str
594
+ bv_voc_id: str
595
+ bv_flag: int = 1
596
+ bv_tag: str = ""
597
+ bv_order: int = 1
598
+
599
+ class UnitCreate(UnitBase):
600
+ pass
601
+
602
+ class UnitUpdate(UnitBase):
603
+ pass
604
+
605
+ class Unit(UnitBase):
606
+ pass
607
+
608
+ def create_unit(db: Session, unit: UnitCreate):
609
+ db_unit = schema.Unit(**unit.dict())
610
+ db.add(db_unit)
611
+ db.commit()
612
+ db.refresh(db_unit)
613
+ return db_unit
614
+
615
+ def get_unit(db: Session, unit_id: str):
616
+ return db.query(schema.Unit).filter(schema.Unit.bv_id == unit_id).first()
617
+
618
+ def get_units(db: Session, skip: int = 0, limit: int = 10):
619
+ return db.query(schema.Unit).offset(skip).limit(limit).all()
620
+
621
+ def update_unit(db: Session, unit_id: str, unit: UnitUpdate):
622
+ db_unit = db.query(schema.Unit).filter(schema.Unit.bv_id == unit_id).first()
623
+ if db_unit:
624
+ for key, value in unit.dict(exclude_unset=True).items():
625
+ setattr(db_unit, key, value)
626
+ db.commit()
627
+ db.refresh(db_unit)
628
+ return db_unit
629
+
630
+ def delete_unit(db: Session, unit_id: str):
631
+ db_unit = db.query(schema.Unit).filter(schema.Unit.bv_id == unit_id).first()
632
+ if db_unit:
633
+ db.delete(db_unit)
634
+ db.commit()
635
+ return db_unit
636
+
637
+ # endregion
638
+
639
+ # region Word
640
+ class WordBase(BaseModel):
641
+ vc_id: str
642
+ vc_vocabulary: str
643
+ vc_phonetic_uk: str
644
+ vc_phonetic_us: str
645
+ vc_frequency: float
646
+ vc_difficulty: float
647
+ vc_acknowledge_rate: float
648
+
649
+ class WordCreate(WordBase):
650
+ pass
651
+
652
+ class WordUpdate(WordBase):
653
+ pass
654
+
655
+ class Word(WordBase):
656
+ pass
657
+
658
+
659
+ def create_word(db: Session, word: WordCreate, unit_id: str):
660
+ db_word = schema.Word(**word.dict(), vc_id=unit_id)
661
+ db.add(db_word)
662
+ db.commit()
663
+ db.refresh(db_word)
664
+ return db_word
665
+
666
+ def get_word(db: Session, word_id: str):
667
+ return db.query(schema.Word).filter(schema.Word.vc_id == word_id).first()
668
+
669
+ def get_words(db: Session, skip: int = 0, limit: int = 10) -> List[schema.Word]:
670
+ return db.query(schema.Word).offset(skip).limit(limit).all()
671
+
672
+ def get_words_by_vocabulary(db: Session, vocabulary: List[str]) -> List[schema.Word]:
673
+ return db.query(schema.Word).filter(schema.Word.vc_vocabulary.in_(vocabulary)).all()
674
+
675
+ def get_words_by_ids(db: Session, ids: List[str]) -> List[schema.Word]:
676
+ return db.query(schema.Word).filter(schema.Word.vc_id.in_(ids)).all()
677
+
678
+ def get_words_in_batch(db: Session, batch_id: str) -> List[schema.Word]:
679
+ return db.query(schema.Word).join(schema.UserMemoryWord, schema.UserMemoryWord.word_id == schema.Word.vc_id).filter(schema.UserMemoryWord.batch_id == batch_id).all()
680
+
681
+ def get_words_at_each_batch(db: Session, batch_ids: List[str]) -> List[schema.Word]:
682
+ return db.query(schema.Word).join(schema.UserMemoryWord, schema.UserMemoryWord.word_id == schema.Word.vc_id).filter(schema.UserMemoryWord.batch_id.in_(batch_ids)).all()
683
+
684
+ def get_words_in_user_book(db: Session, user_book_id: str) -> List[schema.Word]:
685
+ batches = get_new_user_memory_batches_by_user_book_id(db, user_book_id)
686
+ batch_ids = [batch.id for batch in batches]
687
+ return get_words_at_each_batch(db, batch_ids)
688
+
689
+ def update_word(db: Session, word_id: str, word: WordUpdate):
690
+ db_word = db.query(schema.Word).filter(schema.Word.vc_id == word_id).first()
691
+ if db_word:
692
+ for key, value in word.dict(exclude_unset=True).items():
693
+ setattr(db_word, key, value)
694
+ db.commit()
695
+ db.refresh(db_word)
696
+ return db_word
697
+
698
+ def delete_word(db: Session, word_id: str):
699
+ db_word = db.query(schema.Word).filter(schema.Word.vc_id == word_id).first()
700
+ if db_word:
701
+ db.delete(db_word)
702
+ db.commit()
703
+ return db_word
704
+ # endregion
database/schema.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float, DateTime
3
+ from sqlalchemy.orm import relationship
4
+
5
+ from .database import Base
6
+
7
+ import uuid
8
+ import bcrypt
9
+
10
+ # class ModelBase(Base):
11
+
12
+ # __abstract__ = True
13
+
14
+ # id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
15
+ # created_at = Column(DateTime, default=db.func.current_timestamp())
16
+ # updated_at = Column(DateTime,
17
+ # default=func.current_timestamp(),
18
+ # onupdate=func.current_timestamp())
19
+
20
+ class User(Base):
21
+ __tablename__ = "users"
22
+
23
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
24
+ email = Column(String, unique=True, index=True)
25
+ encrypted_password = Column(String)
26
+ is_active = Column(Boolean, default=True)
27
+
28
+ items = relationship("Item", back_populates="owner")
29
+ # write a sql to update the user's password
30
+ # $2b$12$huJdmqFPzWU.9rumd2wpSOZUnCJ0bufmA4vl5T9PDc7V.xLgWAqSu
31
+ # UPDATE = "UPDATE users SET encrypted_password = '$2b$12$huJdmqFPzWU.9rumd2wpSOZUnCJ0bufmA4vl5T9PDc7V.xLgWAqSu' WHERE email = '[email protected]'"
32
+
33
+ def set_password(self, password: str):
34
+ hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
35
+ self.encrypted_password = str(hashed, encoding='utf-8')
36
+
37
+ def verify_password(self, password: str):
38
+ return bcrypt.checkpw(password.encode(), bytes(self.encrypted_password, encoding='utf-8'))
39
+
40
+ def __str__(self):
41
+ return f"<User {self.email}>"
42
+
43
+ def __repr__(self):
44
+ return f"<User {self.email}>"
45
+
46
+
47
+ class Item(Base):
48
+ __tablename__ = "items"
49
+
50
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
51
+ title = Column(String)
52
+ description = Column(String)
53
+ owner_id = Column(String, ForeignKey("users.id"))
54
+
55
+ owner = relationship("User", back_populates="items")
56
+
57
+ # region ่ฎฐๅฟ†ๅ•่ฏ็›ธๅ…ณ
58
+ class UserBook(Base):
59
+ __tablename__ = "user_book"
60
+
61
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
62
+ owner_id = Column(String, ForeignKey("users.id"))
63
+ book_id = Column(String, ForeignKey("book.bk_id"))
64
+
65
+ title = Column(String)
66
+ batch_size = Column(Integer, default=10)
67
+ random = Column(Boolean, default=True)
68
+ memorizing_batch = Column(String, default='')
69
+
70
+ class UserMemoryBatch(Base):
71
+ __tablename__ = "user_memory_batch"
72
+
73
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
74
+ user_book_id = Column(String, ForeignKey("user_book.id"))
75
+
76
+ story = Column(String, default="")
77
+ translated_story = Column(String, default="")
78
+ batch_type = Column(String, default="ๆ–ฐ่ฏ") # ๆ–ฐ่ฏ, ๅ›žๅฟ†, ๅคไน 
79
+
80
+ create_time = Column(DateTime, default=datetime.now)
81
+ update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
82
+
83
+ words = relationship("UserMemoryWord", back_populates="batch")
84
+
85
+ class UserMemoryBatchAction(Base):
86
+ __tablename__ = "user_memory_batch_action"
87
+ """
88
+ ๅผ€ๅง‹่ฎฐๅฟ†ๅˆฐ็ป“ๆŸ่ฎฐๅฟ†็š„ๆ—ถ้—ด๏ผŒๅฏไปฅ่ฎก็ฎ—่ฎฐๅฟ†ๆ•ˆ็Ž‡
89
+ """
90
+
91
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
92
+ batch_id = Column(String, ForeignKey("user_memory_batch.id"))
93
+
94
+ action = Column(String, default="start") # start or end
95
+
96
+ create_time = Column(DateTime, default=datetime.now)
97
+ update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
98
+
99
+
100
+ class UserMemoryBatchGenerationHistory(Base):
101
+ __tablename__ = "user_memory_batch_generation_history"
102
+ """
103
+ ๅผ€ๅง‹่ฎฐๅฟ†ๅˆฐ็ป“ๆŸ่ฎฐๅฟ†็š„ๆ—ถ้—ด๏ผŒๅฏไปฅ่ฎก็ฎ—่ฎฐๅฟ†ๆ•ˆ็Ž‡
104
+ """
105
+
106
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
107
+ batch_id = Column(String, ForeignKey("user_memory_batch.id"))
108
+
109
+ story = Column(String, default="")
110
+ translated_story = Column(String, default="")
111
+ create_time = Column(DateTime, default=datetime.now)
112
+ update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
113
+
114
+
115
+ class UserMemoryWord(Base):
116
+ __tablename__ = "user_memory_word"
117
+
118
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
119
+ batch_id = Column(String, ForeignKey("user_memory_batch.id"))
120
+ word_id = Column(String, ForeignKey("word.vc_id"))
121
+
122
+ batch = relationship("UserMemoryBatch", back_populates="words")
123
+
124
+ class UserMemoryAction(Base):
125
+ __tablename__ = "user_memory_action"
126
+
127
+ id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
128
+ batch_id = Column(String, ForeignKey("user_memory_batch.id"))
129
+ word_id = Column(String, ForeignKey("word.vc_id"))
130
+
131
+ action = Column(String, default="remenber") # remenber or forget
132
+ create_time = Column(DateTime, default=datetime.now)
133
+ update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
134
+
135
+ # endregion
136
+
137
+ class Book(Base):
138
+ __tablename__ = "book"
139
+
140
+ # {'bk_id': 'd645920e395fedad7bbbed0e',
141
+ # 'bk_parent_id': '6512bd43d9caa6e02c990b0a',
142
+ # 'bk_level': 2,
143
+ # 'bk_order': 2.0,
144
+ # 'bk_name': 'ไบบๆ•™็‰ˆ้ซ˜ไธญ่‹ฑ่ฏญ1 - ๅฟ…ไฟฎ',
145
+ # 'bk_item_num': 315,
146
+ # 'bk_direct_item_num': 315,
147
+ # 'bk_author': 'ๅˆ˜้“ไน‰',
148
+ # 'bk_book': 'ไบบๆ•™็‰ˆๆ™ฎ้€š้ซ˜ไธญ่ฏพ็จ‹ๆ ‡ๅ‡†ๅฎž้ชŒๆ•™็ง‘ไนฆ ่‹ฑ่ฏญ 1 ๅฟ…ไฟฎ',
149
+ # 'bk_comment': '้ป‘ไฝ“๏ผšๆœฌๅ•ๅ…ƒ้‡็‚น่ฏๆฑ‡ๅ’Œ็Ÿญ่ฏญ๏ผ›ๆ— โ€œโ–ณโ€๏ผš่ฏพๆ ‡่ฏๆฑ‡๏ผŒ่ฆๆฑ‚ๆŽŒๆก๏ผ›ๆœ‰โ€œโ–ณโ€๏ผšไธ่ฆๆฑ‚ๆŽŒๆก๏ผˆไผšๅ‡บ็Žฐๅคง้‡็ผฉๅ†™ใ€ไบบๅใ€ๅœฐๅๅ’Œ็Ÿญ่ฏญ๏ผŒ่ฏท้€‰่ƒŒ๏ผ‰ใ€‚',
150
+ # 'bk_organization': 'ไบบๆฐ‘ๆ•™่‚ฒๅ‡บ็‰ˆ็คพ ่ฏพ็จ‹ๆ•™ๆ็ ”็ฉถๆ‰€๏ผ›่‹ฑ่ฏญ่ฏพ็จ‹ๆ•™ๆ็ ”็ฉถๅผ€ๅ‘ไธญๅฟƒ',
151
+ # 'bk_publisher': 'ไบบๆฐ‘ๆ•™่‚ฒๅ‡บ็‰ˆ็คพ',
152
+ # 'bk_version': '2007ๅนด1ๆœˆ็ฌฌ2็‰ˆ',
153
+ # 'bk_flag': '้ป˜่ฎค๏ผš152;้ป‘ไฝ“๏ผš97;ๅ‰โ–ณ๏ผš66'},
154
+ bk_id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
155
+ bk_order = Column(Float)
156
+ bk_name = Column(String)
157
+ bk_item_num = Column(Integer)
158
+ bk_author = Column(String, default='')
159
+ bk_comment = Column(String, default='')
160
+ bk_organization = Column(String, default='')
161
+ bk_publisher = Column(String, default='')
162
+ bk_version = Column(String, default='')
163
+
164
+ permission = Column(String, default='private')
165
+ creator = Column(String, ForeignKey("users.id"))
166
+ create_time = Column(DateTime, default=datetime.now)
167
+ update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
168
+
169
+ def __str__(self):
170
+ return f"{self.bk_name}\n [{self.bk_item_num} words, {self.bk_version}]"
171
+ def __repr__(self):
172
+ return f"<Book {self.bk_name}>"
173
+
174
+ class Unit(Base):
175
+ __tablename__ = "unit"
176
+
177
+ # ['bv_id', 'bv_book_id', 'bv_voc_id', 'bv_flag', 'bv_tag', 'bv_order']
178
+ # {'bv_id': '58450c828958a37d5c10f763',
179
+ # 'bv_book_id': 'd645920e395fedad7bbbed0e',
180
+ # 'bv_voc_id': '57067b9ca172044907c615d7',
181
+ # 'bv_flag': 4,
182
+ # 'bv_tag': 'Unit 1',
183
+ # 'bv_order': 1},
184
+ bv_id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
185
+ bv_voc_id = Column(String, ForeignKey("word.vc_id"))
186
+ bv_book_id = Column(String, ForeignKey("book.bk_id"))
187
+ bv_flag = Column(Integer)
188
+ bv_tag = Column(String)
189
+ bv_order = Column(Integer)
190
+
191
+ class Word(Base):
192
+ __tablename__ = "word"
193
+
194
+ # vc_id>vc_vocabulary>vc_phonetic_uk>vc_phonetic_us>vc_frequency>vc_difficulty>vc_acknowledge_rate
195
+ # 57067c89a172044907c6698e>superspecies>[su:pษ™rsหˆpi:สƒi:z]>[supษšsหˆpiสƒiz]>0.0>1>0.664122
196
+ vc_id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
197
+ vc_vocabulary = Column(String)
198
+ vc_translation = Column(String)
199
+ vc_phonetic_uk = Column(String)
200
+ vc_phonetic_us = Column(String)
201
+ vc_frequency = Column(Float)
202
+ vc_difficulty = Column(Float)
203
+ vc_acknowledge_rate = Column(Float)
204
+
205
+ def __str__(self):
206
+ return f"{self.vc_vocabulary} {self.vc_translation}\n[{self.vc_phonetic_uk}] [{self.vc_phonetic_us}]"
207
+
208
+ def __repr__(self):
209
+ return f"{self.vc_vocabulary} {self.vc_translation}\n[{self.vc_phonetic_uk}] [{self.vc_phonetic_us}]"
memorize.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy.orm import Session
2
+ from typing import List, Tuple
3
+
4
+ from tqdm import tqdm
5
+ from database.operation import *
6
+ from database import schema
7
+ import random
8
+ from loguru import logger
9
+ import math
10
+
11
+
12
+ # ่ฎฐๅ•่ฏ
13
+ from story_agent import generate_story_and_translated_story
14
+ from common.util import date_str, multiprocessing_mapping
15
+
16
+
17
+ def get_words_for_book(db: Session, user_book: UserBook) -> List[schema.Word]:
18
+ book = get_book(db, user_book.book_id)
19
+ if book is None:
20
+ logger.warning("book not found")
21
+ return []
22
+ q = db.query(schema.Word).join(schema.Unit, schema.Unit.bv_voc_id == schema.Word.vc_id)
23
+ words = q.filter(schema.Unit.bv_book_id == book.bk_id).order_by(schema.Word.vc_difficulty).all()
24
+ return words
25
+
26
+
27
+ def save_words_as_book(db: Session, user_id: str, words: List[schema.Word], title: str):
28
+ book = create_book(db, BookCreate(bk_name=f"{title}๏ผˆๅพ…ๅญฆๅ•่ฏ่‡ชๅŠจไฟๅญ˜ไธบๅ•่ฏไนฆ๏ผ‰", bk_item_num=len(words), creator=user_id))
29
+ for i, word in tqdm(enumerate(words)):
30
+ unit = UnitCreate(bv_book_id=book.bk_id, bv_voc_id=word.vc_id)
31
+ db_unit = schema.Unit(**unit.dict())
32
+ db.add(db_unit)
33
+ if i % 500 == 0:
34
+ db.commit()
35
+ db.commit()
36
+ return book
37
+
38
+ def save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]):
39
+ batch_words_str_list = [word.vc_vocabulary for word in batch_words]
40
+ # ๆˆ‘ไปฌๅชๅœจ็ฌฌไธ€ไธชๆ‰นๆฌก็”Ÿๆˆๆ•…ไบ‹ใ€‚ๅŽ้ข็š„ๆ‰นๆฌกๆ นๆฎ็”จๆˆท็š„่ฎฐๅฟ†ๆƒ…ๅ†ต็”Ÿๆˆๆ•…ไบ‹๏ผŒๆๅ‰ 3 ไธชๆ‰นๆฌก็”Ÿๆˆๆ•…ไบ‹
41
+ story, translated_story = generate_story_and_translated_story(batch_words_str_list)
42
+ return save_batch_words_with_story(db, i, user_book_id, batch_words, story, translated_story)
43
+
44
+
45
+ def save_batch_words_with_story(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word], story: str, translated_story: str):
46
+ batch_words_str_list = [word.vc_vocabulary for word in batch_words]
47
+ logger.info(f"{i}, {batch_words_str_list}\n{story}")
48
+ user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate(
49
+ user_book_id=user_book_id,
50
+ story=story,
51
+ translated_story=translated_story
52
+ ))
53
+ create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
54
+ batch_id=user_memory_batch.id,
55
+ story=story,
56
+ translated_story=translated_story
57
+ ))
58
+ for word in batch_words:
59
+ memory_word = UserMemoryWordCreate(
60
+ batch_id=user_memory_batch.id,
61
+ word_id=word.vc_id
62
+ )
63
+ db_memory_word = schema.UserMemoryWord(**memory_word.dict())
64
+ db.add(db_memory_word)
65
+ db.commit()
66
+ return user_memory_batch
67
+
68
+ async def async_save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]):
69
+ save_batch_words(db, i, user_book_id, batch_words)
70
+
71
+ import asyncio
72
+ async def async_save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]):
73
+ for i, batch_words in enumerate(batch_words_list):
74
+ asyncio.ensure_future(async_save_batch_words(db, i+1, user_book_id, batch_words))
75
+
76
+ def transform(batch_words: List[str]):
77
+ story, translated_story = generate_story_and_translated_story(batch_words)
78
+ return {
79
+ "story": story,
80
+ "translated_story": translated_story,
81
+ "words": batch_words
82
+ }
83
+
84
+ def save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]):
85
+ word_str_list = []
86
+ for batch_words in batch_words_list:
87
+ word_str_list.append([word.vc_vocabulary for word in batch_words])
88
+ story_list = multiprocessing_mapping(transform, word_str_list, tmp_filepath=f"./output/logs/save_batch_words_list_{date_str}.xlsx")
89
+ logger.info(f"story_list: {len(story_list)}")
90
+ for i, (batch_words, story) in tqdm(enumerate(zip(batch_words_list, story_list))):
91
+ save_batch_words_with_story(db, i, user_book_id, batch_words, story['story'], story['translated_story'])
92
+
93
+ def track(db: Session, user_book: schema.UserBook, words: List[schema.Word]):
94
+ batch_size = user_book.batch_size
95
+ logger.debug(f"{[w.vc_vocabulary for w in words]}")
96
+ logger.debug(f"batch_size: {batch_size}")
97
+ logger.debug(f"words count: {len(words)}")
98
+ if user_book.random:
99
+ random.shuffle(words)
100
+ else:
101
+ words.sort(key=lambda x: x.vc_frequency, reverse=True) # ๆŒ‰็…ง่ฏ้ข‘ๆŽ’ๅบ๏ผŒ่ฏ้ข‘่ถŠ้ซ˜่ถŠๅฎนๆ˜“่ฎฐไฝ
102
+ logger.debug(f"saving words as book")
103
+ save_words_as_book(db, user_book.owner_id, words, user_book.title)
104
+ logger.debug(f"saved words as book [{user_book.title}]")
105
+ batch_words_list = []
106
+ for i in range(0, len(words), batch_size):
107
+ batch_words = words[i:i+batch_size]
108
+ batch_words_list.append(batch_words)
109
+ logger.debug(f"batch_words_list: {len(batch_words_list)}")
110
+ if len(batch_words_list) == 0:
111
+ return
112
+ first_batch_words = batch_words_list[0]
113
+ user_memory_batch = save_batch_words(db, 0, user_book.id, first_batch_words)
114
+ user_book.memorizing_batch = user_memory_batch.id
115
+ db.commit()
116
+ save_batch_words_list(db, user_book.id, batch_words_list[1:])
117
+ # asyncio.run(async_save_batch_words_list(db, user_book.id, batch_words_list[1:]))
118
+
119
+ def remenber(db: Session, batch_id: str, word_id: str):
120
+ return create_user_memory_action(db, UserMemoryActionCreate(
121
+ batch_id=batch_id,
122
+ word_id=word_id,
123
+ action="remember"
124
+ ))
125
+
126
+ def forget(db: Session, batch_id: str, word_id: str):
127
+ return create_user_memory_action(db, UserMemoryActionCreate(
128
+ batch_id=batch_id,
129
+ word_id=word_id,
130
+ action="forget"
131
+ ))
132
+
133
+ def save_memorizing_word_action(db: Session, batch_id: str, actions: List[Tuple[str, str]]):
134
+ """
135
+ actions: [(word_id, remember | forget)]
136
+ """
137
+ for word_id, action in actions:
138
+ memory_action = UserMemoryActionCreate(
139
+ batch_id=batch_id,
140
+ word_id=word_id,
141
+ action=action
142
+ )
143
+ db_memory_action = schema.UserMemoryAction(**memory_action.dict())
144
+ db.add(db_memory_action)
145
+ db.commit()
146
+
147
+ def on_batch_start(db: Session, user_memory_batch_id: str):
148
+ return create_user_memory_batch_action(db, UserMemoryBatchActionCreate(
149
+ batch_id=user_memory_batch_id,
150
+ action="start"
151
+ ))
152
+
153
+ def on_batch_end(db: Session, user_memory_batch_id: str):
154
+ return create_user_memory_batch_action(db, UserMemoryBatchActionCreate(
155
+ batch_id=user_memory_batch_id,
156
+ action="end"
157
+ ))
158
+
159
+ # def generate_recall_batch(db: Session, user_book: schema.UserBook):
160
+ def generate_next_batch(db: Session, user_book: schema.UserBook,
161
+ minutes: int = 60, k: int = 3):
162
+ # ็”Ÿๆˆไธ‹ไธ€ไธชๆ‰นๆฌก๏ผŒๅ›žๅฟ†ๆ‰นๆˆ–่€…ๅคไน ๆ‰น
163
+ # ๅฆ‚ๆžœๆ˜ฏๆ–ฐ่ฏๆ‰น๏ผŒๅˆ™่ฟ”ๅ›ž None
164
+ left_bound, right_bound = 0.3, 0.6
165
+ user_book_id = user_book.id
166
+ batch_size = user_book.batch_size
167
+ # actions, batch_id_to_batch, batch_id_to_words = get_user_memory_batch_history_in_minutes(db, user_book_id, minutes)
168
+ # memorizing_words = sum(list(batch_id_to_words.values()), [])
169
+ memorizing_words = get_user_memory_word_history_in_minutes(db, user_book_id, minutes)
170
+ if len(memorizing_words) < k * batch_size:
171
+ # 1. ่ฎฐๅฟ†ๆ–ฐ่ฏๆ•ฐ่ฟ‡ๅฐ‘
172
+ # ๆ–ฐ่ฏๆ‰น
173
+ logger.info("ๆ–ฐ่ฏๆ‰น")
174
+ return None
175
+ # ่ฎก็ฎ—่ฎฐๅฟ†ๆ•ˆ็Ž‡
176
+ memory_actions = get_actions_at_each_word(db, [w.vc_id for w in memorizing_words])
177
+ remember_count = defaultdict(int)
178
+ forget_count = defaultdict(int)
179
+ for a in memory_actions:
180
+ if a.action == "remember":
181
+ remember_count[a.word_id] += 1
182
+ else:
183
+ forget_count[a.word_id] += 1
184
+ word_id_to_efficiency = {}
185
+ for word in memorizing_words:
186
+ efficiency = remember_count[word.vc_id] / (remember_count[word.vc_id] + forget_count[word.vc_id])
187
+ word_id_to_efficiency[word.vc_id] = efficiency
188
+ logger.info([(w.vc_vocabulary, word_id_to_efficiency[w.vc_id]) for w in memorizing_words].sort(key=lambda x: x[1]))
189
+ if all([efficiency > right_bound for efficiency in word_id_to_efficiency.values()] + [count > 3 for count in remember_count.values()]):
190
+ # 2. ่ฎฐๅฟ†ๆ•ˆ็Ž‡่ฟ‡้ซ˜
191
+ # ๆ–ฐ่ฏๆ‰น
192
+ logger.info("ๆ–ฐ่ฏๆ‰น")
193
+ return None
194
+ forgot_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if efficiency < left_bound]
195
+ forgot_word_ids.sort(key=lambda x: word_id_to_efficiency[x])
196
+ if len(forgot_word_ids) >= batch_size:
197
+ # 4. ๆญฃๅธธๆƒ…ๅ†ต
198
+ # ๅคไน ๆ‰น
199
+ logger.info("ๅคไน ๆ‰น")
200
+ batch_words = [word for word in memorizing_words if word.vc_id in forgot_word_ids][:batch_size]
201
+ batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True)
202
+ batch_words_str_list = [word.vc_vocabulary for word in batch_words]
203
+ story, translated_story = generate_story_and_translated_story(batch_words_str_list)
204
+ user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate(
205
+ user_book_id=user_book_id,
206
+ story=story,
207
+ translated_story=translated_story,
208
+ batch_type="ๅคไน ",
209
+ ))
210
+ create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
211
+ batch_id=user_memory_batch.id,
212
+ story=story,
213
+ translated_story=translated_story
214
+ ))
215
+ for word in batch_words:
216
+ memory_word = UserMemoryWordCreate(
217
+ batch_id=user_memory_batch.id,
218
+ word_id=word.vc_id
219
+ )
220
+ db_memory_word = schema.UserMemoryWord(**memory_word.dict())
221
+ db.add(db_memory_word)
222
+ db.commit()
223
+ return user_memory_batch
224
+ unfarmiliar_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if left_bound <= efficiency < right_bound]
225
+ unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x])
226
+ if len(unfarmiliar_word_ids) < batch_size:
227
+ # ๆŠŠ่ฎฐไฝๆฌกๆ•ฐๅฐ‘็š„ไนŸๅŠ ่ฟ›ๆฅ
228
+ unfarmiliar_word_ids += [word_id for word_id, count in remember_count.items() if count < 3]
229
+ unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x])
230
+ if len(unfarmiliar_word_ids) >= batch_size:
231
+ # 3. ่ฎฐๅฟ†ๆ•ˆ็Ž‡่ฟ‡ไฝŽ
232
+ # ๅ›žๅฟ†ๆ‰น
233
+ logger.info("ๅ›žๅฟ†ๆ‰น")
234
+ batch_words = [word for word in memorizing_words if word.vc_id in unfarmiliar_word_ids][:batch_size]
235
+ batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True)
236
+ batch_words_str_list = [word.vc_vocabulary for word in batch_words]
237
+ story, translated_story = generate_story_and_translated_story(batch_words_str_list)
238
+ user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate(
239
+ user_book_id=user_book_id,
240
+ story=story,
241
+ translated_story=translated_story,
242
+ batch_type="ๅ›žๅฟ†",
243
+ ))
244
+ create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
245
+ batch_id=user_memory_batch.id,
246
+ story=story,
247
+ translated_story=translated_story
248
+ ))
249
+ for word in batch_words:
250
+ memory_word = UserMemoryWordCreate(
251
+ batch_id=user_memory_batch.id,
252
+ word_id=word.vc_id
253
+ )
254
+ db_memory_word = schema.UserMemoryWord(**memory_word.dict())
255
+ db.add(db_memory_word)
256
+ db.commit()
257
+ return user_memory_batch
258
+ # 5. ๆญฃๅธธๆƒ…ๅ†ต
259
+ # ๆ–ฐ่ฏๆ‰น
260
+ logger.info("ๆ–ฐ่ฏๆ‰น")
261
+ return None
262
+
263
+
264
+
output/logs/.gitkeep ADDED
File without changes
story_agent.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import List, Tuple
3
+
4
+ from loguru import logger
5
+ from pydantic import BaseModel, Field
6
+ from langchain.chains import LLMChain
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.output_parsers import (
9
+ PydanticOutputParser,
10
+ OutputFixingParser,
11
+ )
12
+ # LLM.py ๆ˜ฏๆˆ‘่‡ชๅทฑ็š„่ฏญ่จ€ๆจกๅž‹๏ผŒไฝ ๅฏไปฅ็›ดๆŽฅไฝฟ็”จ openai ็š„
13
+ # from langchain.llms.openai import OpenAIChat
14
+ from LLM import OpenAIChat
15
+
16
+ TEMPLATE = """\
17
+ please write a story at least 5 sentences long, using the words [{words}].
18
+
19
+ {format}
20
+
21
+ Attention! The words should be highlighted, surrounded by "`". Therefore, the story should be in the following format.
22
+ English: ... `word1` ... `word2` ...
23
+ Chinese: ... `ๅ•่ฏ1` ... `ๅ•่ฏ2` ...
24
+ """
25
+
26
+ class Story(BaseModel):
27
+ story: str = Field(description="the story")
28
+ translated_story: str = Field(description="the translated story")
29
+
30
+ llm = OpenAIChat(model_name="gpt-3.5-turbo", temperature=0.3)
31
+ parser = PydanticOutputParser(pydantic_object=Story)
32
+ prompt_template = PromptTemplate(
33
+ template=TEMPLATE,
34
+ input_variables=["words"],
35
+ partial_variables={
36
+ "format": parser.get_format_instructions(),
37
+ }
38
+ )
39
+ parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
40
+ chain = LLMChain(
41
+ llm=llm,
42
+ prompt=prompt_template,
43
+ output_parser=parser,
44
+ verbose=False,
45
+ )
46
+
47
+ def tell_story(words: List[str]):
48
+ count = 0
49
+ while count < 10:
50
+ count += 1
51
+ try:
52
+ resp: Story = chain.run(", ".join(words))
53
+ if len(resp.story.strip()) == 0:
54
+ continue
55
+ if len(resp.translated_story.strip()) == 0:
56
+ continue
57
+ return resp
58
+ except Exception as e:
59
+ logger.error(e)
60
+ logger.error(traceback.format_exc())
61
+ logger.error("retrying...")
62
+ continue
63
+ return Story(story="", translated_story="")
64
+
65
+ def generate_story_and_translated_story(words: List[str]) -> Tuple[str, str]:
66
+ resp = tell_story(words)
67
+ return resp.story, resp.translated_story
68
+
69
+
web.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import gradio as gr
3
+ from database.operation import *
4
+ from memorize import *
5
+ from database import SessionLocal, engine, Base
6
+ import database.schema as schema
7
+ from database import constant
8
+ import time
9
+ import asyncio
10
+
11
+ import pandas as pd
12
+ from collections import defaultdict
13
+ from datetime import datetime
14
+
15
+
16
+ Base.metadata.create_all(bind=engine)
17
+ db = SessionLocal()
18
+
19
+ @contextmanager
20
+ def session_scope():
21
+ try:
22
+ yield db
23
+ db.commit()
24
+ except Exception:
25
+ db.rollback()
26
+ raise
27
+ finally:
28
+ db.close()
29
+ intro = """\
30
+ ็›ฎๆ ‡ๅœบๆ™ฏ๏ผšๅช่€ƒ่™‘่ฎฐไฝๅ•่ฏๅŠๅ…ถๆ„ๆ€๏ผŒไฝฟๅพ—่ƒฝๆ— ้šœ็ข้˜…่ฏป๏ผŒไธ่€ƒ่™‘็”จไบŽๅ†™ไฝœใ€‚
31
+
32
+ ไธป่ฆๆƒณๆณ•๏ผšๆ‰น้‡่ฎฐๅ•่ฏ๏ผŒๆฏๆ‰น n ไธชๅ•่ฏ๏ผŒ่ฟ™ n ไธชๅ•่ฏ็”จ AI ็”Ÿๆˆๆ•…ไบ‹๏ผŒๅค่ฟฐๆ•…ไบ‹ๅณๅฏ่ฎฐไฝๅ•่ฏใ€‚
33
+
34
+ ไธบไป€ไนˆ๏ผŸ
35
+
36
+ - ๆ‰น้‡่ฎฐๅ•่ฏ๏ผŒไธ€ๆฌกๅฏไปฅ่ฎฐไฝ n ไธชๅ•่ฏ๏ผŒ่€Œไธๆ˜ฏไธ€ไธชไธ€ไธช่ฎฐ๏ผŒๆ•ˆ็Ž‡้ซ˜ใ€‚
37
+ - ๅค่ฟฐๆ•…ไบ‹๏ผŒๅณ่ดนๆ›ผๅญฆไน ๆณ•๏ผŒๆ•…ไบ‹ๆ˜ฏๅ•่ฏ็š„่ฎฐๅฟ†ไน‹้”šใ€‚
38
+ - ๅค่ฟฐๆ•…ไบ‹่€Œไธๆ˜ฏๅค่ฟฐๅ•่ฏ๏ผŒๆ•…ไบ‹ๅ…ทๆœ‰่ฟž็ปญๆ€ง๏ผŒๆ›ด็ฌฆๅˆไบบ็ฑปๅคฉๆ€ง๏ผŒๅฎนๆ˜“่ฎฐใ€‚
39
+
40
+ ### ไฝฟ็”จๅปบ่ฎฎ
41
+
42
+ 1. ่ฎฐๅ•่ฏๅ‰๏ผŒๅ…ˆๅฎŒๆ•ด่ฟ‡ไธ€้ๅ…จ้ƒจๅ•่ฏ๏ผŒๅ‰”้™คๅทฒ่ฎฐไฝ็š„ๅ•่ฏ๏ผŒไปŽ่€Œๆ้ซ˜ๆ–ฐ่ฏๅฏ†ๅบฆ
43
+ 2. ๅ…ˆ็œ‹ๅ•่ฏ่กจๆ ผ๏ผŒ็„ถๅŽ็œ‹่‹ฑๆ–‡ๆ•…ไบ‹๏ผŒๅฏน็…งไธญๆ–‡ๅฎŒๆˆ่ฎฐๅฟ†
44
+ 3. ่ฎฐๅฟ†ๅฎŒๆˆๅŽ้œ€่ฆไธ€ไธชไธ€ไธชๅ‹พ้€‰ๅทฒ่ฎฐไฝ็š„ๅ•่ฏ๏ผŒๅ‹พ้€‰ๆ—ถๅฐ่ฏ•ๅค่ฟฐๅ•่ฏๆ„ๆ€๏ผŒไปฅๆญคๆฅๆฃ€้ชŒ่ฎฐๅฟ†ๆ•ˆๆžœ
45
+
46
+ > ๆœฌ้กน็›ฎๅŸบไบŽ[ๅผ€ๆบๆ•ฐๆฎ้›†](https://github.com/LinXueyuanStdio/DictionaryData)๏ผŒๅนถไธ”[ๅผ€ๆบไปฃ็ ](https://github.com/LinXueyuanStdio/oh-my-words)๏ผŒๆฌข่ฟŽๅคงๅฎถ่ดก็Œฎไปฃ็ ๏ฝž
47
+ """
48
+
49
+ with gr.Blocks(title="ๆ‰น้‡่ฎฐๅ•่ฏ") as demo:
50
+ # gr.Markdown("# ๆ‰น้‡่ฎฐๅ•่ฏ")
51
+ gr.HTML("<h1 align=\"center\">ๆ‰น้‡่ฎฐๅ•่ฏ</h1>")
52
+ user = gr.State(value={})
53
+
54
+ # 0. ็™ปๅฝ•
55
+ with gr.Tab("ไธป้กต"):
56
+ gr.Markdown(intro)
57
+ gr.Markdown(f"ๅ…ฑ {get_book_count(db)} ๆœฌไนฆ")
58
+ gr.HTML("""<iframe src="https://ghbtns.com/github-btn.html?user=LinXueyuanStdio&repo=oh-my-words&type=star&count=true&size=small" frameborder="0" scrolling="0" width="170" height="30" title="GitHub"></iframe>""")
59
+ with gr.Row():
60
+ with gr.Column():
61
+ email = gr.TextArea(value=constant.email, lines=1, label="้‚ฎ็ฎฑ")
62
+ password = gr.TextArea(value=constant.password, lines=1, label="ๅฏ†็ ")
63
+ login_btn = gr.Button("็™ปๅฝ•")
64
+ with gr.Column():
65
+ register_email = gr.TextArea(value='', lines=1, label="้‚ฎ็ฎฑ")
66
+ register_password = gr.TextArea(value='', lines=1, label="ๅฏ†็ ")
67
+ register_btn = gr.Button("็ซ‹ๅณๆณจๅ†Œ", variant="primary")
68
+ user_status = gr.Textbox("", lines=1, label="็”จๆˆท็Šถๆ€")
69
+
70
+ # 1. ๅˆ›ๅปบ่ฎฐๅฟ†่ฎกๅˆ’
71
+ tab1 = gr.Tab("ๅˆ›ๅปบ่ฎฐๅฟ†่ฎกๅˆ’", visible=False)
72
+ with tab1:
73
+ select_book = gr.Dropdown([], label="ๅ•่ฏไนฆ", info="้€‰ๆ‹ฉไธ€ๆœฌๅ•่ฏไนฆ")
74
+ batch_size = gr.Number(value=10, label="ๆ‰นๆฌกๅคงๅฐ")
75
+ randomize = gr.Checkbox(value=True, label="ไปฅๅ•่ฏไนฑๅบ่ฟ›่กŒ่ฎฐๅฟ†")
76
+ title = gr.TextArea(value='ๅ•่ฏไนฆ', lines=1, label="่ฎฐๅฟ†่ฎกๅˆ’็š„ๅ็งฐ")
77
+ btn = gr.Button("ๅˆ›ๅปบ่ฎฐๅฟ†่ฎกๅˆ’")
78
+ status = gr.Textbox("", lines=1, label="็Šถๆ€")
79
+
80
+ def submit(user: Dict[str, str], book, title, randomize, batch_size):
81
+ user_id = user.get("id", None)
82
+ if user_id is None:
83
+ gr.Error("่ฏทๅ…ˆ็™ปๅฝ•")
84
+ return "่ฏทๅ…ˆ็™ปๅฝ•"
85
+ book_id = book.split(" [")[1][:-1]
86
+ user_book = create_user_book(db, UserBookCreate(
87
+ owner_id=user_id,
88
+ book_id=book_id,
89
+ title=title,
90
+ random=randomize,
91
+ batch_size=batch_size
92
+ ))
93
+ if user_book is not None:
94
+ return "ๆˆๅŠŸ"
95
+ else:
96
+ return "ๅคฑ่ดฅ"
97
+
98
+ btn.click(submit, [user, select_book, title, randomize, batch_size], [status])
99
+ def on_select(user: Dict[str, str], evt: gr.SelectData):
100
+ user_id = user.get("id", None)
101
+ new_options = []
102
+ if user_id is None:
103
+ return gr.Dropdown(choices=new_options), "่ฏทๅ…ˆ็™ปๅฝ•"
104
+ books = get_all_books_for_user(db, user_id)
105
+ new_options = [f"{'โญ ' if book.permission == 'private' else ''}{book.bk_name} (ๅ…ฑ {book.bk_item_num} ่ฏ) [{book.bk_id}]" for book in books]
106
+ return gr.Dropdown(choices=new_options), f"ๆ‚จๅฅฝ๏ผŒ{user['email']}"
107
+ tab1.select(on_select, [user], [select_book, status])
108
+
109
+ # 2. ้€‰ๆ‹ฉๅ•่ฏๅˆ†ๆ‰น
110
+ with gr.Tab("้€‰ๆ‹ฉๅ•่ฏๅˆ†ๆ‰น") as tab2:
111
+ select_user_book = gr.Dropdown(
112
+ [], label="่ฎฐๅฟ†่ฎกๅˆ’", info="่ฏท้€‰ๆ‹ฉ่ฎฐๅฟ†่ฎกๅˆ’"
113
+ )
114
+ word_count = gr.Number(value=0, label="ๅ•่ฏไธชๆ•ฐ")
115
+ known_words = gr.CheckboxGroup(
116
+ [], label="ๅทฒๅญฆไผš็š„ๅ•่ฏ", info="ๆญฃๅผ่ฎฐๅฟ†ๅ‰ๅฐ†ๅŽป้™คๅทฒๅญฆไผš็š„ๅ•่ฏ๏ผŒๆ้ซ˜ๆฏไธชๆ‰นๆฌก็š„ๆ–ฐ่ฏๅฏ†ๅบฆ๏ผŒ่ฟ›่€Œๆ้ซ˜ๆ•ˆ็Ž‡"
117
+ )
118
+ btn = gr.Button("็”Ÿๆˆๆ‰นๆฌก")
119
+ status = gr.Textbox("3000 ่ฏๅคงๆฆ‚่ฆ 2 ๅฐๆ—ถๆ‰่ƒฝๅ†™ๅฎŒๆ‰€ๆœ‰็š„ๆ•…ไบ‹", lines=1, label="็”Ÿๆˆ็ป“ๆžœ")
120
+
121
+ def on_select_user(user):
122
+ user_id = user.get("id", None)
123
+ if user_id is None:
124
+ gr.Error("่ฏทๅ…ˆ็™ปๅฝ•")
125
+ return gr.Dropdown(choices=[]), "่ฏทๅ…ˆ็™ปๅฝ•"
126
+ new_options = []
127
+ user_book = get_user_books_by_owner_id(db, user_id)
128
+ new_options = [f"{book.title} | {book.batch_size}ไธชๅ•่ฏไธ€็ป„ [{book.id}]" for book in user_book]
129
+ return gr.Dropdown(choices=new_options), "3000 ่ฏๅคงๆฆ‚่ฆ 2 ๅฐๆ—ถๆ‰่ƒฝๅ†™ๅฎŒๆ‰€ๆœ‰็š„ๆ•…ไบ‹"
130
+
131
+ def on_select_user_book(user_book):
132
+ logger.debug(f'user_book {user_book}')
133
+ if user_book is None:
134
+ return 0, gr.CheckboxGroup(choices=[])
135
+ new_options = []
136
+ user_book_id = user_book.split(" [")[1][:-1]
137
+ user_book = get_user_book(db, user_book_id)
138
+ book_id = user_book.book_id
139
+ book = get_book(db, book_id)
140
+ if book is None:
141
+ return 0, gr.CheckboxGroup(choices=[])
142
+ words = get_words_for_book(db, user_book)
143
+ new_options = [f"{word.vc_vocabulary}" for word in words]
144
+ return len(words), gr.CheckboxGroup(choices=new_options)
145
+
146
+ select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[word_count, known_words])
147
+ tab2.select(on_select_user, [user], [select_user_book, status])
148
+
149
+ def submit(user_book, known_words):
150
+ start_time = time.time()
151
+ user_book_id = user_book.split(" [")[1][:-1]
152
+ user_book = get_user_book(db, user_book_id)
153
+ all_words = get_words_for_book(db, user_book)
154
+ unknown_words = []
155
+ for w in all_words:
156
+ if w.vc_vocabulary not in known_words:
157
+ unknown_words.append(w)
158
+ track(db, user_book, unknown_words)
159
+ end_time = time.time()
160
+ duration = end_time - start_time
161
+ return f"ๆˆๅŠŸ๏ผๅˆ†ไธบ {len(unknown_words) // user_book.batch_size} ไธชๆ‰นๆฌก๏ผŒๅ…ฑ {len(unknown_words)} ไธชๅ•่ฏ๏ผŒ่€—ๆ—ถ {duration:.2f} ็ง’"
162
+
163
+ btn.click(submit, [select_user_book, known_words], [status])
164
+
165
+ # 3. ่ฎฐๅฟ†
166
+ with gr.Tab("่ฎฐๅฟ†") as tab3:
167
+ select_user_book = gr.Dropdown(
168
+ [], label="่ฎฐๅฟ†่ฎกๅˆ’", info="่ฏท้€‰ๆ‹ฉ่ฎฐๅฟ†่ฎกๅˆ’"
169
+ )
170
+ info = gr.Accordion(f"ๆ–ฐ่ฏ", open=False)
171
+ with info:
172
+ gr.Markdown(f"ๆ–ฐ่ฏ")
173
+
174
+ dataframe_header = ["ๅ•่ฏ", "ไธญๆ–‡่ฏๆ„", "่‹ฑๅผ้Ÿณๆ ‡", "็พŽๅผ้Ÿณๆ ‡", "่ฎฐๅฟ†้‡"]
175
+ memorizing_dataframe = gr.Dataframe(
176
+ headers=dataframe_header,
177
+ datatype=["str"] * len(dataframe_header),
178
+ col_count=(len(dataframe_header), "fixed"),
179
+ wrap=True,
180
+ )
181
+ batches = gr.State(value=[])
182
+ current_batch_index = gr.State(value=-1)
183
+ user_book_id = gr.State(value=None)
184
+ with gr.Row():
185
+ # story = gr.HighlightedText([])
186
+ # translated_story = gr.HighlightedText([])
187
+ # story = gr.Textbox()
188
+ # translated_story = gr.Textbox()
189
+ story = gr.Markdown()
190
+ translated_story = gr.Markdown()
191
+ # ่ฏ•ไบ†ไธ€ไธ‹๏ผŒ่ฟ˜ๆ˜ฏ markdown ็š„ๆ˜พ็คบๆ•ˆๆžœๅฅฝ
192
+
193
+ memorize_action = gr.CheckboxGroup(choices=[], label="่ฎฐไฝ็š„ๅ•่ฏ", info="่ƒฝๅคŸๅค่ฟฐๅ‡บๆ„ๆ€ๆ‰็ฎ—่ฎฐไฝ")
194
+ with gr.Row():
195
+ previous_batch_btn = gr.Button("ไธŠไธ€ๆ‰น")
196
+ regenerate_btn = gr.Button("้‡ๆ–ฐ็”Ÿๆˆๆ•…ไบ‹")
197
+ next_batch_btn = gr.Button("ไธ‹ไธ€ๆ‰น", variant="primary")
198
+ progress = gr.Slider(1, 1, value=1, step=1, label="่ฟ›ๅบฆ", info="")
199
+
200
+ def on_select_user(user):
201
+ user_id = user.get("id", None)
202
+ if user_id is None:
203
+ gr.Error("่ฏทๅ…ˆ็™ปๅฝ•")
204
+ return gr.Dropdown(choices=[])
205
+ new_options = []
206
+ user_book = get_user_books_by_owner_id(db, user_id)
207
+ new_options = [f"{book.title} | {book.batch_size}ไธชๅ•่ฏไธ€็ป„ [{book.id}]" for book in user_book]
208
+ return gr.Dropdown(choices=new_options)
209
+
210
+ def update_from_batch(memorizing_batch: UserMemoryBatch):
211
+ new_options = []
212
+ word_df = []
213
+ # logger.debug(get_user_memory_batch(db, memorizing_batch.id))
214
+ # logger.debug(memorizing_batch.id)
215
+ # logger.debug(get_user_memory_words_by_batch_id(db, memorizing_batch.id))
216
+ # logger.debug(get_words_by_ids(db, [w.word_id for w in get_user_memory_words_by_batch_id(db, memorizing_batch.id)]))
217
+ # words = get_words_in_batch(db, memorizing_batch.id)
218
+ # words = get_words_by_ids(db, [w.word_id for w in memorizing_words])
219
+ memorizing_words = get_user_memory_words_by_batch_id(db, memorizing_batch.id)
220
+ words = get_words_by_ids(db, [w.word_id for w in memorizing_words])
221
+ # ็ปŸ่ฎก่ฎฐๅฟ†้‡
222
+ actions = get_actions_at_each_word(db, [w.word_id for w in memorizing_words])
223
+ remember_count = defaultdict(int)
224
+ forget_count = defaultdict(int)
225
+ for a in actions:
226
+ if a.action == "remember":
227
+ remember_count[a.word_id] += 1
228
+ else:
229
+ forget_count[a.word_id] += 1
230
+ # ็ปŸ่ฎก่ฎฐๅฟ†ๆ•ˆ็Ž‡
231
+ batch_actions = get_user_memory_batch_actions_by_user_memory_batch_id(db, memorizing_batch.id)
232
+ batch_actions.sort(key=lambda x: x.create_time)
233
+ start, end = None, None
234
+ total_duration = None
235
+ for a in batch_actions:
236
+ if a.action == "start":
237
+ start: datetime = a.create_time
238
+ elif a.action == "end":
239
+ end: datetime = a.create_time
240
+ if start is None:
241
+ continue
242
+ if total_duration is None:
243
+ total_duration = end - start
244
+ else:
245
+ total_duration += end - start
246
+ memory_speed = f"{memorizing_batch.batch_type}"
247
+ if total_duration is not None:
248
+ sec = total_duration.total_seconds()
249
+ minutes = sec / 60
250
+ memory_speed += f"๏ผšๅฝ“ๅ‰ๆ‰นๆฌก่ฎฐๅฟ†ๆ•ˆ็Ž‡ {len(memorizing_words) / minutes:.2f} ่ฏ/ๅˆ†้’Ÿ๏ผŒ{minutes:.2f} ๅˆ†้’Ÿ/ๆ‰นๆฌก"
251
+ # ๅ•่ฏไฟกๆฏ่กจๆ ผไธŽๅ‹พ้€‰
252
+ for w in words:
253
+ new_options.append(f"{w.vc_vocabulary}")
254
+ word_df.append([
255
+ w.vc_vocabulary, # ๅ•่ฏ
256
+ w.vc_translation, # ไธญๆ–‡่ฏๆ„
257
+ w.vc_phonetic_uk, # ่‹ฑๅผ้Ÿณๆ ‡
258
+ w.vc_phonetic_us, # ็พŽๅผ้Ÿณๆ ‡
259
+ f"{remember_count[w.vc_id]} / {remember_count[w.vc_id] + forget_count[w.vc_id]}", # ่ฎฐๅฟ†้‡
260
+ ])
261
+ df = pd.DataFrame(word_df, columns=dataframe_header)
262
+ if memorizing_batch.batch_type == "ๅ›žๅฟ†":
263
+ df = pd.DataFrame([[row[0], "", row[2], row[3], row[4]] for row in word_df], columns=dataframe_header)
264
+ # ๆ•…ไบ‹
265
+ story = memorizing_batch.story
266
+ translated_story = memorizing_batch.translated_story
267
+ if len(story) == 0 or len(translated_story) == 0:
268
+ story, translated_story = regenerate_for_batch(memorizing_batch, words)
269
+
270
+ logger.info("่ฎก็ฎ—ๆ‰นๆฌกไฟกๆฏ")
271
+ logger.info(new_options)
272
+ logger.info(story)
273
+ logger.info(translated_story)
274
+ logger.info("=" * 8)
275
+ return (gr.Accordion(label=memory_speed), df, story, translated_story, gr.CheckboxGroup(choices=new_options))
276
+
277
+ def on_select_user_book(user_book_id: str):
278
+ """
279
+ 1. ๅฝ“ๅ‰ๅ•่ฏ
280
+ 2. ๅฏนๅฝ“ๅ‰ๅ•่ฏ็š„ๆ“ไฝœ
281
+ 3. ๆ•…ไบ‹
282
+ """
283
+ logger.debug(f'user_book {user_book_id}')
284
+ if user_book_id is None:
285
+ # ไธบไป€ไนˆไผš็ฉบ๏ผŸ่ฟ™้‡Œ่ฟ”ๅ›ž็š„ไธœ่ฅฟๅฏ่ƒฝไผš็ˆ†็‚ธ๏ผŒไฝ†ๅฅฝๅƒๆ‰ง่กŒไธๅˆฐ่ฟ™้‡Œ
286
+ # ไธ็ฎกไบ†๏ผŒๆ”พไธชๅ‘Š็คบ็‰Œๅœจ่ฟ™้‡Œ๏ผŒๅคงๅฎถ็œ‹่ง่ฟ™ไธชๅ‘่ฏท็ป•็€่ตฐ
287
+ return [], gr.CheckboxGroup(choices=[])
288
+ user_book_id: str = user_book_id.split(" [")[1][:-1]
289
+ user_book = get_user_book(db, user_book_id)
290
+ batches = get_new_user_memory_batches_by_user_book_id(db, user_book_id) # ๅช็ผ“ๅญ˜ๆ–ฐ่ฏ
291
+ batch_id = user_book.memorizing_batch
292
+ memorizing_batch = get_user_memory_batch(db, batch_id)
293
+ current_batch_index = -1
294
+ if memorizing_batch is not None:
295
+ for index, b in enumerate(batches):
296
+ if b.id == memorizing_batch.id:
297
+ current_batch_index = index
298
+ break
299
+ if current_batch_index == -1:
300
+ # ๅฝ“ๅ‰่ฟ˜ๆฒกๅผ€ๅง‹่ฎฐๅฟ†๏ผŒๆˆ–่€…ๅฝ“ๅ‰ๆ‰นๆฌกไธๆ˜ฏๆ–ฐ่ฏๆ‰นๆฌก
301
+ current_batch_index = 0
302
+ memorizing_batch = batches[0]
303
+ batch_id = memorizing_batch.id
304
+ user_book.memorizing_batch = batch_id
305
+ update_user_book(db, user_book_id, UserBookUpdate(
306
+ owner_id=user_book.owner_id,
307
+ book_id=user_book.book_id,
308
+ title=user_book.title,
309
+ random=user_book.random,
310
+ batch_size=user_book.batch_size,
311
+ memorizing_batch=batch_id
312
+ ))
313
+ updates = update_from_batch(memorizing_batch)
314
+ on_batch_start(db, memorizing_batch.id)
315
+ asyncio.run(pregenerate(batches, current_batch_index))
316
+ return (batches, current_batch_index, user_book) + updates + (
317
+ gr.Slider(
318
+ minimum=1,
319
+ maximum=len(batches),
320
+ value=current_batch_index,
321
+ ),)
322
+
323
+ batch_widget = [info, memorizing_dataframe, story, translated_story, memorize_action]
324
+ tab3.select(on_select_user, inputs=[user], outputs=[select_user_book])
325
+ select_user_book.select(
326
+ on_select_user_book,
327
+ inputs=[select_user_book],
328
+ outputs=[batches, current_batch_index, user_book_id] + batch_widget + [progress]
329
+ )
330
+ async def worker_regenerate_for_batch(batches: List[UserMemoryBatch], index: int):
331
+ started_at = time.monotonic()
332
+ logger.info(f"started {index}")
333
+ # start
334
+ batch = batches[index]
335
+ story = batch.story
336
+ translated_story = batch.translated_story
337
+ if len(story) == 0 or len(translated_story) == 0:
338
+ batch_words = get_words_in_batch(db, batch.id)
339
+ regenerate_for_batch(batch, batch_words)
340
+ # end
341
+ total = time.monotonic() - started_at
342
+ logger.info(f'completed in {total:.2f} seconds')
343
+
344
+ async def pregenerate(batches: List[UserMemoryBatch], current_batch_index: int):
345
+ logger.info("ๅผ€ๅง‹้ข„็”Ÿๆˆๆ•…ไบ‹")
346
+ indexes = [current_batch_index+i+1 for i in range(3)]+[current_batch_index-i-1 for i in range(3)]
347
+ indexes = [i for i in indexes if 0 <= i < len(batches)]
348
+ for index in indexes:
349
+ asyncio.ensure_future(worker_regenerate_for_batch(batches, index))
350
+ logger.info("็ป“ๆŸ้ข„็”Ÿๆˆๆ•…ไบ‹")
351
+
352
+ def submit_batch(batches: List[UserMemoryBatch], current_batch_index: int):
353
+ memorizing_batch = batches[current_batch_index]
354
+ return set_memorizing_batch(batches, current_batch_index, memorizing_batch)
355
+
356
+ def set_memorizing_batch(batches: List[UserMemoryBatch], current_batch_index: int, memorizing_batch: UserMemoryBatch):
357
+ updates = update_from_batch(memorizing_batch)
358
+ asyncio.run(pregenerate(batches, current_batch_index))
359
+ logger.info("pregenerated")
360
+ return updates + (gr.Slider(value=current_batch_index+1), current_batch_index)
361
+
362
+ def save_progress(old_batch: UserMemoryBatch, memorize_action: List[str]):
363
+ # ไฟๅญ˜ๅ•่ฏ่ฎฐๅฟ†่ฟ›ๅบฆ
364
+ actions = []
365
+ words = get_words_in_batch(db, old_batch.id)
366
+ for word in words:
367
+ if word.vc_vocabulary in memorize_action:
368
+ actions.append((word.vc_id, "remember"))
369
+ else:
370
+ actions.append((word.vc_id, "forget"))
371
+ save_memorizing_word_action(db, old_batch.id, actions)
372
+
373
+ def previous_batch(batches: List[UserMemoryBatch], current_batch_index: int, user_book: schema.UserBook, memorize_action: List[str]):
374
+ old_index = current_batch_index
375
+ if current_batch_index <= 0:
376
+ current_batch_index = 0
377
+ elif current_batch_index > 0:
378
+ current_batch_index -= 1
379
+ if current_batch_index != old_index:
380
+ # ไธ‹ไธ€้กตไน‹ๅ‰้œ€่ฆไฟๅญ˜่ฎฐๅฟ†่ฟ›ๅบฆ
381
+ # logger.info("ไธ‹ไธ€้กตไน‹ๅ‰้œ€่ฆไฟๅญ˜่ฎฐๅฟ†่ฟ›ๅบฆ")
382
+ # logger.info(memorize_action)
383
+ # ไฟๅญ˜ๆ‰นๆฌก่ฟ›ๅบฆ
384
+ old_batch = batches[old_index]
385
+ current_batch = batches[current_batch_index]
386
+ save_progress(old_batch, memorize_action)
387
+ on_batch_end(db, old_batch.id)
388
+ on_batch_start(db, current_batch.id)
389
+ user_book_id = user_book.id
390
+ update_user_book_memorizing_batch(db, user_book_id, current_batch.id)
391
+ return submit_batch(batches, current_batch_index)
392
+
393
+ def next_batch(batches: List[UserMemoryBatch], current_batch_index: int, user_book: schema.UserBook, memorize_action: List[str]):
394
+ old_index = current_batch_index
395
+ if current_batch_index >= len(batches)-1:
396
+ current_batch_index = len(batches)-1
397
+ elif current_batch_index < len(batches) - 1:
398
+ current_batch_index += 1
399
+ if current_batch_index != old_index:
400
+ # ไธ‹ไธ€้กตไน‹ๅ‰้œ€่ฆไฟๅญ˜่ฎฐๅฟ†่ฟ›ๅบฆ
401
+ # logger.info("ไธ‹ไธ€้กตไน‹ๅ‰้œ€่ฆไฟๅญ˜่ฎฐๅฟ†่ฟ›ๅบฆ")
402
+ # logger.info(memorize_action)
403
+ # ไฟๅญ˜ๆ‰นๆฌก่ฟ›ๅบฆ
404
+ old_batch = batches[old_index]
405
+ memorizing_batch = get_user_memory_batch(db, user_book.memorizing_batch)
406
+ if memorizing_batch is not None:
407
+ old_batch = memorizing_batch
408
+ current_batch = batches[current_batch_index]
409
+ save_progress(old_batch, memorize_action)
410
+ on_batch_end(db, old_batch.id)
411
+ next_batch = generate_next_batch(db, user_book, minutes=60, k=3)
412
+ if next_batch is not None:
413
+ current_batch = next_batch
414
+ on_batch_start(db, current_batch.id)
415
+ user_book_id = user_book.id
416
+ update_user_book_memorizing_batch(db, user_book_id, current_batch.id)
417
+ if next_batch is not None:
418
+ return set_memorizing_batch(batches, old_index, current_batch)
419
+ else:
420
+ return set_memorizing_batch(batches, current_batch_index, current_batch)
421
+ else:
422
+ memorizing_batch = get_user_memory_batch(db, user_book.memorizing_batch)
423
+ current_batch = batches[current_batch_index]
424
+ save_progress(memorizing_batch, memorize_action)
425
+ on_batch_end(db, memorizing_batch.id)
426
+ next_batch = generate_next_batch(db, user_book, minutes=60, k=3)
427
+ if next_batch is not None:
428
+ current_batch = next_batch
429
+ on_batch_start(db, current_batch.id)
430
+ user_book_id = user_book.id
431
+ update_user_book_memorizing_batch(db, user_book_id, current_batch.id)
432
+ if next_batch is not None:
433
+ return set_memorizing_batch(batches, old_index, current_batch)
434
+ else:
435
+ return set_memorizing_batch(batches, current_batch_index, current_batch)
436
+ previous_batch_btn.click(
437
+ previous_batch,
438
+ inputs=[batches, current_batch_index, user_book_id, memorize_action],
439
+ outputs=batch_widget + [progress, current_batch_index]
440
+ )
441
+ next_batch_btn.click(
442
+ next_batch,
443
+ inputs=[batches, current_batch_index, user_book_id, memorize_action],
444
+ outputs=batch_widget + [progress, current_batch_index]
445
+ )
446
+
447
+ def regenerate_for_batch(memorizing_batch: UserMemoryBatch, batch_words: List[Word]):
448
+ batch_words_str_list = [word.vc_vocabulary for word in batch_words]
449
+ logger.info(f"็”Ÿๆˆๆ•…ไบ‹ {batch_words_str_list}")
450
+ story, translated_story = generate_story_and_translated_story(batch_words_str_list)
451
+ memorizing_batch.story = story
452
+ memorizing_batch.translated_story = translated_story
453
+ db.commit()
454
+ db.refresh(memorizing_batch)
455
+ create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
456
+ batch_id=memorizing_batch.id,
457
+ story=story,
458
+ translated_story=translated_story
459
+ ))
460
+ logger.info(story)
461
+ logger.info(translated_story)
462
+ return story, translated_story
463
+
464
+ def regenerate(batches: List[UserMemoryBatch], current_batch_index: int):
465
+ # ้‡ๆ–ฐ็”Ÿๆˆๆ•…ไบ‹
466
+ memorizing_batch = batches[current_batch_index]
467
+ batch_words = get_words_in_batch(db, memorizing_batch.id)
468
+ story, translated_story = regenerate_for_batch(memorizing_batch, batch_words)
469
+ return story, translated_story
470
+ regenerate_btn.click(regenerate, inputs=[batches, current_batch_index], outputs=[story, translated_story])
471
+
472
+ # 4. ไปŽ่ฎฐๅฟ†่ฎกๅˆ’ไธญๅˆ›ๅปบๅ•่ฏไนฆ
473
+ with gr.Tab("ไปŽ่ฎฐๅฟ†่ฎกๅˆ’ไธญๅˆ›ๅปบๅ•่ฏไนฆ") as tab4:
474
+ select_user_book = gr.Dropdown(
475
+ [], label="่ฎฐๅฟ†่ฎกๅˆ’", info="่ฏท้€‰ๆ‹ฉ่ฎฐๅฟ†่ฎกๅˆ’"
476
+ )
477
+ word_count = gr.Number(value=0, label="ๅ•่ฏไธชๆ•ฐ")
478
+ known_words = gr.CheckboxGroup(
479
+ [], label="ๅทฒๅญฆไผš็š„ๅ•่ฏ", info="่ฏทๆฃ€ๆŸฅๅทฒๅญฆไผš็š„ๅ•่ฏ๏ผŒ่ฟ™ไบ›ๅ•่ฏๅฐ†ไธไผš่ขซๅŒ…ๅซๅœจๆ–ฐ็š„ๅ•่ฏไนฆไธญ"
480
+ )
481
+ title = gr.TextArea(value='ๅ•่ฏไนฆ', lines=1, label="ๅ•่ฏไนฆ็š„ๅ็งฐ")
482
+ btn = gr.Button("ไปŽ่ฎฐๅฟ†่ฎกๅˆ’ไธญๅˆ›ๅปบๅ•่ฏไนฆ")
483
+ status = gr.Textbox("", lines=1, label="็Šถๆ€")
484
+
485
+ def on_select_user(user):
486
+ user_id = user.get("id", None)
487
+ if user_id is None:
488
+ gr.Error("่ฏทๅ…ˆ็™ปๅฝ•")
489
+ return gr.Dropdown(choices=[])
490
+ new_options = []
491
+ user_book = get_user_books_by_owner_id(db, user_id)
492
+ new_options = [f"{book.title} | {book.batch_size}ไธชๅ•่ฏไธ€็ป„ [{book.id}]" for book in user_book]
493
+ return gr.Dropdown(choices=new_options)
494
+
495
+ def on_select_user_book(user_book):
496
+ logger.debug(f'user_book {user_book}')
497
+ if user_book is None:
498
+ return 0, gr.CheckboxGroup(choices=[])
499
+ new_options = []
500
+ user_book_id = user_book.split(" [")[1][:-1]
501
+ words = get_words_in_user_book(db, user_book_id)
502
+ new_options = [f"{word.vc_vocabulary}" for word in words]
503
+ return len(words), gr.CheckboxGroup(choices=new_options)
504
+
505
+ tab4.select(on_select_user, inputs=[user], outputs=[select_user_book])
506
+ select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[word_count, known_words])
507
+
508
+ def submit(user, user_book, known_words, title):
509
+ user_id = user.get("id", None)
510
+ if user_id is None:
511
+ gr.Error("่ฏทๅ…ˆ็™ปๅฝ•")
512
+ return "่ฏทๅ…ˆ็™ปๅฝ•"
513
+ user_book_id = user_book.split(" [")[1][:-1]
514
+ all_words = get_words_in_user_book(db, user_book_id)
515
+ unknown_words = []
516
+ for w in all_words:
517
+ if w.vc_vocabulary not in known_words:
518
+ unknown_words.append(w)
519
+ # all_words = get_words_by_vocabulary(db, known_words)
520
+ book = save_words_as_book(db, user_id, unknown_words, title)
521
+ if book is not None:
522
+ return f"ๆˆๅŠŸ็”Ÿๆˆไธ€ๆœฌๅ•่ฏไนฆ๏ผš{book.bk_name}"
523
+ else:
524
+ return "ๅคฑ่ดฅ"
525
+
526
+ btn.click(submit, [user, select_user_book, known_words, title], [status])
527
+
528
+ # 5. ็ปŸ่ฎก
529
+ with gr.Tab("็ปŸ่ฎก") as tab5:
530
+ # 5.1. ๆ•…ไบ‹็”ŸๆˆๅŽ†ๅฒ
531
+ with gr.Tab("AI ๅŽ†ๅฒ่ฎฐๅฝ•") as tab51:
532
+ select_user_book = gr.Dropdown(
533
+ [], label="่ฎฐๅฟ†่ฎกๅˆ’", info="่ฏท้€‰ๆ‹ฉ่ฎฐๅฟ†่ฎกๅˆ’"
534
+ )
535
+
536
+ history_header = ["ๅ•่ฏ", "ๆ•…ไบ‹", "ไธญๆ–‡ๆ•…ไบ‹", "็”Ÿๆˆๆ—ถ้—ด"]
537
+ history_dataframe = gr.Dataframe(
538
+ headers=history_header,
539
+ datatype=["str"] * len(history_header),
540
+ col_count=(len(history_header), "fixed"),
541
+ wrap=True,
542
+ min_width=320,
543
+ height=800,
544
+ )
545
+
546
+ def on_select_user(user):
547
+ user_id = user.get("id", None)
548
+ if user_id is None:
549
+ gr.Error("่ฏทๅ…ˆ็™ปๅฝ•")
550
+ return gr.Dropdown(choices=[])
551
+ new_options = []
552
+ user_book = get_user_books_by_owner_id(db, user_id)
553
+ new_options = [f"{book.title} | {book.batch_size}ไธชๅ•่ฏไธ€็ป„ [{book.id}]" for book in user_book]
554
+ return gr.Dropdown(choices=new_options)
555
+
556
+ def on_select_user_book(user_book_id):
557
+ logger.debug(f'user_book {user_book_id}')
558
+ if user_book_id is None:
559
+ return 0, gr.CheckboxGroup(choices=[])
560
+ user_book_id = user_book_id.split(" [")[1][:-1]
561
+ batch_id_to_words_and_history = get_generation_hostorys_by_user_book_id(db, user_book_id)
562
+ data = []
563
+ for batch_id, (words, histories) in batch_id_to_words_and_history.items():
564
+ for history in histories:
565
+ word = ", ".join([w.vc_vocabulary for w in words])
566
+ story = history.story
567
+ translated_story = history.translated_story
568
+ create_time = history.create_time
569
+ data.append([word, story, translated_story, create_time])
570
+ df = pd.DataFrame(data, columns=history_header)
571
+ return df
572
+
573
+ tab51.select(on_select_user, inputs=[user], outputs=[select_user_book])
574
+ select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[history_dataframe])
575
+
576
+ # 5.2. ่ฎฐๅฟ†ๅŽ†ๅฒ่ฎฐๅฝ•
577
+ with gr.Tab("่ฎฐๅฟ†ๅŽ†ๅฒ่ฎฐๅฝ•") as tab52:
578
+ select_user_book = gr.Dropdown(
579
+ [], label="่ฎฐๅฟ†่ฎกๅˆ’", info="่ฏท้€‰ๆ‹ฉ่ฎฐๅฟ†่ฎกๅˆ’"
580
+ )
581
+
582
+ batch_history_header = ["ๅ•่ฏ", "ๆ•…ไบ‹", "ไธญๆ–‡ๆ•…ไบ‹", "ๆ‰นๆฌก็ฑปๅž‹", "่ฎฐๅฟ†ๆƒ…ๅ†ต", "็”Ÿๆˆๆ—ถ้—ด"]
583
+ batch_history_dataframe = gr.Dataframe(
584
+ headers=batch_history_header,
585
+ datatype=["str"] * len(batch_history_header),
586
+ col_count=(len(batch_history_header), "fixed"),
587
+ wrap=True,
588
+ min_width=320,
589
+ height=800,
590
+ )
591
+
592
+ def on_select_user(user):
593
+ user_id = user.get("id", None)
594
+ if user_id is None:
595
+ gr.Error("่ฏทๅ…ˆ็™ปๅฝ•")
596
+ return gr.Dropdown(choices=[])
597
+ new_options = []
598
+ user_book = get_user_books_by_owner_id(db, user_id)
599
+ new_options = [f"{book.title} | {book.batch_size}ไธชๅ•่ฏไธ€็ป„ [{book.id}]" for book in user_book]
600
+ return gr.Dropdown(choices=new_options)
601
+
602
+ def on_select_user_book(user_book_id):
603
+ logger.debug(f'user_book {user_book_id}')
604
+ if user_book_id is None:
605
+ return 0, gr.CheckboxGroup(choices=[])
606
+ user_book_id = user_book_id.split(" [")[1][:-1]
607
+ actions, batch_id_to_batch, batch_id_to_words, batch_id_to_actions = get_user_memory_batch_history(db, user_book_id)
608
+ data = []
609
+ for action in actions:
610
+ batch_id = action.batch_id
611
+
612
+ words = batch_id_to_words[batch_id]
613
+ word = ", ".join([w.vc_vocabulary for w in words])
614
+
615
+ batch = batch_id_to_batch[batch_id]
616
+ story = batch.story
617
+ translated_story = batch.translated_story
618
+ batch_type = batch.batch_type
619
+
620
+ memory_actions = batch_id_to_actions.get(batch_id, [])
621
+ remember_word_ids = {a.word_id for a in memory_actions if a.action == "remember"}
622
+ remember_words = []
623
+ forget_words = []
624
+ for w in words:
625
+ if w.vc_id in remember_word_ids:
626
+ remember_words.append(w.vc_vocabulary)
627
+ else:
628
+ forget_words.append(w.vc_vocabulary)
629
+ memory_status = f"่ฎฐไฝ {len(remember_words)} ไธชๅ•่ฏ๏ผŒๅฟ˜่ฎฐ {len(forget_words)} ไธชๅ•่ฏ"
630
+ memory_status += f"๏ผŒ่ฎฐไฝ็š„ๅ•่ฏ๏ผš{', '.join(remember_words)}"
631
+ memory_status += f"๏ผŒๅฟ˜่ฎฐ็š„ๅ•่ฏ๏ผš{', '.join(forget_words)}"
632
+
633
+ create_time = action.create_time
634
+
635
+ data.append([word, story, translated_story, batch_type, memory_status, create_time])
636
+ df = pd.DataFrame(data, columns=batch_history_header)
637
+ return df
638
+
639
+ tab52.select(on_select_user, inputs=[user], outputs=[select_user_book])
640
+ select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[batch_history_dataframe])
641
+
642
+
643
+ on_login_success_ui = [email, password, login_btn, register_email, register_password, register_btn]
644
+ on_login_success_ui += [tab1]
645
+
646
+ def on_login(login_success):
647
+ return (
648
+ gr.TextArea(visible=not login_success),
649
+ gr.TextArea(visible=not login_success),
650
+ gr.Button(visible=not login_success),
651
+ gr.TextArea(visible=not login_success),
652
+ gr.TextArea(visible=not login_success),
653
+ gr.Button(visible=not login_success),
654
+ # gr.Accordion(visible=not login_success),
655
+ ) + (
656
+ gr.Tab(visible=login_success),
657
+ )
658
+ def login(email, password):
659
+ user = get_user_by_email(db, email)
660
+ if password is None or len(password) == 0:
661
+ return {
662
+ "id": "",
663
+ "email": "",
664
+ }, "็™ปๅฝ•ๅคฑ่ดฅ", *on_login(False)
665
+ if user is None or not user.verify_password(password):
666
+ return {
667
+ "id": "",
668
+ "email": "",
669
+ }, "็™ปๅฝ•ๅคฑ่ดฅ", *on_login(False)
670
+ return {
671
+ "id": user.id,
672
+ "email": user.email,
673
+ }, "็™ปๅฝ•ๆˆๅŠŸ", *on_login(True)
674
+ login_btn.click(login, [email, password], [user, user_status] + on_login_success_ui)
675
+ def register(email, password):
676
+ user = get_user_by_email(db, email)
677
+ if user is not None:
678
+ return {
679
+ "id": "",
680
+ "email": "",
681
+ }, "ๆณจๅ†Œๅคฑ่ดฅ๏ผŒ่ฏฅ้‚ฎ็ฎฑๅทฒ่ขซๆณจๅ†Œ", *on_login(False)
682
+ else:
683
+ user = create_user(db, email=email, password=password)
684
+ return {
685
+ "id": user.id,
686
+ "email": user.email,
687
+ }, "ๆณจๅ†Œๅนถ็™ปๅฝ•ๆˆๅŠŸ", *on_login(True)
688
+ register_btn.click(register, [register_email, register_password], [user, user_status] + on_login_success_ui)
689
+
690
+
691
+ if __name__ == "__main__":
692
+ # import os
693
+ # os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
694
+ # demo.launch(server_name="127.0.0.1", server_port=8090, debug=True)
695
+ logger.add(f"output/logs/web_{date_str}.log", rotation="1 day", retention="7 days", level="INFO")
696
+ demo.launch()