Spaces:
Runtime error
Runtime error
init
Browse files- .gitignore +1 -0
- LLM.py +420 -0
- app copy.py +96 -0
- app.py +25 -68
- common/util.py +31 -0
- database/__init__.py +11 -0
- database/constant.py +2 -0
- database/database.py +10 -0
- database/model.py +37 -0
- database/operation.py +704 -0
- database/schema.py +209 -0
- memorize.py +264 -0
- output/logs/.gitkeep +0 -0
- story_agent.py +69 -0
- web.py +696 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/*.pyc
|
LLM.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding=utf-8
|
2 |
+
import multiprocessing as mp
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import requests
|
6 |
+
import tiktoken
|
7 |
+
from tqdm import tqdm
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import (
|
10 |
+
AbstractSet,
|
11 |
+
Any,
|
12 |
+
Callable,
|
13 |
+
Collection,
|
14 |
+
Dict,
|
15 |
+
Generator,
|
16 |
+
List,
|
17 |
+
Literal,
|
18 |
+
Mapping,
|
19 |
+
Optional,
|
20 |
+
Set,
|
21 |
+
Tuple,
|
22 |
+
Union,
|
23 |
+
)
|
24 |
+
from pydantic import Extra, Field, root_validator
|
25 |
+
from loguru import logger
|
26 |
+
|
27 |
+
from langchain.llms.base import BaseLLM
|
28 |
+
from langchain.schema import Generation, LLMResult
|
29 |
+
from langchain.utils import get_from_dict_or_env
|
30 |
+
|
31 |
+
from langchain.callbacks.manager import (
|
32 |
+
AsyncCallbackManagerForLLMRun,
|
33 |
+
CallbackManagerForLLMRun,
|
34 |
+
)
|
35 |
+
import sys
|
36 |
+
import json
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass(frozen=True)
|
40 |
+
class ChatGPTConfig:
|
41 |
+
r"""Defines the parameters for generating chat completions using the
|
42 |
+
OpenAI API.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
temperature (float, optional): Sampling temperature to use, between
|
46 |
+
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
47 |
+
while lower values make it more focused and deterministic.
|
48 |
+
(default: :obj:`0.2`)
|
49 |
+
top_p (float, optional): An alternative to sampling with temperature,
|
50 |
+
called nucleus sampling, where the model considers the results of
|
51 |
+
the tokens with top_p probability mass. So :obj:`0.1` means only
|
52 |
+
the tokens comprising the top 10% probability mass are considered.
|
53 |
+
(default: :obj:`1.0`)
|
54 |
+
n (int, optional): How many chat completion choices to generate for
|
55 |
+
each input message. ()default: :obj:`1`)
|
56 |
+
stream (bool, optional): If True, partial message deltas will be sent
|
57 |
+
as data-only server-sent events as they become available.
|
58 |
+
(default: :obj:`False`)
|
59 |
+
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
60 |
+
will stop generating further tokens. (default: :obj:`None`)
|
61 |
+
max_tokens (int, optional): The maximum number of tokens to generate
|
62 |
+
in the chat completion. The total length of input tokens and
|
63 |
+
generated tokens is limited by the model's context length.
|
64 |
+
(default: :obj:`None`)
|
65 |
+
presence_penalty (float, optional): Number between :obj:`-2.0` and
|
66 |
+
:obj:`2.0`. Positive values penalize new tokens based on whether
|
67 |
+
they appear in the text so far, increasing the model's likelihood
|
68 |
+
to talk about new topics. See more information about frequency and
|
69 |
+
presence penalties. (default: :obj:`0.0`)
|
70 |
+
frequency_penalty (float, optional): Number between :obj:`-2.0` and
|
71 |
+
:obj:`2.0`. Positive values penalize new tokens based on their
|
72 |
+
existing frequency in the text so far, decreasing the model's
|
73 |
+
likelihood to repeat the same line verbatim. See more information
|
74 |
+
about frequency and presence penalties. (default: :obj:`0.0`)
|
75 |
+
logit_bias (dict, optional): Modify the likelihood of specified tokens
|
76 |
+
appearing in the completion. Accepts a json object that maps tokens
|
77 |
+
(specified by their token ID in the tokenizer) to an associated
|
78 |
+
bias value from :obj:`-100` to :obj:`100`. Mathematically, the bias
|
79 |
+
is added to the logits generated by the model prior to sampling.
|
80 |
+
The exact effect will vary per model, but values between:obj:` -1`
|
81 |
+
and :obj:`1` should decrease or increase likelihood of selection;
|
82 |
+
values like :obj:`-100` or :obj:`100` should result in a ban or
|
83 |
+
exclusive selection of the relevant token. (default: :obj:`{}`)
|
84 |
+
user (str, optional): A unique identifier representing your end-user,
|
85 |
+
which can help OpenAI to monitor and detect abuse.
|
86 |
+
(default: :obj:`""`)
|
87 |
+
"""
|
88 |
+
temperature: float = 1.0 # openai default: 1.0
|
89 |
+
top_p: float = 1.0
|
90 |
+
max_in_tokens: int = 3200
|
91 |
+
timeout: int = 20
|
92 |
+
|
93 |
+
|
94 |
+
def get_userid_and_token(
|
95 |
+
url='http://avatar.aicubes.cn/vtuber/auth/api/oauth/v1/login',
|
96 |
+
app_id='6027294018fd496693d0b8c77e2d20a1',
|
97 |
+
app_secret='52806a6fff8a452497061b9dcc5779f4'
|
98 |
+
):
|
99 |
+
d = {'app_id': app_id, 'app_secret': app_secret}
|
100 |
+
h = {'Content-Type': 'application/json'}
|
101 |
+
r = requests.post(url, json=d, headers=h)
|
102 |
+
data = r.json()['data']
|
103 |
+
return data['user_id'], data['token']
|
104 |
+
|
105 |
+
|
106 |
+
class ChatAPI:
|
107 |
+
def __init__(self, timeout=20, verbose=False) -> None:
|
108 |
+
self.timeout = timeout
|
109 |
+
self.verbose = verbose
|
110 |
+
self.user_id, self.token = get_userid_and_token()
|
111 |
+
|
112 |
+
def create_chat_completion(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None) -> str:
|
113 |
+
res = self.create_chat_completion_response_data(messages, model, temperature, max_tokens)
|
114 |
+
return res['choices'][0]['message']['content']
|
115 |
+
|
116 |
+
def create_chat_completion_response_data(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None):
|
117 |
+
res = self.create_chat_completion_response(messages, model, temperature, max_tokens)
|
118 |
+
res = res.json()['data']
|
119 |
+
return res
|
120 |
+
|
121 |
+
def create_chat_completion_response(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None):
|
122 |
+
chat_url = 'http://avatar.aicubes.cn/vtuber/ai_access/chatgpt/v1/chat/completions'
|
123 |
+
chat_header = {
|
124 |
+
'Content-Type': 'application/json',
|
125 |
+
'userId': self.user_id,
|
126 |
+
'token': self.token
|
127 |
+
}
|
128 |
+
payload = {
|
129 |
+
'model': model,
|
130 |
+
'messages': messages,
|
131 |
+
'temperature': temperature,
|
132 |
+
'max_tokens': max_tokens,
|
133 |
+
}
|
134 |
+
timeout = self.timeout
|
135 |
+
res = requests.post(chat_url, json=payload, headers=chat_header, timeout=timeout)
|
136 |
+
if self.verbose:
|
137 |
+
data = res.json()["data"]
|
138 |
+
if data is None:
|
139 |
+
logger.debug(res.json())
|
140 |
+
else:
|
141 |
+
logger.debug(data["choices"][0]["message"]["content"])
|
142 |
+
return res
|
143 |
+
|
144 |
+
|
145 |
+
class OpenAIChat(BaseLLM):
|
146 |
+
"""Wrapper around OpenAI Chat large language models.
|
147 |
+
|
148 |
+
To use, you should have the ``openai`` python package installed, and the
|
149 |
+
environment variable ``OPENAI_API_KEY`` set with your API key.
|
150 |
+
|
151 |
+
Any parameters that are valid to be passed to the openai.create call can be passed
|
152 |
+
in, even if not explicitly saved on this class.
|
153 |
+
|
154 |
+
Example:
|
155 |
+
.. code-block:: python
|
156 |
+
|
157 |
+
from langchain.llms import OpenAIChat
|
158 |
+
openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
|
159 |
+
"""
|
160 |
+
|
161 |
+
model_name: str = "gpt-3.5-turbo"
|
162 |
+
"""Model name to use."""
|
163 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
164 |
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
165 |
+
max_retries: int = 6
|
166 |
+
"""Maximum number of retries to make when generating."""
|
167 |
+
prefix_messages: List = Field(default_factory=list)
|
168 |
+
"""Series of messages for Chat input."""
|
169 |
+
streaming: bool = False
|
170 |
+
"""Whether to stream the results or not."""
|
171 |
+
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
172 |
+
"""Set of special tokens that are allowedใ"""
|
173 |
+
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
174 |
+
"""Set of special tokens that are not allowedใ"""
|
175 |
+
api = ChatAPI(timeout=60)
|
176 |
+
generate_verbose: bool = False
|
177 |
+
|
178 |
+
class Config:
|
179 |
+
"""Configuration for this pydantic object."""
|
180 |
+
|
181 |
+
extra = Extra.ignore
|
182 |
+
|
183 |
+
@root_validator(pre=True)
|
184 |
+
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
185 |
+
"""Build extra kwargs from additional params that were passed in."""
|
186 |
+
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
187 |
+
|
188 |
+
extra = values.get("model_kwargs", {})
|
189 |
+
for field_name in list(values):
|
190 |
+
if field_name not in all_required_field_names:
|
191 |
+
if field_name in extra:
|
192 |
+
raise ValueError(f"Found {field_name} supplied twice.")
|
193 |
+
extra[field_name] = values.pop(field_name)
|
194 |
+
values["model_kwargs"] = extra
|
195 |
+
return values
|
196 |
+
|
197 |
+
@root_validator()
|
198 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
199 |
+
"""Validate that api key and python package exists in environment."""
|
200 |
+
return values
|
201 |
+
|
202 |
+
@property
|
203 |
+
def _default_params(self) -> Dict[str, Any]:
|
204 |
+
"""Get the default parameters for calling OpenAI API."""
|
205 |
+
return self.model_kwargs
|
206 |
+
|
207 |
+
def _get_chat_params(
|
208 |
+
self, prompts: List[str], stop: Optional[List[str]] = None
|
209 |
+
) -> Tuple:
|
210 |
+
if len(prompts) > 1:
|
211 |
+
raise ValueError(
|
212 |
+
f"OpenAIChat currently only supports single prompt, got {prompts}"
|
213 |
+
)
|
214 |
+
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
|
215 |
+
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
216 |
+
if stop is not None:
|
217 |
+
if "stop" in params:
|
218 |
+
raise ValueError("`stop` found in both the input and default params.")
|
219 |
+
params["stop"] = stop
|
220 |
+
if params.get("max_tokens") == -1:
|
221 |
+
# for ChatGPT api, omitting max_tokens is equivalent to having no limit
|
222 |
+
del params["max_tokens"]
|
223 |
+
return messages, params
|
224 |
+
|
225 |
+
def _generate(
|
226 |
+
self,
|
227 |
+
prompts: List[str],
|
228 |
+
stop: Optional[List[str]] = None,
|
229 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
230 |
+
) -> LLMResult:
|
231 |
+
messages, params = self._get_chat_params(prompts, stop)
|
232 |
+
if self.generate_verbose:
|
233 |
+
logger.debug(json.dumps(params, indent=2))
|
234 |
+
for msg in messages:
|
235 |
+
logger.debug(msg["role"] + " : " + msg["content"])
|
236 |
+
resp = self.api.create_chat_completion_response_data(messages, self.model_name, self.model_kwargs['temperature'])
|
237 |
+
full_response = resp
|
238 |
+
llm_output = {
|
239 |
+
"token_usage": full_response["usage"],
|
240 |
+
"model_name": self.model_name,
|
241 |
+
}
|
242 |
+
return LLMResult(
|
243 |
+
generations=[
|
244 |
+
[Generation(text=full_response["choices"][0]["message"]["content"])]
|
245 |
+
],
|
246 |
+
llm_output=llm_output,
|
247 |
+
)
|
248 |
+
|
249 |
+
async def _agenerate(
|
250 |
+
self,
|
251 |
+
prompts: List[str],
|
252 |
+
stop: Optional[List[str]] = None,
|
253 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
254 |
+
) -> LLMResult:
|
255 |
+
# messages, params = self._get_chat_params(prompts, stop)
|
256 |
+
# full_response = await acompletion_with_retry(
|
257 |
+
# self, messages=messages, **params
|
258 |
+
# )
|
259 |
+
# llm_output = {
|
260 |
+
# "token_usage": full_response["usage"],
|
261 |
+
# "model_name": self.model_name,
|
262 |
+
# }
|
263 |
+
# return LLMResult(
|
264 |
+
# generations=[
|
265 |
+
# [Generation(text=full_response["choices"][0]["message"]["content"])]
|
266 |
+
# ],
|
267 |
+
# llm_output=llm_output,
|
268 |
+
# )
|
269 |
+
raise NotImplementedError("Async not supported for OpenAIChat")
|
270 |
+
|
271 |
+
@property
|
272 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
273 |
+
"""Get the identifying parameters."""
|
274 |
+
return {**{"model_name": self.model_name}, **self._default_params}
|
275 |
+
|
276 |
+
@property
|
277 |
+
def _llm_type(self) -> str:
|
278 |
+
"""Return type of llm."""
|
279 |
+
return "openai-chat"
|
280 |
+
|
281 |
+
def get_num_tokens(self, text: str) -> int:
|
282 |
+
"""Calculate num tokens with tiktoken package."""
|
283 |
+
# tiktoken NOT supported for Python < 3.8
|
284 |
+
if sys.version_info[1] < 8:
|
285 |
+
return super().get_num_tokens(text)
|
286 |
+
try:
|
287 |
+
import tiktoken
|
288 |
+
except ImportError:
|
289 |
+
raise ValueError(
|
290 |
+
"Could not import tiktoken python package. "
|
291 |
+
"This is needed in order to calculate get_num_tokens. "
|
292 |
+
"Please install it with `pip install tiktoken`."
|
293 |
+
)
|
294 |
+
# create a GPT-3.5-Turbo encoder instance
|
295 |
+
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
296 |
+
|
297 |
+
# encode the text using the GPT-3.5-Turbo encoder
|
298 |
+
tokenized_text = enc.encode(
|
299 |
+
text,
|
300 |
+
allowed_special=self.allowed_special,
|
301 |
+
disallowed_special=self.disallowed_special,
|
302 |
+
)
|
303 |
+
|
304 |
+
# calculate the number of tokens in the encoded text
|
305 |
+
return len(tokenized_text)
|
306 |
+
|
307 |
+
class ChatSession:
|
308 |
+
def __init__(self, prompt: str = '', chatgpt_config: ChatGPTConfig = ChatGPTConfig()) -> None:
|
309 |
+
self.chatgpt_config = chatgpt_config.__dict__
|
310 |
+
self.user_id, self.token = self.get_userid_and_token()
|
311 |
+
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-0301")
|
312 |
+
self.count = lambda x: len(encoding.encode(x))
|
313 |
+
self.history = []
|
314 |
+
self.system = [self.make_msg("system", prompt)] if prompt else []
|
315 |
+
|
316 |
+
def restart(self, prompt: str = '') -> None:
|
317 |
+
self.system = [self.make_msg("system", prompt)] if prompt else []
|
318 |
+
|
319 |
+
@staticmethod
|
320 |
+
def make_msg(role: str, msg: str) -> Dict:
|
321 |
+
assert role in {"system", "assistant", "user"}
|
322 |
+
return {"role": role, "content": msg}
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def get_userid_and_token(
|
326 |
+
url='http://avatar.aicubes.cn/vtuber/auth/api/oauth/v1/login',
|
327 |
+
app_id='6027294018fd496693d0b8c77e2d20a1',
|
328 |
+
app_secret='52806a6fff8a452497061b9dcc5779f4'
|
329 |
+
):
|
330 |
+
d = {'app_id': app_id, 'app_secret': app_secret}
|
331 |
+
h = {'Content-Type': 'application/json'}
|
332 |
+
r = requests.post(url, json=d, headers=h)
|
333 |
+
data = r.json()['data']
|
334 |
+
return data['user_id'], data['token']
|
335 |
+
|
336 |
+
def make_chat_session(self, user_id: str, token: str, input_message: List[Dict[str, str]]):
|
337 |
+
chat_h = {
|
338 |
+
'Content-Type': 'application/json',
|
339 |
+
'userId': user_id,
|
340 |
+
'token': token
|
341 |
+
}
|
342 |
+
chat_url = 'http://avatar.aicubes.cn/vtuber/ai_access/chatgpt/v1/chat/completions'
|
343 |
+
res = requests.post(chat_url, json={
|
344 |
+
'messages': input_message, **self.chatgpt_config
|
345 |
+
}, headers=chat_h, timeout=self.chatgpt_config['timeout'])
|
346 |
+
return res.json()['data']['choices'][0]['message']['content']
|
347 |
+
|
348 |
+
def create_chat_completion(self, messages: List[Dict[str, str]], model: str, temperature: float, max_tokens=None) -> str:
|
349 |
+
chat_url = 'http://avatar.aicubes.cn/vtuber/ai_access/chatgpt/v1/chat/completions'
|
350 |
+
chat_header = {
|
351 |
+
'Content-Type': 'application/json',
|
352 |
+
'userId': self.user_id,
|
353 |
+
'token': self.token
|
354 |
+
}
|
355 |
+
payload = {
|
356 |
+
'model': model,
|
357 |
+
'messages': messages,
|
358 |
+
'temperature': temperature,
|
359 |
+
'max_tokens': max_tokens,
|
360 |
+
}
|
361 |
+
timeout = self.chatgpt_config['timeout']
|
362 |
+
res = requests.post(chat_url, json=payload, headers=chat_header, timeout=timeout)
|
363 |
+
return res.json()['data']['choices'][0]['message']['content']
|
364 |
+
|
365 |
+
def chat(self, msg: str):
|
366 |
+
self.history.append(self.make_msg("user", msg))
|
367 |
+
init_tokenCnt = self.count(self.system[0]['content']) if self.system else 0
|
368 |
+
inputStaMsgIdx, tokenCnt = len(self.history), init_tokenCnt
|
369 |
+
while inputStaMsgIdx and (
|
370 |
+
tokenCnt := tokenCnt + self.count(self.history[inputStaMsgIdx - 1]['content'])) < \
|
371 |
+
self.chatgpt_config['max_in_tokens']:
|
372 |
+
inputStaMsgIdx -= 1
|
373 |
+
inputStaMsgIdx = inputStaMsgIdx if inputStaMsgIdx < len(self.history) else -1
|
374 |
+
res = self.make_chat_session(self.user_id, self.token, self.system + self.history[inputStaMsgIdx:])
|
375 |
+
self.history.append(self.make_msg("assistant", res))
|
376 |
+
return res
|
377 |
+
|
378 |
+
|
379 |
+
def batch_chat(info_lst: List, request_num: int = 6) -> List:
|
380 |
+
res = []
|
381 |
+
pool = mp.Pool(processes=request_num)
|
382 |
+
for id, res_text in tqdm(pool.imap(single_chat, info_lst), desc="Asking API", total=len(info_lst)):
|
383 |
+
if res_text:
|
384 |
+
res.append((id, res_text))
|
385 |
+
|
386 |
+
return res
|
387 |
+
|
388 |
+
|
389 |
+
def single_chat(info: Dict) -> (int, str):
|
390 |
+
sess = ChatSession(info['sys'], info['config'])
|
391 |
+
try:
|
392 |
+
res = sess.chat(info['query'])
|
393 |
+
return info['id'], res
|
394 |
+
except Exception as e:
|
395 |
+
print(e)
|
396 |
+
return info['id'], ""
|
397 |
+
|
398 |
+
|
399 |
+
if __name__ == '__main__':
|
400 |
+
|
401 |
+
sys_prompt = """
|
402 |
+
ไฝ ๆฏไธไฝไธฅๆ ผ็่ฏๅๅ๏ผๆไผ็ปไฝ ไธไธชๆไปคๅ่ฟไธชๆไปค็ๅๅค๏ผไฝ ้่ฆไป็ปๆฃๆฅๅๅคๅนถ็ปๅบๅๆฐ๏ผไฝ ๅฏไปฅไปๅคไธช่งๅบฆ่ฏๅค่ฟไธชๅๅค๏ผๆฏๅฆ๏ผ
|
403 |
+
ๅๅคๆฏๅฆๅ็กฎใๆฏๅฆ่ฏฆๅฐฝใๆฏๅฆๆ ๅฎณใๆฏๅฆๅฎๅ
จ็ฌฆๅๆไปค้็่ฆๆฑ๏ผ็ญ็ญใๅๆฐๅไธบ5ไธช็ญ็บง๏ผ1ๅ๏ผๅฎๅ
จไธๅฏ็จ๏ผ2ๅ๏ผไธๅฏ็จไฝๅฎๆไบ้จๅๆไปค๏ผ
|
404 |
+
3ๅ๏ผๅฏ็จไฝๆๆๆพ็ผบ้ท๏ผ4ๅ๏ผๅฏ็จไฝๆๅฐ่ฎธ็ผบ้ท๏ผ5ๅ๏ผๅฏ็จไธๆฒกๆ็ผบ้ทใไฝ ๅจๅทฅไฝๆถ้่ฆๅ ๅ
ฅ่ชๅทฑ็ๆ่๏ผๅนถๅจๆๅ็ปๅบๅๆฐใ
|
405 |
+
ไธ้ขๆฏไธไธชไพๅญ๏ผ
|
406 |
+
User: \n\n<ๆไปค>้ฉฌไบ็ๅฆปๅญๆฏ่ฐ?</ๆไปค>\n\n<ๅๅค>้ฉฌไบ็ๅฆปๅญๆฏๅผ ่ฑ็ชใ</ๅๅค>
|
407 |
+
Assistant: ่ฟไธชๅๅค้่ฏฏ๏ผ้ฉฌไบๆฏ้ฟ้ๅทดๅทดๅๅงไบบ๏ผไป็ๅฆปๅญๆฏๅผ ็๏ผๅ ๆญคๅๅค้่ฏฏ๏ผๅ ๆญค๏ผๆ็ๅๆฐๆฏ[1ๅ]ใ
|
408 |
+
"""
|
409 |
+
|
410 |
+
aaa = """
|
411 |
+
fq(xm, m) = (Wqxm)e^(imฮธ)
|
412 |
+
fk(xn, n) = (Wkxn)e^(inฮธ)
|
413 |
+
g(xm, xn, m โ n) = Re[(Wqxm)(Wkxn)โe^(i(mโn)ฮธ)]
|
414 |
+
"""
|
415 |
+
prompt = 'User: \n\n<ๆไปค>ๅงๆๅค้ซ</ๆไปค>\n\n<ๅๅค>18m</ๅๅค>\nAssistant:'
|
416 |
+
bbb = "The given equation defines a function g(xm, xn, m-n) in terms of two complex functions fq(xm, m) and fk(xn, n) and their corresponding Fourier coefficients Wq and Wk, respectively. The function g(xm, xn, m-n) takes the real part of the product of the two complex exponential terms with phase angles m-theta and n-theta, respectively, where theta is an arbitrary constant angle. The term (m-n)theta in the exponent indicates that the two exponential terms are shifted by a phase difference of (m-n)theta."
|
417 |
+
session = ChatSession('่งฃ้ๅ
ฌๅผ็ๅซไน')
|
418 |
+
# print(session.chat(aaa))
|
419 |
+
print(session.chat("ไฝ ๆฏ่ฐ๏ผ่ฐๅ้ ไบไฝ ๏ผไฝ ็็ฅ่ฏๆชๆญขไบไปไนๆถๅ๏ผไฝ ๅฏไปฅ็ป่ชๅทฑๅไธไธชๅๅญ๏ผ่ฏทๅ่ฏๆไฝ ็ๅๅญ"))
|
420 |
+
|
app copy.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import huggingface_hub
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
import shutil
|
6 |
+
import os
|
7 |
+
import datetime
|
8 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
9 |
+
|
10 |
+
|
11 |
+
DB_FILE = "./app.db"
|
12 |
+
|
13 |
+
TOKEN = os.environ.get('HUB_TOKEN')
|
14 |
+
repo = huggingface_hub.Repository(
|
15 |
+
local_dir="data",
|
16 |
+
repo_type="dataset",
|
17 |
+
clone_from="linxy/oh-my-words",
|
18 |
+
use_auth_token=TOKEN
|
19 |
+
)
|
20 |
+
repo.git_pull()
|
21 |
+
|
22 |
+
# Set db to latest
|
23 |
+
shutil.copyfile("./data/app.db", DB_FILE)
|
24 |
+
|
25 |
+
|
26 |
+
# Create table if it doesn't already exist
|
27 |
+
|
28 |
+
db = sqlite3.connect(DB_FILE)
|
29 |
+
try:
|
30 |
+
db.execute("SELECT * FROM reviews").fetchall()
|
31 |
+
db.close()
|
32 |
+
except sqlite3.OperationalError:
|
33 |
+
db.execute(
|
34 |
+
'''
|
35 |
+
CREATE TABLE reviews (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
36 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
37 |
+
name TEXT, review INTEGER, comments TEXT)
|
38 |
+
''')
|
39 |
+
db.commit()
|
40 |
+
db.close()
|
41 |
+
|
42 |
+
|
43 |
+
def get_latest_reviews(db: sqlite3.Connection):
|
44 |
+
reviews = db.execute("SELECT * FROM reviews ORDER BY id DESC limit 10").fetchall()
|
45 |
+
total_reviews = db.execute("Select COUNT(id) from reviews").fetchone()[0]
|
46 |
+
reviews = pd.DataFrame(reviews, columns=["id", "date_created", "name", "review", "comments"])
|
47 |
+
return reviews, total_reviews
|
48 |
+
|
49 |
+
|
50 |
+
def add_review(name: str, review: int, comments: str):
|
51 |
+
db = sqlite3.connect(DB_FILE)
|
52 |
+
cursor = db.cursor()
|
53 |
+
cursor.execute("INSERT INTO reviews(name, review, comments) VALUES(?,?,?)", [name, review, comments])
|
54 |
+
db.commit()
|
55 |
+
reviews, total_reviews = get_latest_reviews(db)
|
56 |
+
db.close()
|
57 |
+
return reviews, total_reviews
|
58 |
+
|
59 |
+
def load_data():
|
60 |
+
db = sqlite3.connect(DB_FILE)
|
61 |
+
reviews, total_reviews = get_latest_reviews(db)
|
62 |
+
db.close()
|
63 |
+
return reviews, total_reviews
|
64 |
+
|
65 |
+
|
66 |
+
with gr.Blocks() as demo:
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column():
|
69 |
+
name = gr.Textbox(label="Name", placeholder="What is your name?")
|
70 |
+
review = gr.Radio(label="How satisfied are you with using gradio?", choices=[1, 2, 3, 4, 5])
|
71 |
+
comments = gr.Textbox(label="Comments", lines=10, placeholder="Do you have any feedback on gradio?")
|
72 |
+
submit = gr.Button(value="Submit Feedback")
|
73 |
+
with gr.Column():
|
74 |
+
with gr.Box():
|
75 |
+
gr.Markdown("Most recently created 10 rows: See full dataset [here](https://huggingface.co/datasets/freddyaboulton/gradio-reviews)")
|
76 |
+
data = gr.Dataframe()
|
77 |
+
count = gr.Number(label="Total number of reviews")
|
78 |
+
submit.click(add_review, [name, review, comments], [data, count])
|
79 |
+
demo.load(load_data, None, [data, count])
|
80 |
+
|
81 |
+
|
82 |
+
def backup_db():
|
83 |
+
shutil.copyfile(DB_FILE, "./data/reviews.db")
|
84 |
+
db = sqlite3.connect(DB_FILE)
|
85 |
+
reviews = db.execute("SELECT * FROM reviews").fetchall()
|
86 |
+
pd.DataFrame(reviews).to_csv("./data/reviews.csv", index=False)
|
87 |
+
print("updating db")
|
88 |
+
repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.datetime.now()}")
|
89 |
+
|
90 |
+
|
91 |
+
scheduler = BackgroundScheduler()
|
92 |
+
scheduler.add_job(func=backup_db, trigger="interval", seconds=60)
|
93 |
+
scheduler.start()
|
94 |
+
|
95 |
+
|
96 |
+
demo.launch()
|
app.py
CHANGED
@@ -6,9 +6,10 @@ import shutil
|
|
6 |
import os
|
7 |
import datetime
|
8 |
from apscheduler.schedulers.background import BackgroundScheduler
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
12 |
|
13 |
TOKEN = os.environ.get('HUB_TOKEN')
|
14 |
repo = huggingface_hub.Repository(
|
@@ -18,73 +19,20 @@ repo = huggingface_hub.Repository(
|
|
18 |
use_auth_token=TOKEN
|
19 |
)
|
20 |
repo.git_pull()
|
21 |
-
|
22 |
# Set db to latest
|
23 |
-
shutil.copyfile(
|
24 |
-
|
25 |
-
|
26 |
-
# Create table if it doesn't already exist
|
27 |
-
|
28 |
-
db = sqlite3.connect(DB_FILE)
|
29 |
-
try:
|
30 |
-
db.execute("SELECT * FROM reviews").fetchall()
|
31 |
-
db.close()
|
32 |
-
except sqlite3.OperationalError:
|
33 |
-
db.execute(
|
34 |
-
'''
|
35 |
-
CREATE TABLE reviews (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
36 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
37 |
-
name TEXT, review INTEGER, comments TEXT)
|
38 |
-
''')
|
39 |
-
db.commit()
|
40 |
-
db.close()
|
41 |
-
|
42 |
-
|
43 |
-
def get_latest_reviews(db: sqlite3.Connection):
|
44 |
-
reviews = db.execute("SELECT * FROM reviews ORDER BY id DESC limit 10").fetchall()
|
45 |
-
total_reviews = db.execute("Select COUNT(id) from reviews").fetchone()[0]
|
46 |
-
reviews = pd.DataFrame(reviews, columns=["id", "date_created", "name", "review", "comments"])
|
47 |
-
return reviews, total_reviews
|
48 |
-
|
49 |
-
|
50 |
-
def add_review(name: str, review: int, comments: str):
|
51 |
-
db = sqlite3.connect(DB_FILE)
|
52 |
-
cursor = db.cursor()
|
53 |
-
cursor.execute("INSERT INTO reviews(name, review, comments) VALUES(?,?,?)", [name, review, comments])
|
54 |
-
db.commit()
|
55 |
-
reviews, total_reviews = get_latest_reviews(db)
|
56 |
-
db.close()
|
57 |
-
return reviews, total_reviews
|
58 |
-
|
59 |
-
def load_data():
|
60 |
-
db = sqlite3.connect(DB_FILE)
|
61 |
-
reviews, total_reviews = get_latest_reviews(db)
|
62 |
-
db.close()
|
63 |
-
return reviews, total_reviews
|
64 |
-
|
65 |
-
|
66 |
-
with gr.Blocks() as demo:
|
67 |
-
with gr.Row():
|
68 |
-
with gr.Column():
|
69 |
-
name = gr.Textbox(label="Name", placeholder="What is your name?")
|
70 |
-
review = gr.Radio(label="How satisfied are you with using gradio?", choices=[1, 2, 3, 4, 5])
|
71 |
-
comments = gr.Textbox(label="Comments", lines=10, placeholder="Do you have any feedback on gradio?")
|
72 |
-
submit = gr.Button(value="Submit Feedback")
|
73 |
-
with gr.Column():
|
74 |
-
with gr.Box():
|
75 |
-
gr.Markdown("Most recently created 10 rows: See full dataset [here](https://huggingface.co/datasets/freddyaboulton/gradio-reviews)")
|
76 |
-
data = gr.Dataframe()
|
77 |
-
count = gr.Number(label="Total number of reviews")
|
78 |
-
submit.click(add_review, [name, review, comments], [data, count])
|
79 |
-
demo.load(load_data, None, [data, count])
|
80 |
-
|
81 |
|
82 |
def backup_db():
|
83 |
-
shutil.copyfile(DB_FILE,
|
84 |
-
db = sqlite3.connect(DB_FILE)
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.datetime.now()}")
|
89 |
|
90 |
|
@@ -92,5 +40,14 @@ scheduler = BackgroundScheduler()
|
|
92 |
scheduler.add_job(func=backup_db, trigger="interval", seconds=60)
|
93 |
scheduler.start()
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
|
|
|
|
|
6 |
import os
|
7 |
import datetime
|
8 |
from apscheduler.schedulers.background import BackgroundScheduler
|
9 |
+
from database.database import DATABASE_FILE as DB_FILE
|
10 |
+
from web import demo
|
11 |
+
from loguru import logger
|
12 |
+
from common.util import date_str
|
13 |
|
14 |
TOKEN = os.environ.get('HUB_TOKEN')
|
15 |
repo = huggingface_hub.Repository(
|
|
|
19 |
use_auth_token=TOKEN
|
20 |
)
|
21 |
repo.git_pull()
|
22 |
+
DATASET_FILE = f"./data/{DB_FILE}"
|
23 |
# Set db to latest
|
24 |
+
shutil.copyfile(DATASET_FILE, DB_FILE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def backup_db():
|
27 |
+
shutil.copyfile(DB_FILE, DATASET_FILE)
|
28 |
+
# db = sqlite3.connect(DB_FILE)
|
29 |
+
# pd.DataFrame(db.execute("SELECT * FROM words").fetchall()).to_csv("./data/words.csv", index=False)
|
30 |
+
# print("save word.csv")
|
31 |
+
# pd.DataFrame(db.execute("SELECT * FROM book").fetchall()).to_csv("./data/book.csv", index=False)
|
32 |
+
# print("save book.csv")
|
33 |
+
# pd.DataFrame(db.execute("SELECT * FROM unit").fetchall()).to_csv("./data/unit.csv", index=False)
|
34 |
+
# print("save unit.csv")
|
35 |
+
# db.close()
|
36 |
repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.datetime.now()}")
|
37 |
|
38 |
|
|
|
40 |
scheduler.add_job(func=backup_db, trigger="interval", seconds=60)
|
41 |
scheduler.start()
|
42 |
|
43 |
+
# def load_data():
|
44 |
+
# db = sqlite3.connect(DB_FILE)
|
45 |
+
# reviews, total_reviews = get_latest_reviews(db)
|
46 |
+
# db.close()
|
47 |
+
# return reviews, total_reviews
|
48 |
+
|
49 |
+
# demo.load(load_data, None, [data, count])
|
50 |
|
51 |
+
if __name__ == "__main__":
|
52 |
+
logger.add(f"output/logs/web_{date_str}.log", rotation="1 day", retention="7 days", level="INFO")
|
53 |
+
demo.launch()
|
common/util.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from loguru import logger
|
3 |
+
from tqdm import tqdm
|
4 |
+
import pandas as pd
|
5 |
+
import datetime
|
6 |
+
|
7 |
+
import multiprocessing
|
8 |
+
from multiprocessing import Pool
|
9 |
+
cpu_num = multiprocessing.cpu_count()
|
10 |
+
logger.info(f"cpu_num: {cpu_num}")
|
11 |
+
|
12 |
+
|
13 |
+
date_str = datetime.datetime.now().strftime("%Y%m%d_%Hh%Mm%Ss")
|
14 |
+
|
15 |
+
|
16 |
+
def multiprocessing_mapping(
|
17 |
+
mapping_func,
|
18 |
+
items: List[Any],
|
19 |
+
batch_size=1000,
|
20 |
+
tmp_filepath=f"./output/multiprocessing_mapping_{date_str}_tmp.xlsx",
|
21 |
+
):
|
22 |
+
pool = Pool(processes=cpu_num)
|
23 |
+
total_rows: List[Dict[str, str]] = []
|
24 |
+
for i in tqdm(range(0, len(items), batch_size)):
|
25 |
+
new_rows = pool.map(mapping_func, items[i:i+batch_size])
|
26 |
+
total_rows += new_rows
|
27 |
+
df = pd.DataFrame(total_rows)
|
28 |
+
df.to_excel(tmp_filepath, index=False)
|
29 |
+
pool.close()
|
30 |
+
pool.join()
|
31 |
+
return total_rows
|
database/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .database import SessionLocal, engine, Base
|
2 |
+
from .schema import *
|
3 |
+
from .operation import *
|
4 |
+
|
5 |
+
|
6 |
+
def get_db():
|
7 |
+
db = SessionLocal()
|
8 |
+
try:
|
9 |
+
yield db
|
10 |
+
finally:
|
11 |
+
db.close()
|
database/constant.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
email = "[email protected]"
|
2 |
+
password = "123456"
|
database/database.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import URL, create_engine
|
2 |
+
from sqlalchemy.ext.declarative import declarative_base
|
3 |
+
from sqlalchemy.orm import sessionmaker
|
4 |
+
|
5 |
+
DATABASE_FILE = "app.db"
|
6 |
+
SQLALCHEMY_DATABASE_URL = f"sqlite:///./{DATABASE_FILE}"
|
7 |
+
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
8 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
9 |
+
|
10 |
+
Base = declarative_base()
|
database/model.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class ItemBase(BaseModel):
|
7 |
+
title: str
|
8 |
+
description: Union[str, None] = None
|
9 |
+
|
10 |
+
|
11 |
+
class ItemCreate(ItemBase):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class Item(ItemBase):
|
16 |
+
id: int
|
17 |
+
owner_id: int
|
18 |
+
|
19 |
+
class Config:
|
20 |
+
orm_mode = True
|
21 |
+
|
22 |
+
|
23 |
+
class UserBase(BaseModel):
|
24 |
+
email: str
|
25 |
+
|
26 |
+
|
27 |
+
class UserCreate(UserBase):
|
28 |
+
password: str
|
29 |
+
|
30 |
+
|
31 |
+
class User(UserBase):
|
32 |
+
id: int
|
33 |
+
is_active: bool
|
34 |
+
items: List[Item] = []
|
35 |
+
|
36 |
+
class Config:
|
37 |
+
orm_mode = True
|
database/operation.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from sqlalchemy.orm import Session
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import List, Optional, Tuple, Dict
|
5 |
+
from . import schema
|
6 |
+
from sqlalchemy import func, or_
|
7 |
+
from sqlalchemy.orm import aliased
|
8 |
+
from sqlalchemy.orm import Query
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
|
12 |
+
# # ๅๅปบไธไธชๅซๅ็Address
|
13 |
+
# AddressAlias = aliased(Address)
|
14 |
+
|
15 |
+
# # ๅๅปบไธไธชๅญๆฅ่ฏข
|
16 |
+
# subquery = session.query(
|
17 |
+
# func.count(AddressAlias.id).label('address_count'),
|
18 |
+
# AddressAlias.user_id
|
19 |
+
# ).group_by(AddressAlias.user_id).subquery()
|
20 |
+
|
21 |
+
# # ไฝฟ็จๅญๆฅ่ฏขๅ่ฟๆฅๆฅ่ฏข
|
22 |
+
# users = session.query(User, subquery.c.address_count).\
|
23 |
+
# outerjoin(subquery, User.id == subquery.c.user_id).\
|
24 |
+
# order_by(User.id).all()
|
25 |
+
|
26 |
+
# for user, address_count in users:
|
27 |
+
# print(f'User {user.name} has {address_count} addresses.')
|
28 |
+
|
29 |
+
# region User
|
30 |
+
class UserBase(BaseModel):
|
31 |
+
email: str
|
32 |
+
is_active: bool
|
33 |
+
|
34 |
+
class UserCreate(UserBase):
|
35 |
+
pass
|
36 |
+
|
37 |
+
class UserUpdate(UserBase):
|
38 |
+
hashed_password: Optional[str] # This is optional because you may not always want to update the password
|
39 |
+
|
40 |
+
class User(UserBase):
|
41 |
+
id: str
|
42 |
+
|
43 |
+
def create_user(db: Session, email: str, password: str):
|
44 |
+
db_user = schema.User(email=email, is_active=True)
|
45 |
+
db_user.set_password(password)
|
46 |
+
db.add(db_user)
|
47 |
+
db.commit()
|
48 |
+
db.refresh(db_user)
|
49 |
+
return db_user
|
50 |
+
|
51 |
+
def get_user(db: Session, user_id: str) -> Optional[schema.User]:
|
52 |
+
return db.query(schema.User).filter(schema.User.id == user_id).first()
|
53 |
+
|
54 |
+
def get_user_by_email(db: Session, email: str) -> Optional[schema.User]:
|
55 |
+
return db.query(schema.User).filter(schema.User.email == email).first()
|
56 |
+
|
57 |
+
def get_users(db: Session, skip: int = 0, limit: int = 10) -> List[schema.User]:
|
58 |
+
return db.query(schema.User).offset(skip).limit(limit).all()
|
59 |
+
|
60 |
+
def get_all_users(db: Session) -> List[schema.User]:
|
61 |
+
return db.query(schema.User).all()
|
62 |
+
|
63 |
+
def update_user(db: Session, user_id: str, user: UserUpdate):
|
64 |
+
db_user = db.query(schema.User).filter(schema.User.id == user_id).first()
|
65 |
+
if db_user:
|
66 |
+
for key, value in user.dict(exclude_unset=True).items():
|
67 |
+
setattr(db_user, key, value)
|
68 |
+
db.commit()
|
69 |
+
db.refresh(db_user)
|
70 |
+
return db_user
|
71 |
+
|
72 |
+
def delete_user(db: Session, user_id: str):
|
73 |
+
db_user = db.query(schema.User).filter(schema.User.id == user_id).first()
|
74 |
+
if db_user:
|
75 |
+
db.delete(db_user)
|
76 |
+
db.commit()
|
77 |
+
return db_user
|
78 |
+
# endregion
|
79 |
+
|
80 |
+
# region Item
|
81 |
+
class ItemBase(BaseModel):
|
82 |
+
title: str
|
83 |
+
description: str
|
84 |
+
|
85 |
+
class ItemCreate(ItemBase):
|
86 |
+
pass
|
87 |
+
|
88 |
+
class ItemUpdate(ItemBase):
|
89 |
+
pass
|
90 |
+
|
91 |
+
class Item(ItemBase):
|
92 |
+
id: str
|
93 |
+
owner_id: str
|
94 |
+
|
95 |
+
def create_item(db: Session, item: ItemCreate, owner_id: str):
|
96 |
+
db_item = schema.Item(**item.dict(), owner_id=owner_id)
|
97 |
+
db.add(db_item)
|
98 |
+
db.commit()
|
99 |
+
db.refresh(db_item)
|
100 |
+
return db_item
|
101 |
+
|
102 |
+
def get_item(db: Session, item_id: str):
|
103 |
+
return db.query(schema.Item).filter(schema.Item.id == item_id).first()
|
104 |
+
|
105 |
+
def get_items(db: Session, skip: int = 0, limit: int = 10):
|
106 |
+
return db.query(schema.Item).offset(skip).limit(limit).all()
|
107 |
+
|
108 |
+
def update_item(db: Session, item_id: str, item: ItemUpdate):
|
109 |
+
db_item = db.query(schema.Item).filter(schema.Item.id == item_id).first()
|
110 |
+
if db_item:
|
111 |
+
for key, value in item.dict(exclude_unset=True).items():
|
112 |
+
setattr(db_item, key, value)
|
113 |
+
db.commit()
|
114 |
+
db.refresh(db_item)
|
115 |
+
return db_item
|
116 |
+
|
117 |
+
def delete_item(db: Session, item_id: str):
|
118 |
+
db_item = db.query(schema.Item).filter(schema.Item.id == item_id).first()
|
119 |
+
if db_item:
|
120 |
+
db.delete(db_item)
|
121 |
+
db.commit()
|
122 |
+
return db_item
|
123 |
+
|
124 |
+
# endregion
|
125 |
+
|
126 |
+
# region UserBook
|
127 |
+
class UserBookBase(BaseModel):
|
128 |
+
owner_id: str
|
129 |
+
book_id: str
|
130 |
+
title: str
|
131 |
+
random: bool
|
132 |
+
batch_size: int
|
133 |
+
memorizing_batch: str = ""
|
134 |
+
|
135 |
+
class UserBookCreate(UserBookBase):
|
136 |
+
pass
|
137 |
+
|
138 |
+
class UserBookUpdate(UserBookBase):
|
139 |
+
pass
|
140 |
+
|
141 |
+
class UserBook(UserBookBase):
|
142 |
+
id: str
|
143 |
+
|
144 |
+
def create_user_book(db: Session, user_book: UserBookCreate):
|
145 |
+
db_user_book = schema.UserBook(**user_book.dict())
|
146 |
+
db.add(db_user_book)
|
147 |
+
db.commit()
|
148 |
+
db.refresh(db_user_book)
|
149 |
+
return db_user_book
|
150 |
+
|
151 |
+
def get_user_book(db: Session, user_book_id: str) -> schema.UserBook:
|
152 |
+
return db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
|
153 |
+
|
154 |
+
def get_user_books_by_owner_id(db: Session, owner_id: str) -> List[schema.UserBook]:
|
155 |
+
return db.query(schema.UserBook).filter(schema.UserBook.owner_id == owner_id).all()
|
156 |
+
|
157 |
+
def get_user_books(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserBook]:
|
158 |
+
return db.query(schema.UserBook).offset(skip).limit(limit).all()
|
159 |
+
|
160 |
+
def update_user_book(db: Session, user_book_id: str, user_book: UserBookUpdate):
|
161 |
+
db_user_book = db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
|
162 |
+
if db_user_book:
|
163 |
+
for key, value in user_book.dict(exclude_unset=True).items():
|
164 |
+
setattr(db_user_book, key, value)
|
165 |
+
db.commit()
|
166 |
+
db.refresh(db_user_book)
|
167 |
+
return db_user_book
|
168 |
+
|
169 |
+
def update_user_book_memorizing_batch(db: Session, user_book_id: str, memorizing_batch: str):
|
170 |
+
db_user_book = db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
|
171 |
+
if db_user_book:
|
172 |
+
db_user_book.memorizing_batch = memorizing_batch
|
173 |
+
db.commit()
|
174 |
+
db.refresh(db_user_book)
|
175 |
+
return db_user_book
|
176 |
+
|
177 |
+
def delete_user_book(db: Session, user_book_id: str):
|
178 |
+
db_user_book = db.query(schema.UserBook).filter(schema.UserBook.id == user_book_id).first()
|
179 |
+
if db_user_book:
|
180 |
+
db.delete(db_user_book)
|
181 |
+
db.commit()
|
182 |
+
return db_user_book
|
183 |
+
|
184 |
+
# endregion
|
185 |
+
|
186 |
+
# region UserMemoryBatch
|
187 |
+
class UserMemoryBatchBase(BaseModel):
|
188 |
+
user_book_id: str
|
189 |
+
story: str
|
190 |
+
translated_story: str
|
191 |
+
batch_type: str = "ๆฐ่ฏ"
|
192 |
+
|
193 |
+
class UserMemoryBatchCreate(UserMemoryBatchBase):
|
194 |
+
pass
|
195 |
+
|
196 |
+
class UserMemoryBatchUpdate(UserMemoryBatchBase):
|
197 |
+
pass
|
198 |
+
|
199 |
+
class UserMemoryBatch(UserMemoryBatchBase):
|
200 |
+
id: str
|
201 |
+
|
202 |
+
def create_user_memory_batch(db: Session, memory_batch: UserMemoryBatchCreate):
|
203 |
+
db_memory_batch = schema.UserMemoryBatch(**memory_batch.dict())
|
204 |
+
db.add(db_memory_batch)
|
205 |
+
db.commit()
|
206 |
+
db.refresh(db_memory_batch)
|
207 |
+
return db_memory_batch
|
208 |
+
|
209 |
+
def get_user_memory_batch(db: Session, memory_batch_id: str):
|
210 |
+
return db.query(schema.UserMemoryBatch).filter(schema.UserMemoryBatch.id == memory_batch_id).first()
|
211 |
+
|
212 |
+
def get_user_memory_batchs(db: Session, skip: int = 0, limit: int = 10):
|
213 |
+
return db.query(schema.UserMemoryBatch).offset(skip).limit(limit).all()
|
214 |
+
|
215 |
+
def get_user_memory_batches_by_user_book_id(db: Session, user_book_id: str) -> List[schema.UserMemoryBatch]:
|
216 |
+
return db.query(schema.UserMemoryBatch).filter(
|
217 |
+
schema.UserMemoryBatch.user_book_id == user_book_id
|
218 |
+
).order_by(schema.UserMemoryBatch.create_time).all()
|
219 |
+
|
220 |
+
def get_new_user_memory_batches_by_user_book_id(db: Session, user_book_id: str) -> List[schema.UserMemoryBatch]:
|
221 |
+
return db.query(schema.UserMemoryBatch).filter(
|
222 |
+
schema.UserMemoryBatch.user_book_id == user_book_id,
|
223 |
+
schema.UserMemoryBatch.batch_type == "ๆฐ่ฏ",
|
224 |
+
).order_by(schema.UserMemoryBatch.create_time).all()
|
225 |
+
|
226 |
+
def actions_infomation(db: Session, action_query: Query[schema.UserMemoryBatchAction]):
|
227 |
+
distinct_actions = action_query.distinct().subquery()
|
228 |
+
batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
|
229 |
+
batch_id_to_batch = {batch.id: batch for batch in batches}
|
230 |
+
batch_id_to_words = {batch.id: get_words_in_batch(db, batch.id) for batch in batches}
|
231 |
+
return batches, batch_id_to_batch, batch_id_to_words
|
232 |
+
|
233 |
+
def get_user_memory_batch_history(db: Session, user_book_id: str):
|
234 |
+
action_query = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.action == "end").join(
|
235 |
+
schema.UserMemoryBatch, schema.UserMemoryBatch.id == schema.UserMemoryBatchAction.batch_id
|
236 |
+
).filter(schema.UserMemoryBatch.user_book_id == user_book_id)
|
237 |
+
actions = action_query.order_by(schema.UserMemoryBatchAction.create_time).all()
|
238 |
+
|
239 |
+
distinct_actions = action_query.distinct().subquery()
|
240 |
+
batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
|
241 |
+
batch_id_to_batch = {batch.id: batch for batch in batches}
|
242 |
+
batch_id_to_words = {batch.id: get_words_in_batch(db, batch.id) for batch in batches}
|
243 |
+
batch_id_to_actions = {batch.id: get_user_memory_actions_in_batch(db, batch.id) for batch in batches}
|
244 |
+
return actions, batch_id_to_batch, batch_id_to_words, batch_id_to_actions
|
245 |
+
|
246 |
+
def get_user_memory_batch_history_in_minutes(db: Session, user_book_id: str, minutes: int):
|
247 |
+
action_query = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.action == "end").join(
|
248 |
+
schema.UserMemoryBatch, schema.UserMemoryBatch.id == schema.UserMemoryBatchAction.batch_id
|
249 |
+
).filter(
|
250 |
+
schema.UserMemoryBatch.user_book_id == user_book_id,
|
251 |
+
schema.UserMemoryBatchAction.create_time > datetime.datetime.now() - datetime.timedelta(minutes=minutes),
|
252 |
+
).limit(20)
|
253 |
+
actions = action_query.order_by(schema.UserMemoryBatchAction.create_time).all()
|
254 |
+
distinct_actions = action_query.distinct().subquery()
|
255 |
+
batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
|
256 |
+
batch_id_to_batch = {batch.id: batch for batch in batches}
|
257 |
+
batch_id_to_words = {batch.id: get_words_in_batch(db, batch.id) for batch in batches}
|
258 |
+
# batch_actions = db.query(schema.UserMemoryBatchAction).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatchAction.batch_id).all()
|
259 |
+
# return actions, batch_id_to_batch, batch_id_to_words, batch_actions
|
260 |
+
return actions, batch_id_to_batch, batch_id_to_words
|
261 |
+
|
262 |
+
|
263 |
+
def get_user_memory_word_history_in_minutes(db: Session, user_book_id: str, minutes: int):
|
264 |
+
action_query = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.action == "end").join(
|
265 |
+
schema.UserMemoryBatch, schema.UserMemoryBatch.id == schema.UserMemoryBatchAction.batch_id
|
266 |
+
).filter(
|
267 |
+
schema.UserMemoryBatch.user_book_id == user_book_id,
|
268 |
+
# schema.UserMemoryBatchAction.create_time > datetime.datetime.now() - datetime.timedelta(minutes=minutes),
|
269 |
+
).limit(200)
|
270 |
+
distinct_actions = action_query.distinct().subquery()
|
271 |
+
batches = db.query(schema.UserMemoryBatch).join(distinct_actions, distinct_actions.c.batch_id == schema.UserMemoryBatch.id).all()
|
272 |
+
words = [get_words_in_batch(db, batch.id) for batch in batches]
|
273 |
+
words = sum(words, [])
|
274 |
+
return words
|
275 |
+
|
276 |
+
def update_user_memory_batch(db: Session, memory_batch_id: str, memory_batch: UserMemoryBatchUpdate):
|
277 |
+
db_memory_batch = db.query(schema.UserMemoryBatch).filter(schema.UserMemoryBatch.id == memory_batch_id).first()
|
278 |
+
if db_memory_batch:
|
279 |
+
for key, value in memory_batch.dict(exclude_unset=True).items():
|
280 |
+
setattr(db_memory_batch, key, value)
|
281 |
+
db.commit()
|
282 |
+
db.refresh(db_memory_batch)
|
283 |
+
return db_memory_batch
|
284 |
+
|
285 |
+
def delete_user_memory_batch(db: Session, memory_batch_id: str):
|
286 |
+
db_memory_batch = db.query(schema.UserMemoryBatch).filter(schema.UserMemoryBatch.id == memory_batch_id).first()
|
287 |
+
if db_memory_batch:
|
288 |
+
db.delete(db_memory_batch)
|
289 |
+
db.commit()
|
290 |
+
return db_memory_batch
|
291 |
+
|
292 |
+
# endregion
|
293 |
+
|
294 |
+
# region UserMemoryBatchAction
|
295 |
+
class UserMemoryBatchActionBase(BaseModel):
|
296 |
+
batch_id: str
|
297 |
+
action: str
|
298 |
+
|
299 |
+
class UserMemoryBatchActionCreate(UserMemoryBatchActionBase):
|
300 |
+
pass
|
301 |
+
|
302 |
+
class UserMemoryBatchActionUpdate(UserMemoryBatchActionBase):
|
303 |
+
pass
|
304 |
+
|
305 |
+
class UserMemoryBatchAction(UserMemoryBatchActionBase):
|
306 |
+
id: str
|
307 |
+
create_time: str
|
308 |
+
update_time: str
|
309 |
+
|
310 |
+
def create_user_memory_batch_action(db: Session, memory_batch_action: UserMemoryBatchActionCreate):
|
311 |
+
db_memory_batch_action = schema.UserMemoryBatchAction(**memory_batch_action.dict())
|
312 |
+
db.add(db_memory_batch_action)
|
313 |
+
db.commit()
|
314 |
+
db.refresh(db_memory_batch_action)
|
315 |
+
return db_memory_batch_action
|
316 |
+
|
317 |
+
def get_user_memory_batch_action(db: Session, memory_batch_action_id: str):
|
318 |
+
return db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.id == memory_batch_action_id).first()
|
319 |
+
|
320 |
+
def get_user_memory_batch_actions(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserMemoryBatchAction]:
|
321 |
+
return db.query(schema.UserMemoryBatchAction).offset(skip).limit(limit).all()
|
322 |
+
|
323 |
+
def get_user_memory_batch_actions_by_user_memory_batch_id(db: Session, user_memory_batch_id: str) -> List[schema.UserMemoryBatchAction]:
|
324 |
+
return db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.batch_id == user_memory_batch_id).all()
|
325 |
+
|
326 |
+
def get_actions_at_each_batch(db: Session, memory_batch_ids: List[str]) -> List[schema.UserMemoryBatchAction]:
|
327 |
+
return db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.batch_id.in_(memory_batch_ids)).all()
|
328 |
+
|
329 |
+
def get_finished_actions_at_each_batch(db: Session, memory_batch_ids: List[str]) -> List[schema.UserMemoryBatchAction]:
|
330 |
+
return db.query(schema.UserMemoryBatchAction).filter(
|
331 |
+
schema.UserMemoryBatchAction.batch_id.in_(memory_batch_ids),
|
332 |
+
schema.UserMemoryBatchAction.action == "end",
|
333 |
+
).order_by(schema.UserMemoryBatchAction.create_time).all()
|
334 |
+
|
335 |
+
def update_user_memory_batch_action(db: Session, memory_batch_action_id: str, memory_batch_action: UserMemoryBatchActionUpdate):
|
336 |
+
db_memory_batch_action = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.id == memory_batch_action_id).first()
|
337 |
+
if db_memory_batch_action:
|
338 |
+
for key, value in memory_batch_action.dict(exclude_unset=True).items():
|
339 |
+
setattr(db_memory_batch_action, key, value)
|
340 |
+
db.commit()
|
341 |
+
db.refresh(db_memory_batch_action)
|
342 |
+
return db_memory_batch_action
|
343 |
+
|
344 |
+
def delete_user_memory_batch_action(db: Session, memory_batch_action_id: str):
|
345 |
+
db_memory_batch_action = db.query(schema.UserMemoryBatchAction).filter(schema.UserMemoryBatchAction.id == memory_batch_action_id).first()
|
346 |
+
if db_memory_batch_action:
|
347 |
+
db.delete(db_memory_batch_action)
|
348 |
+
db.commit()
|
349 |
+
return db_memory_batch_action
|
350 |
+
# endregion
|
351 |
+
|
352 |
+
# region UserMemoryBatchGenerationHistory
|
353 |
+
class UserMemoryBatchGenerationHistoryBase(BaseModel):
|
354 |
+
batch_id: str
|
355 |
+
story: str
|
356 |
+
translated_story: str
|
357 |
+
|
358 |
+
class UserMemoryBatchGenerationHistoryCreate(UserMemoryBatchGenerationHistoryBase):
|
359 |
+
pass
|
360 |
+
|
361 |
+
class UserMemoryBatchGenerationHistoryUpdate(UserMemoryBatchGenerationHistoryBase):
|
362 |
+
pass
|
363 |
+
|
364 |
+
class UserMemoryBatchGenerationHistory(UserMemoryBatchGenerationHistoryBase):
|
365 |
+
id: str
|
366 |
+
create_time: str
|
367 |
+
update_time: str
|
368 |
+
|
369 |
+
def create_user_memory_batch_generation_history(db: Session, memory_batch_generation_history: UserMemoryBatchGenerationHistoryCreate):
|
370 |
+
db_memory_batch_generation_history = schema.UserMemoryBatchGenerationHistory(**memory_batch_generation_history.dict())
|
371 |
+
db.add(db_memory_batch_generation_history)
|
372 |
+
db.commit()
|
373 |
+
db.refresh(db_memory_batch_generation_history)
|
374 |
+
return db_memory_batch_generation_history
|
375 |
+
|
376 |
+
def get_user_memory_batch_generation_history(db: Session, memory_batch_generation_history_id: str):
|
377 |
+
return db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.id == memory_batch_generation_history_id).first()
|
378 |
+
|
379 |
+
def get_user_memory_batch_generation_historys(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserMemoryBatchGenerationHistory]:
|
380 |
+
return db.query(schema.UserMemoryBatchGenerationHistory).offset(skip).limit(limit).all()
|
381 |
+
|
382 |
+
def get_user_memory_batch_generation_historys_by_user_memory_batch_id(db: Session, user_memory_batch_id: str) -> List[schema.UserMemoryBatchGenerationHistory]:
|
383 |
+
return db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.batch_id == user_memory_batch_id).all()
|
384 |
+
|
385 |
+
def get_generation_historys_at_each_batch(db: Session, memory_batch_ids: List[str]) -> List[schema.UserMemoryBatchGenerationHistory]:
|
386 |
+
return db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.batch_id.in_(memory_batch_ids)).all()
|
387 |
+
|
388 |
+
def get_generation_hostorys_by_user_book_id(db: Session, user_book_id: str) -> Dict[str, Tuple[List[schema.Word], List[schema.UserMemoryBatchGenerationHistory]]]:
|
389 |
+
batches = get_user_memory_batches_by_user_book_id(db, user_book_id)
|
390 |
+
batch_ids = [batch.id for batch in batches]
|
391 |
+
batch_id_to_words_and_history = {}
|
392 |
+
for batch_id in batch_ids:
|
393 |
+
historys = get_user_memory_batch_generation_historys_by_user_memory_batch_id(db, batch_id)
|
394 |
+
if len(historys) == 0:
|
395 |
+
continue
|
396 |
+
words = get_words_in_batch(db, batch_id)
|
397 |
+
batch_id_to_words_and_history[batch_id] = (words, historys)
|
398 |
+
return batch_id_to_words_and_history
|
399 |
+
|
400 |
+
def update_user_memory_batch_generation_history(db: Session, memory_batch_generation_history_id: str, memory_batch_generation_history: UserMemoryBatchGenerationHistoryUpdate):
|
401 |
+
db_memory_batch_generation_history = db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.id == memory_batch_generation_history_id).first()
|
402 |
+
if db_memory_batch_generation_history:
|
403 |
+
for key, value in memory_batch_generation_history.dict(exclude_unset=True).items():
|
404 |
+
setattr(db_memory_batch_generation_history, key, value)
|
405 |
+
db.commit()
|
406 |
+
db.refresh(db_memory_batch_generation_history)
|
407 |
+
return db_memory_batch_generation_history
|
408 |
+
|
409 |
+
def delete_user_memory_batch_generation_history(db: Session, memory_batch_generation_history_id: str):
|
410 |
+
db_memory_batch_generation_history = db.query(schema.UserMemoryBatchGenerationHistory).filter(schema.UserMemoryBatchGenerationHistory.id == memory_batch_generation_history_id).first()
|
411 |
+
if db_memory_batch_generation_history:
|
412 |
+
db.delete(db_memory_batch_generation_history)
|
413 |
+
db.commit()
|
414 |
+
return db_memory_batch_generation_history
|
415 |
+
|
416 |
+
# endregion
|
417 |
+
|
418 |
+
# region UserMemoryWord
|
419 |
+
class UserMemoryWordBase(BaseModel):
|
420 |
+
batch_id: str
|
421 |
+
word_id: str
|
422 |
+
|
423 |
+
class UserMemoryWordCreate(UserMemoryWordBase):
|
424 |
+
pass
|
425 |
+
|
426 |
+
class UserMemoryWordUpdate(UserMemoryWordBase):
|
427 |
+
pass
|
428 |
+
|
429 |
+
class UserMemoryWord(UserMemoryWordBase):
|
430 |
+
id: str
|
431 |
+
|
432 |
+
def create_user_memory_word(db: Session, memory_word: UserMemoryWordCreate):
|
433 |
+
db_memory_word = schema.UserMemoryWord(**memory_word.dict())
|
434 |
+
db.add(db_memory_word)
|
435 |
+
db.commit()
|
436 |
+
db.refresh(db_memory_word)
|
437 |
+
return db_memory_word
|
438 |
+
|
439 |
+
def get_user_memory_word(db: Session, memory_word_id: str):
|
440 |
+
return db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.id == memory_word_id).first()
|
441 |
+
|
442 |
+
def get_user_memory_words(db: Session, skip: int = 0, limit: int = 10) -> List[schema.UserMemoryWord]:
|
443 |
+
return db.query(schema.UserMemoryWord).offset(skip).limit(limit).all()
|
444 |
+
|
445 |
+
def get_user_memory_words_by_batch_id(db: Session, batch_id: str) -> List[schema.UserMemoryWord]:
|
446 |
+
return db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.batch_id == batch_id).all()
|
447 |
+
|
448 |
+
def update_user_memory_word(db: Session, memory_word_id: str, memory_word: UserMemoryWordUpdate):
|
449 |
+
db_memory_word = db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.id == memory_word_id).first()
|
450 |
+
if db_memory_word:
|
451 |
+
for key, value in memory_word.dict(exclude_unset=True).items():
|
452 |
+
setattr(db_memory_word, key, value)
|
453 |
+
db.commit()
|
454 |
+
db.refresh(db_memory_word)
|
455 |
+
return db_memory_word
|
456 |
+
|
457 |
+
def delete_user_memory_word(db: Session, memory_word_id: str):
|
458 |
+
db_memory_word = db.query(schema.UserMemoryWord).filter(schema.UserMemoryWord.id == memory_word_id).first()
|
459 |
+
if db_memory_word:
|
460 |
+
db.delete(db_memory_word)
|
461 |
+
db.commit()
|
462 |
+
return db_memory_word
|
463 |
+
|
464 |
+
# endregion
|
465 |
+
|
466 |
+
# region UserMemoryAction
|
467 |
+
class UserMemoryActionBase(BaseModel):
|
468 |
+
batch_id: str
|
469 |
+
word_id: str
|
470 |
+
action: str
|
471 |
+
|
472 |
+
class UserMemoryActionCreate(UserMemoryActionBase):
|
473 |
+
pass
|
474 |
+
|
475 |
+
class UserMemoryActionUpdate(UserMemoryActionBase):
|
476 |
+
pass
|
477 |
+
|
478 |
+
class UserMemoryAction(UserMemoryActionBase):
|
479 |
+
id: str
|
480 |
+
create_time: str
|
481 |
+
update_time: str
|
482 |
+
|
483 |
+
|
484 |
+
def create_user_memory_action(db: Session, memory_action: UserMemoryActionCreate):
|
485 |
+
db_memory_action = schema.UserMemoryAction(**memory_action.dict())
|
486 |
+
db.add(db_memory_action)
|
487 |
+
db.commit()
|
488 |
+
db.refresh(db_memory_action)
|
489 |
+
return db_memory_action
|
490 |
+
|
491 |
+
def get_user_memory_action(db: Session, memory_action_id: str) -> schema.UserMemoryAction:
|
492 |
+
return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.id == memory_action_id).first()
|
493 |
+
|
494 |
+
def get_user_memory_actions(db: Session, skip: int = 0, limit: int = 10)-> List[schema.UserMemoryAction]:
|
495 |
+
return db.query(schema.UserMemoryAction).offset(skip).limit(limit).all()
|
496 |
+
|
497 |
+
def get_user_memory_actions_by_word_id(db: Session, word_id: str)-> List[schema.UserMemoryAction]:
|
498 |
+
return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.word_id == word_id).all()
|
499 |
+
|
500 |
+
def get_actions_at_each_word(db: Session, word_ids: List[str]) -> List[schema.UserMemoryAction]:
|
501 |
+
return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.word_id.in_(word_ids)).all()
|
502 |
+
|
503 |
+
def get_user_memory_actions_in_batch(db: Session, batch_id: str) -> List[schema.UserMemoryAction]:
|
504 |
+
return db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.batch_id == batch_id).all()
|
505 |
+
|
506 |
+
def update_user_memory_action(db: Session, memory_action_id: str, memory_action: UserMemoryActionUpdate):
|
507 |
+
db_memory_action = db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.id == memory_action_id).first()
|
508 |
+
if db_memory_action:
|
509 |
+
for key, value in memory_action.dict(exclude_unset=True).items():
|
510 |
+
setattr(db_memory_action, key, value)
|
511 |
+
db.commit()
|
512 |
+
db.refresh(db_memory_action)
|
513 |
+
return db_memory_action
|
514 |
+
|
515 |
+
def delete_user_memory_action(db: Session, memory_action_id: str):
|
516 |
+
db_memory_action = db.query(schema.UserMemoryAction).filter(schema.UserMemoryAction.id == memory_action_id).first()
|
517 |
+
if db_memory_action:
|
518 |
+
db.delete(db_memory_action)
|
519 |
+
db.commit()
|
520 |
+
return db_memory_action
|
521 |
+
|
522 |
+
# endregion
|
523 |
+
|
524 |
+
# region Book
|
525 |
+
class BookBase(BaseModel):
|
526 |
+
bk_order: float = 0
|
527 |
+
bk_name: str
|
528 |
+
bk_item_num: int
|
529 |
+
bk_author: str = ""
|
530 |
+
bk_comment: str = ""
|
531 |
+
bk_organization: str = ""
|
532 |
+
bk_publisher: str = ""
|
533 |
+
bk_version: str = ""
|
534 |
+
permission: str = "private"
|
535 |
+
creator: str
|
536 |
+
|
537 |
+
class BookCreate(BookBase):
|
538 |
+
pass
|
539 |
+
|
540 |
+
class BookUpdate(BookBase):
|
541 |
+
pass
|
542 |
+
|
543 |
+
class Book(BookBase):
|
544 |
+
bk_id: str
|
545 |
+
|
546 |
+
def create_book(db: Session, book: BookCreate):
|
547 |
+
db_book = schema.Book(**book.dict())
|
548 |
+
db.add(db_book)
|
549 |
+
db.commit()
|
550 |
+
db.refresh(db_book)
|
551 |
+
return db_book
|
552 |
+
|
553 |
+
def get_book(db: Session, book_id: str):
|
554 |
+
return db.query(schema.Book).filter(schema.Book.bk_id == book_id).first()
|
555 |
+
|
556 |
+
def get_book_by_name(db: Session, book_name: str):
|
557 |
+
return db.query(schema.Book).filter(schema.Book.bk_name == book_name).first()
|
558 |
+
|
559 |
+
def get_books(db: Session, skip: int = 0, limit: int = 10):
|
560 |
+
return db.query(schema.Book).offset(skip).limit(limit).all()
|
561 |
+
|
562 |
+
def get_all_books(db: Session):
|
563 |
+
return db.query(schema.Book).all()
|
564 |
+
|
565 |
+
def get_all_books_for_user(db: Session, user_id: str):
|
566 |
+
return db.query(schema.Book).filter(
|
567 |
+
or_(schema.Book.creator == user_id, schema.Book.permission == "public")
|
568 |
+
).order_by(schema.Book.permission, schema.Book.create_time).all()
|
569 |
+
|
570 |
+
def get_book_count(db: Session):
|
571 |
+
return db.query(schema.Book).count()
|
572 |
+
|
573 |
+
def update_book(db: Session, book_id: str, book: BookUpdate):
|
574 |
+
db_book = db.query(schema.Book).filter(schema.Book.bk_id == book_id).first()
|
575 |
+
if db_book:
|
576 |
+
for key, value in book.dict(exclude_unset=True).items():
|
577 |
+
setattr(db_book, key, value)
|
578 |
+
db.commit()
|
579 |
+
db.refresh(db_book)
|
580 |
+
return db_book
|
581 |
+
|
582 |
+
def delete_book(db: Session, book_id: str):
|
583 |
+
db_book = db.query(schema.Book).filter(schema.Book.bk_id == book_id).first()
|
584 |
+
if db_book:
|
585 |
+
db.delete(db_book)
|
586 |
+
db.commit()
|
587 |
+
return db_book
|
588 |
+
|
589 |
+
# endregion
|
590 |
+
|
591 |
+
# region Unit
|
592 |
+
class UnitBase(BaseModel):
|
593 |
+
bv_book_id: str
|
594 |
+
bv_voc_id: str
|
595 |
+
bv_flag: int = 1
|
596 |
+
bv_tag: str = ""
|
597 |
+
bv_order: int = 1
|
598 |
+
|
599 |
+
class UnitCreate(UnitBase):
|
600 |
+
pass
|
601 |
+
|
602 |
+
class UnitUpdate(UnitBase):
|
603 |
+
pass
|
604 |
+
|
605 |
+
class Unit(UnitBase):
|
606 |
+
pass
|
607 |
+
|
608 |
+
def create_unit(db: Session, unit: UnitCreate):
|
609 |
+
db_unit = schema.Unit(**unit.dict())
|
610 |
+
db.add(db_unit)
|
611 |
+
db.commit()
|
612 |
+
db.refresh(db_unit)
|
613 |
+
return db_unit
|
614 |
+
|
615 |
+
def get_unit(db: Session, unit_id: str):
|
616 |
+
return db.query(schema.Unit).filter(schema.Unit.bv_id == unit_id).first()
|
617 |
+
|
618 |
+
def get_units(db: Session, skip: int = 0, limit: int = 10):
|
619 |
+
return db.query(schema.Unit).offset(skip).limit(limit).all()
|
620 |
+
|
621 |
+
def update_unit(db: Session, unit_id: str, unit: UnitUpdate):
|
622 |
+
db_unit = db.query(schema.Unit).filter(schema.Unit.bv_id == unit_id).first()
|
623 |
+
if db_unit:
|
624 |
+
for key, value in unit.dict(exclude_unset=True).items():
|
625 |
+
setattr(db_unit, key, value)
|
626 |
+
db.commit()
|
627 |
+
db.refresh(db_unit)
|
628 |
+
return db_unit
|
629 |
+
|
630 |
+
def delete_unit(db: Session, unit_id: str):
|
631 |
+
db_unit = db.query(schema.Unit).filter(schema.Unit.bv_id == unit_id).first()
|
632 |
+
if db_unit:
|
633 |
+
db.delete(db_unit)
|
634 |
+
db.commit()
|
635 |
+
return db_unit
|
636 |
+
|
637 |
+
# endregion
|
638 |
+
|
639 |
+
# region Word
|
640 |
+
class WordBase(BaseModel):
|
641 |
+
vc_id: str
|
642 |
+
vc_vocabulary: str
|
643 |
+
vc_phonetic_uk: str
|
644 |
+
vc_phonetic_us: str
|
645 |
+
vc_frequency: float
|
646 |
+
vc_difficulty: float
|
647 |
+
vc_acknowledge_rate: float
|
648 |
+
|
649 |
+
class WordCreate(WordBase):
|
650 |
+
pass
|
651 |
+
|
652 |
+
class WordUpdate(WordBase):
|
653 |
+
pass
|
654 |
+
|
655 |
+
class Word(WordBase):
|
656 |
+
pass
|
657 |
+
|
658 |
+
|
659 |
+
def create_word(db: Session, word: WordCreate, unit_id: str):
|
660 |
+
db_word = schema.Word(**word.dict(), vc_id=unit_id)
|
661 |
+
db.add(db_word)
|
662 |
+
db.commit()
|
663 |
+
db.refresh(db_word)
|
664 |
+
return db_word
|
665 |
+
|
666 |
+
def get_word(db: Session, word_id: str):
|
667 |
+
return db.query(schema.Word).filter(schema.Word.vc_id == word_id).first()
|
668 |
+
|
669 |
+
def get_words(db: Session, skip: int = 0, limit: int = 10) -> List[schema.Word]:
|
670 |
+
return db.query(schema.Word).offset(skip).limit(limit).all()
|
671 |
+
|
672 |
+
def get_words_by_vocabulary(db: Session, vocabulary: List[str]) -> List[schema.Word]:
|
673 |
+
return db.query(schema.Word).filter(schema.Word.vc_vocabulary.in_(vocabulary)).all()
|
674 |
+
|
675 |
+
def get_words_by_ids(db: Session, ids: List[str]) -> List[schema.Word]:
|
676 |
+
return db.query(schema.Word).filter(schema.Word.vc_id.in_(ids)).all()
|
677 |
+
|
678 |
+
def get_words_in_batch(db: Session, batch_id: str) -> List[schema.Word]:
|
679 |
+
return db.query(schema.Word).join(schema.UserMemoryWord, schema.UserMemoryWord.word_id == schema.Word.vc_id).filter(schema.UserMemoryWord.batch_id == batch_id).all()
|
680 |
+
|
681 |
+
def get_words_at_each_batch(db: Session, batch_ids: List[str]) -> List[schema.Word]:
|
682 |
+
return db.query(schema.Word).join(schema.UserMemoryWord, schema.UserMemoryWord.word_id == schema.Word.vc_id).filter(schema.UserMemoryWord.batch_id.in_(batch_ids)).all()
|
683 |
+
|
684 |
+
def get_words_in_user_book(db: Session, user_book_id: str) -> List[schema.Word]:
|
685 |
+
batches = get_new_user_memory_batches_by_user_book_id(db, user_book_id)
|
686 |
+
batch_ids = [batch.id for batch in batches]
|
687 |
+
return get_words_at_each_batch(db, batch_ids)
|
688 |
+
|
689 |
+
def update_word(db: Session, word_id: str, word: WordUpdate):
|
690 |
+
db_word = db.query(schema.Word).filter(schema.Word.vc_id == word_id).first()
|
691 |
+
if db_word:
|
692 |
+
for key, value in word.dict(exclude_unset=True).items():
|
693 |
+
setattr(db_word, key, value)
|
694 |
+
db.commit()
|
695 |
+
db.refresh(db_word)
|
696 |
+
return db_word
|
697 |
+
|
698 |
+
def delete_word(db: Session, word_id: str):
|
699 |
+
db_word = db.query(schema.Word).filter(schema.Word.vc_id == word_id).first()
|
700 |
+
if db_word:
|
701 |
+
db.delete(db_word)
|
702 |
+
db.commit()
|
703 |
+
return db_word
|
704 |
+
# endregion
|
database/schema.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float, DateTime
|
3 |
+
from sqlalchemy.orm import relationship
|
4 |
+
|
5 |
+
from .database import Base
|
6 |
+
|
7 |
+
import uuid
|
8 |
+
import bcrypt
|
9 |
+
|
10 |
+
# class ModelBase(Base):
|
11 |
+
|
12 |
+
# __abstract__ = True
|
13 |
+
|
14 |
+
# id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
15 |
+
# created_at = Column(DateTime, default=db.func.current_timestamp())
|
16 |
+
# updated_at = Column(DateTime,
|
17 |
+
# default=func.current_timestamp(),
|
18 |
+
# onupdate=func.current_timestamp())
|
19 |
+
|
20 |
+
class User(Base):
|
21 |
+
__tablename__ = "users"
|
22 |
+
|
23 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
24 |
+
email = Column(String, unique=True, index=True)
|
25 |
+
encrypted_password = Column(String)
|
26 |
+
is_active = Column(Boolean, default=True)
|
27 |
+
|
28 |
+
items = relationship("Item", back_populates="owner")
|
29 |
+
# write a sql to update the user's password
|
30 |
+
# $2b$12$huJdmqFPzWU.9rumd2wpSOZUnCJ0bufmA4vl5T9PDc7V.xLgWAqSu
|
31 |
+
# UPDATE = "UPDATE users SET encrypted_password = '$2b$12$huJdmqFPzWU.9rumd2wpSOZUnCJ0bufmA4vl5T9PDc7V.xLgWAqSu' WHERE email = '[email protected]'"
|
32 |
+
|
33 |
+
def set_password(self, password: str):
|
34 |
+
hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
|
35 |
+
self.encrypted_password = str(hashed, encoding='utf-8')
|
36 |
+
|
37 |
+
def verify_password(self, password: str):
|
38 |
+
return bcrypt.checkpw(password.encode(), bytes(self.encrypted_password, encoding='utf-8'))
|
39 |
+
|
40 |
+
def __str__(self):
|
41 |
+
return f"<User {self.email}>"
|
42 |
+
|
43 |
+
def __repr__(self):
|
44 |
+
return f"<User {self.email}>"
|
45 |
+
|
46 |
+
|
47 |
+
class Item(Base):
|
48 |
+
__tablename__ = "items"
|
49 |
+
|
50 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
51 |
+
title = Column(String)
|
52 |
+
description = Column(String)
|
53 |
+
owner_id = Column(String, ForeignKey("users.id"))
|
54 |
+
|
55 |
+
owner = relationship("User", back_populates="items")
|
56 |
+
|
57 |
+
# region ่ฎฐๅฟๅ่ฏ็ธๅ
ณ
|
58 |
+
class UserBook(Base):
|
59 |
+
__tablename__ = "user_book"
|
60 |
+
|
61 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
62 |
+
owner_id = Column(String, ForeignKey("users.id"))
|
63 |
+
book_id = Column(String, ForeignKey("book.bk_id"))
|
64 |
+
|
65 |
+
title = Column(String)
|
66 |
+
batch_size = Column(Integer, default=10)
|
67 |
+
random = Column(Boolean, default=True)
|
68 |
+
memorizing_batch = Column(String, default='')
|
69 |
+
|
70 |
+
class UserMemoryBatch(Base):
|
71 |
+
__tablename__ = "user_memory_batch"
|
72 |
+
|
73 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
74 |
+
user_book_id = Column(String, ForeignKey("user_book.id"))
|
75 |
+
|
76 |
+
story = Column(String, default="")
|
77 |
+
translated_story = Column(String, default="")
|
78 |
+
batch_type = Column(String, default="ๆฐ่ฏ") # ๆฐ่ฏ, ๅๅฟ, ๅคไน
|
79 |
+
|
80 |
+
create_time = Column(DateTime, default=datetime.now)
|
81 |
+
update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
|
82 |
+
|
83 |
+
words = relationship("UserMemoryWord", back_populates="batch")
|
84 |
+
|
85 |
+
class UserMemoryBatchAction(Base):
|
86 |
+
__tablename__ = "user_memory_batch_action"
|
87 |
+
"""
|
88 |
+
ๅผๅง่ฎฐๅฟๅฐ็ปๆ่ฎฐๅฟ็ๆถ้ด๏ผๅฏไปฅ่ฎก็ฎ่ฎฐๅฟๆ็
|
89 |
+
"""
|
90 |
+
|
91 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
92 |
+
batch_id = Column(String, ForeignKey("user_memory_batch.id"))
|
93 |
+
|
94 |
+
action = Column(String, default="start") # start or end
|
95 |
+
|
96 |
+
create_time = Column(DateTime, default=datetime.now)
|
97 |
+
update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
|
98 |
+
|
99 |
+
|
100 |
+
class UserMemoryBatchGenerationHistory(Base):
|
101 |
+
__tablename__ = "user_memory_batch_generation_history"
|
102 |
+
"""
|
103 |
+
ๅผๅง่ฎฐๅฟๅฐ็ปๆ่ฎฐๅฟ็ๆถ้ด๏ผๅฏไปฅ่ฎก็ฎ่ฎฐๅฟๆ็
|
104 |
+
"""
|
105 |
+
|
106 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
107 |
+
batch_id = Column(String, ForeignKey("user_memory_batch.id"))
|
108 |
+
|
109 |
+
story = Column(String, default="")
|
110 |
+
translated_story = Column(String, default="")
|
111 |
+
create_time = Column(DateTime, default=datetime.now)
|
112 |
+
update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
|
113 |
+
|
114 |
+
|
115 |
+
class UserMemoryWord(Base):
|
116 |
+
__tablename__ = "user_memory_word"
|
117 |
+
|
118 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
119 |
+
batch_id = Column(String, ForeignKey("user_memory_batch.id"))
|
120 |
+
word_id = Column(String, ForeignKey("word.vc_id"))
|
121 |
+
|
122 |
+
batch = relationship("UserMemoryBatch", back_populates="words")
|
123 |
+
|
124 |
+
class UserMemoryAction(Base):
|
125 |
+
__tablename__ = "user_memory_action"
|
126 |
+
|
127 |
+
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
128 |
+
batch_id = Column(String, ForeignKey("user_memory_batch.id"))
|
129 |
+
word_id = Column(String, ForeignKey("word.vc_id"))
|
130 |
+
|
131 |
+
action = Column(String, default="remenber") # remenber or forget
|
132 |
+
create_time = Column(DateTime, default=datetime.now)
|
133 |
+
update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
|
134 |
+
|
135 |
+
# endregion
|
136 |
+
|
137 |
+
class Book(Base):
|
138 |
+
__tablename__ = "book"
|
139 |
+
|
140 |
+
# {'bk_id': 'd645920e395fedad7bbbed0e',
|
141 |
+
# 'bk_parent_id': '6512bd43d9caa6e02c990b0a',
|
142 |
+
# 'bk_level': 2,
|
143 |
+
# 'bk_order': 2.0,
|
144 |
+
# 'bk_name': 'ไบบๆ็้ซไธญ่ฑ่ฏญ1 - ๅฟ
ไฟฎ',
|
145 |
+
# 'bk_item_num': 315,
|
146 |
+
# 'bk_direct_item_num': 315,
|
147 |
+
# 'bk_author': 'ๅ้ไน',
|
148 |
+
# 'bk_book': 'ไบบๆ็ๆฎ้้ซไธญ่ฏพ็จๆ ๅๅฎ้ชๆ็งไนฆ ่ฑ่ฏญ 1 ๅฟ
ไฟฎ',
|
149 |
+
# 'bk_comment': '้ปไฝ๏ผๆฌๅๅ
้็น่ฏๆฑๅ็ญ่ฏญ๏ผๆ โโณโ๏ผ่ฏพๆ ่ฏๆฑ๏ผ่ฆๆฑๆๆก๏ผๆโโณโ๏ผไธ่ฆๆฑๆๆก๏ผไผๅบ็ฐๅคง้็ผฉๅใไบบๅใๅฐๅๅ็ญ่ฏญ๏ผ่ฏท้่๏ผใ',
|
150 |
+
# 'bk_organization': 'ไบบๆฐๆ่ฒๅบ็็คพ ่ฏพ็จๆๆ็ ็ฉถๆ๏ผ่ฑ่ฏญ่ฏพ็จๆๆ็ ็ฉถๅผๅไธญๅฟ',
|
151 |
+
# 'bk_publisher': 'ไบบๆฐๆ่ฒๅบ็็คพ',
|
152 |
+
# 'bk_version': '2007ๅนด1ๆ็ฌฌ2็',
|
153 |
+
# 'bk_flag': '้ป่ฎค๏ผ152;้ปไฝ๏ผ97;ๅโณ๏ผ66'},
|
154 |
+
bk_id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
155 |
+
bk_order = Column(Float)
|
156 |
+
bk_name = Column(String)
|
157 |
+
bk_item_num = Column(Integer)
|
158 |
+
bk_author = Column(String, default='')
|
159 |
+
bk_comment = Column(String, default='')
|
160 |
+
bk_organization = Column(String, default='')
|
161 |
+
bk_publisher = Column(String, default='')
|
162 |
+
bk_version = Column(String, default='')
|
163 |
+
|
164 |
+
permission = Column(String, default='private')
|
165 |
+
creator = Column(String, ForeignKey("users.id"))
|
166 |
+
create_time = Column(DateTime, default=datetime.now)
|
167 |
+
update_time = Column(DateTime, onupdate=datetime.now, default=datetime.now)
|
168 |
+
|
169 |
+
def __str__(self):
|
170 |
+
return f"{self.bk_name}\n [{self.bk_item_num} words, {self.bk_version}]"
|
171 |
+
def __repr__(self):
|
172 |
+
return f"<Book {self.bk_name}>"
|
173 |
+
|
174 |
+
class Unit(Base):
|
175 |
+
__tablename__ = "unit"
|
176 |
+
|
177 |
+
# ['bv_id', 'bv_book_id', 'bv_voc_id', 'bv_flag', 'bv_tag', 'bv_order']
|
178 |
+
# {'bv_id': '58450c828958a37d5c10f763',
|
179 |
+
# 'bv_book_id': 'd645920e395fedad7bbbed0e',
|
180 |
+
# 'bv_voc_id': '57067b9ca172044907c615d7',
|
181 |
+
# 'bv_flag': 4,
|
182 |
+
# 'bv_tag': 'Unit 1',
|
183 |
+
# 'bv_order': 1},
|
184 |
+
bv_id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
185 |
+
bv_voc_id = Column(String, ForeignKey("word.vc_id"))
|
186 |
+
bv_book_id = Column(String, ForeignKey("book.bk_id"))
|
187 |
+
bv_flag = Column(Integer)
|
188 |
+
bv_tag = Column(String)
|
189 |
+
bv_order = Column(Integer)
|
190 |
+
|
191 |
+
class Word(Base):
|
192 |
+
__tablename__ = "word"
|
193 |
+
|
194 |
+
# vc_id>vc_vocabulary>vc_phonetic_uk>vc_phonetic_us>vc_frequency>vc_difficulty>vc_acknowledge_rate
|
195 |
+
# 57067c89a172044907c6698e>superspecies>[su:pษrsหpi:สi:z]>[supษsหpiสiz]>0.0>1>0.664122
|
196 |
+
vc_id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
197 |
+
vc_vocabulary = Column(String)
|
198 |
+
vc_translation = Column(String)
|
199 |
+
vc_phonetic_uk = Column(String)
|
200 |
+
vc_phonetic_us = Column(String)
|
201 |
+
vc_frequency = Column(Float)
|
202 |
+
vc_difficulty = Column(Float)
|
203 |
+
vc_acknowledge_rate = Column(Float)
|
204 |
+
|
205 |
+
def __str__(self):
|
206 |
+
return f"{self.vc_vocabulary} {self.vc_translation}\n[{self.vc_phonetic_uk}] [{self.vc_phonetic_us}]"
|
207 |
+
|
208 |
+
def __repr__(self):
|
209 |
+
return f"{self.vc_vocabulary} {self.vc_translation}\n[{self.vc_phonetic_uk}] [{self.vc_phonetic_us}]"
|
memorize.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy.orm import Session
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
from database.operation import *
|
6 |
+
from database import schema
|
7 |
+
import random
|
8 |
+
from loguru import logger
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
# ่ฎฐๅ่ฏ
|
13 |
+
from story_agent import generate_story_and_translated_story
|
14 |
+
from common.util import date_str, multiprocessing_mapping
|
15 |
+
|
16 |
+
|
17 |
+
def get_words_for_book(db: Session, user_book: UserBook) -> List[schema.Word]:
|
18 |
+
book = get_book(db, user_book.book_id)
|
19 |
+
if book is None:
|
20 |
+
logger.warning("book not found")
|
21 |
+
return []
|
22 |
+
q = db.query(schema.Word).join(schema.Unit, schema.Unit.bv_voc_id == schema.Word.vc_id)
|
23 |
+
words = q.filter(schema.Unit.bv_book_id == book.bk_id).order_by(schema.Word.vc_difficulty).all()
|
24 |
+
return words
|
25 |
+
|
26 |
+
|
27 |
+
def save_words_as_book(db: Session, user_id: str, words: List[schema.Word], title: str):
|
28 |
+
book = create_book(db, BookCreate(bk_name=f"{title}๏ผๅพ
ๅญฆๅ่ฏ่ชๅจไฟๅญไธบๅ่ฏไนฆ๏ผ", bk_item_num=len(words), creator=user_id))
|
29 |
+
for i, word in tqdm(enumerate(words)):
|
30 |
+
unit = UnitCreate(bv_book_id=book.bk_id, bv_voc_id=word.vc_id)
|
31 |
+
db_unit = schema.Unit(**unit.dict())
|
32 |
+
db.add(db_unit)
|
33 |
+
if i % 500 == 0:
|
34 |
+
db.commit()
|
35 |
+
db.commit()
|
36 |
+
return book
|
37 |
+
|
38 |
+
def save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]):
|
39 |
+
batch_words_str_list = [word.vc_vocabulary for word in batch_words]
|
40 |
+
# ๆไปฌๅชๅจ็ฌฌไธไธชๆนๆฌก็ๆๆ
ไบใๅ้ข็ๆนๆฌกๆ นๆฎ็จๆท็่ฎฐๅฟๆ
ๅต็ๆๆ
ไบ๏ผๆๅ 3 ไธชๆนๆฌก็ๆๆ
ไบ
|
41 |
+
story, translated_story = generate_story_and_translated_story(batch_words_str_list)
|
42 |
+
return save_batch_words_with_story(db, i, user_book_id, batch_words, story, translated_story)
|
43 |
+
|
44 |
+
|
45 |
+
def save_batch_words_with_story(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word], story: str, translated_story: str):
|
46 |
+
batch_words_str_list = [word.vc_vocabulary for word in batch_words]
|
47 |
+
logger.info(f"{i}, {batch_words_str_list}\n{story}")
|
48 |
+
user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate(
|
49 |
+
user_book_id=user_book_id,
|
50 |
+
story=story,
|
51 |
+
translated_story=translated_story
|
52 |
+
))
|
53 |
+
create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
|
54 |
+
batch_id=user_memory_batch.id,
|
55 |
+
story=story,
|
56 |
+
translated_story=translated_story
|
57 |
+
))
|
58 |
+
for word in batch_words:
|
59 |
+
memory_word = UserMemoryWordCreate(
|
60 |
+
batch_id=user_memory_batch.id,
|
61 |
+
word_id=word.vc_id
|
62 |
+
)
|
63 |
+
db_memory_word = schema.UserMemoryWord(**memory_word.dict())
|
64 |
+
db.add(db_memory_word)
|
65 |
+
db.commit()
|
66 |
+
return user_memory_batch
|
67 |
+
|
68 |
+
async def async_save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]):
|
69 |
+
save_batch_words(db, i, user_book_id, batch_words)
|
70 |
+
|
71 |
+
import asyncio
|
72 |
+
async def async_save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]):
|
73 |
+
for i, batch_words in enumerate(batch_words_list):
|
74 |
+
asyncio.ensure_future(async_save_batch_words(db, i+1, user_book_id, batch_words))
|
75 |
+
|
76 |
+
def transform(batch_words: List[str]):
|
77 |
+
story, translated_story = generate_story_and_translated_story(batch_words)
|
78 |
+
return {
|
79 |
+
"story": story,
|
80 |
+
"translated_story": translated_story,
|
81 |
+
"words": batch_words
|
82 |
+
}
|
83 |
+
|
84 |
+
def save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]):
|
85 |
+
word_str_list = []
|
86 |
+
for batch_words in batch_words_list:
|
87 |
+
word_str_list.append([word.vc_vocabulary for word in batch_words])
|
88 |
+
story_list = multiprocessing_mapping(transform, word_str_list, tmp_filepath=f"./output/logs/save_batch_words_list_{date_str}.xlsx")
|
89 |
+
logger.info(f"story_list: {len(story_list)}")
|
90 |
+
for i, (batch_words, story) in tqdm(enumerate(zip(batch_words_list, story_list))):
|
91 |
+
save_batch_words_with_story(db, i, user_book_id, batch_words, story['story'], story['translated_story'])
|
92 |
+
|
93 |
+
def track(db: Session, user_book: schema.UserBook, words: List[schema.Word]):
|
94 |
+
batch_size = user_book.batch_size
|
95 |
+
logger.debug(f"{[w.vc_vocabulary for w in words]}")
|
96 |
+
logger.debug(f"batch_size: {batch_size}")
|
97 |
+
logger.debug(f"words count: {len(words)}")
|
98 |
+
if user_book.random:
|
99 |
+
random.shuffle(words)
|
100 |
+
else:
|
101 |
+
words.sort(key=lambda x: x.vc_frequency, reverse=True) # ๆ็
ง่ฏ้ขๆๅบ๏ผ่ฏ้ข่ถ้ซ่ถๅฎนๆ่ฎฐไฝ
|
102 |
+
logger.debug(f"saving words as book")
|
103 |
+
save_words_as_book(db, user_book.owner_id, words, user_book.title)
|
104 |
+
logger.debug(f"saved words as book [{user_book.title}]")
|
105 |
+
batch_words_list = []
|
106 |
+
for i in range(0, len(words), batch_size):
|
107 |
+
batch_words = words[i:i+batch_size]
|
108 |
+
batch_words_list.append(batch_words)
|
109 |
+
logger.debug(f"batch_words_list: {len(batch_words_list)}")
|
110 |
+
if len(batch_words_list) == 0:
|
111 |
+
return
|
112 |
+
first_batch_words = batch_words_list[0]
|
113 |
+
user_memory_batch = save_batch_words(db, 0, user_book.id, first_batch_words)
|
114 |
+
user_book.memorizing_batch = user_memory_batch.id
|
115 |
+
db.commit()
|
116 |
+
save_batch_words_list(db, user_book.id, batch_words_list[1:])
|
117 |
+
# asyncio.run(async_save_batch_words_list(db, user_book.id, batch_words_list[1:]))
|
118 |
+
|
119 |
+
def remenber(db: Session, batch_id: str, word_id: str):
|
120 |
+
return create_user_memory_action(db, UserMemoryActionCreate(
|
121 |
+
batch_id=batch_id,
|
122 |
+
word_id=word_id,
|
123 |
+
action="remember"
|
124 |
+
))
|
125 |
+
|
126 |
+
def forget(db: Session, batch_id: str, word_id: str):
|
127 |
+
return create_user_memory_action(db, UserMemoryActionCreate(
|
128 |
+
batch_id=batch_id,
|
129 |
+
word_id=word_id,
|
130 |
+
action="forget"
|
131 |
+
))
|
132 |
+
|
133 |
+
def save_memorizing_word_action(db: Session, batch_id: str, actions: List[Tuple[str, str]]):
|
134 |
+
"""
|
135 |
+
actions: [(word_id, remember | forget)]
|
136 |
+
"""
|
137 |
+
for word_id, action in actions:
|
138 |
+
memory_action = UserMemoryActionCreate(
|
139 |
+
batch_id=batch_id,
|
140 |
+
word_id=word_id,
|
141 |
+
action=action
|
142 |
+
)
|
143 |
+
db_memory_action = schema.UserMemoryAction(**memory_action.dict())
|
144 |
+
db.add(db_memory_action)
|
145 |
+
db.commit()
|
146 |
+
|
147 |
+
def on_batch_start(db: Session, user_memory_batch_id: str):
|
148 |
+
return create_user_memory_batch_action(db, UserMemoryBatchActionCreate(
|
149 |
+
batch_id=user_memory_batch_id,
|
150 |
+
action="start"
|
151 |
+
))
|
152 |
+
|
153 |
+
def on_batch_end(db: Session, user_memory_batch_id: str):
|
154 |
+
return create_user_memory_batch_action(db, UserMemoryBatchActionCreate(
|
155 |
+
batch_id=user_memory_batch_id,
|
156 |
+
action="end"
|
157 |
+
))
|
158 |
+
|
159 |
+
# def generate_recall_batch(db: Session, user_book: schema.UserBook):
|
160 |
+
def generate_next_batch(db: Session, user_book: schema.UserBook,
|
161 |
+
minutes: int = 60, k: int = 3):
|
162 |
+
# ็ๆไธไธไธชๆนๆฌก๏ผๅๅฟๆนๆ่
ๅคไน ๆน
|
163 |
+
# ๅฆๆๆฏๆฐ่ฏๆน๏ผๅ่ฟๅ None
|
164 |
+
left_bound, right_bound = 0.3, 0.6
|
165 |
+
user_book_id = user_book.id
|
166 |
+
batch_size = user_book.batch_size
|
167 |
+
# actions, batch_id_to_batch, batch_id_to_words = get_user_memory_batch_history_in_minutes(db, user_book_id, minutes)
|
168 |
+
# memorizing_words = sum(list(batch_id_to_words.values()), [])
|
169 |
+
memorizing_words = get_user_memory_word_history_in_minutes(db, user_book_id, minutes)
|
170 |
+
if len(memorizing_words) < k * batch_size:
|
171 |
+
# 1. ่ฎฐๅฟๆฐ่ฏๆฐ่ฟๅฐ
|
172 |
+
# ๆฐ่ฏๆน
|
173 |
+
logger.info("ๆฐ่ฏๆน")
|
174 |
+
return None
|
175 |
+
# ่ฎก็ฎ่ฎฐๅฟๆ็
|
176 |
+
memory_actions = get_actions_at_each_word(db, [w.vc_id for w in memorizing_words])
|
177 |
+
remember_count = defaultdict(int)
|
178 |
+
forget_count = defaultdict(int)
|
179 |
+
for a in memory_actions:
|
180 |
+
if a.action == "remember":
|
181 |
+
remember_count[a.word_id] += 1
|
182 |
+
else:
|
183 |
+
forget_count[a.word_id] += 1
|
184 |
+
word_id_to_efficiency = {}
|
185 |
+
for word in memorizing_words:
|
186 |
+
efficiency = remember_count[word.vc_id] / (remember_count[word.vc_id] + forget_count[word.vc_id])
|
187 |
+
word_id_to_efficiency[word.vc_id] = efficiency
|
188 |
+
logger.info([(w.vc_vocabulary, word_id_to_efficiency[w.vc_id]) for w in memorizing_words].sort(key=lambda x: x[1]))
|
189 |
+
if all([efficiency > right_bound for efficiency in word_id_to_efficiency.values()] + [count > 3 for count in remember_count.values()]):
|
190 |
+
# 2. ่ฎฐๅฟๆ็่ฟ้ซ
|
191 |
+
# ๆฐ่ฏๆน
|
192 |
+
logger.info("ๆฐ่ฏๆน")
|
193 |
+
return None
|
194 |
+
forgot_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if efficiency < left_bound]
|
195 |
+
forgot_word_ids.sort(key=lambda x: word_id_to_efficiency[x])
|
196 |
+
if len(forgot_word_ids) >= batch_size:
|
197 |
+
# 4. ๆญฃๅธธๆ
ๅต
|
198 |
+
# ๅคไน ๆน
|
199 |
+
logger.info("ๅคไน ๆน")
|
200 |
+
batch_words = [word for word in memorizing_words if word.vc_id in forgot_word_ids][:batch_size]
|
201 |
+
batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True)
|
202 |
+
batch_words_str_list = [word.vc_vocabulary for word in batch_words]
|
203 |
+
story, translated_story = generate_story_and_translated_story(batch_words_str_list)
|
204 |
+
user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate(
|
205 |
+
user_book_id=user_book_id,
|
206 |
+
story=story,
|
207 |
+
translated_story=translated_story,
|
208 |
+
batch_type="ๅคไน ",
|
209 |
+
))
|
210 |
+
create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
|
211 |
+
batch_id=user_memory_batch.id,
|
212 |
+
story=story,
|
213 |
+
translated_story=translated_story
|
214 |
+
))
|
215 |
+
for word in batch_words:
|
216 |
+
memory_word = UserMemoryWordCreate(
|
217 |
+
batch_id=user_memory_batch.id,
|
218 |
+
word_id=word.vc_id
|
219 |
+
)
|
220 |
+
db_memory_word = schema.UserMemoryWord(**memory_word.dict())
|
221 |
+
db.add(db_memory_word)
|
222 |
+
db.commit()
|
223 |
+
return user_memory_batch
|
224 |
+
unfarmiliar_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if left_bound <= efficiency < right_bound]
|
225 |
+
unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x])
|
226 |
+
if len(unfarmiliar_word_ids) < batch_size:
|
227 |
+
# ๆ่ฎฐไฝๆฌกๆฐๅฐ็ไนๅ ่ฟๆฅ
|
228 |
+
unfarmiliar_word_ids += [word_id for word_id, count in remember_count.items() if count < 3]
|
229 |
+
unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x])
|
230 |
+
if len(unfarmiliar_word_ids) >= batch_size:
|
231 |
+
# 3. ่ฎฐๅฟๆ็่ฟไฝ
|
232 |
+
# ๅๅฟๆน
|
233 |
+
logger.info("ๅๅฟๆน")
|
234 |
+
batch_words = [word for word in memorizing_words if word.vc_id in unfarmiliar_word_ids][:batch_size]
|
235 |
+
batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True)
|
236 |
+
batch_words_str_list = [word.vc_vocabulary for word in batch_words]
|
237 |
+
story, translated_story = generate_story_and_translated_story(batch_words_str_list)
|
238 |
+
user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate(
|
239 |
+
user_book_id=user_book_id,
|
240 |
+
story=story,
|
241 |
+
translated_story=translated_story,
|
242 |
+
batch_type="ๅๅฟ",
|
243 |
+
))
|
244 |
+
create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
|
245 |
+
batch_id=user_memory_batch.id,
|
246 |
+
story=story,
|
247 |
+
translated_story=translated_story
|
248 |
+
))
|
249 |
+
for word in batch_words:
|
250 |
+
memory_word = UserMemoryWordCreate(
|
251 |
+
batch_id=user_memory_batch.id,
|
252 |
+
word_id=word.vc_id
|
253 |
+
)
|
254 |
+
db_memory_word = schema.UserMemoryWord(**memory_word.dict())
|
255 |
+
db.add(db_memory_word)
|
256 |
+
db.commit()
|
257 |
+
return user_memory_batch
|
258 |
+
# 5. ๆญฃๅธธๆ
ๅต
|
259 |
+
# ๆฐ่ฏๆน
|
260 |
+
logger.info("ๆฐ่ฏๆน")
|
261 |
+
return None
|
262 |
+
|
263 |
+
|
264 |
+
|
output/logs/.gitkeep
ADDED
File without changes
|
story_agent.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from loguru import logger
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from langchain.chains import LLMChain
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain.output_parsers import (
|
9 |
+
PydanticOutputParser,
|
10 |
+
OutputFixingParser,
|
11 |
+
)
|
12 |
+
# LLM.py ๆฏๆ่ชๅทฑ็่ฏญ่จๆจกๅ๏ผไฝ ๅฏไปฅ็ดๆฅไฝฟ็จ openai ็
|
13 |
+
# from langchain.llms.openai import OpenAIChat
|
14 |
+
from LLM import OpenAIChat
|
15 |
+
|
16 |
+
TEMPLATE = """\
|
17 |
+
please write a story at least 5 sentences long, using the words [{words}].
|
18 |
+
|
19 |
+
{format}
|
20 |
+
|
21 |
+
Attention! The words should be highlighted, surrounded by "`". Therefore, the story should be in the following format.
|
22 |
+
English: ... `word1` ... `word2` ...
|
23 |
+
Chinese: ... `ๅ่ฏ1` ... `ๅ่ฏ2` ...
|
24 |
+
"""
|
25 |
+
|
26 |
+
class Story(BaseModel):
|
27 |
+
story: str = Field(description="the story")
|
28 |
+
translated_story: str = Field(description="the translated story")
|
29 |
+
|
30 |
+
llm = OpenAIChat(model_name="gpt-3.5-turbo", temperature=0.3)
|
31 |
+
parser = PydanticOutputParser(pydantic_object=Story)
|
32 |
+
prompt_template = PromptTemplate(
|
33 |
+
template=TEMPLATE,
|
34 |
+
input_variables=["words"],
|
35 |
+
partial_variables={
|
36 |
+
"format": parser.get_format_instructions(),
|
37 |
+
}
|
38 |
+
)
|
39 |
+
parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
|
40 |
+
chain = LLMChain(
|
41 |
+
llm=llm,
|
42 |
+
prompt=prompt_template,
|
43 |
+
output_parser=parser,
|
44 |
+
verbose=False,
|
45 |
+
)
|
46 |
+
|
47 |
+
def tell_story(words: List[str]):
|
48 |
+
count = 0
|
49 |
+
while count < 10:
|
50 |
+
count += 1
|
51 |
+
try:
|
52 |
+
resp: Story = chain.run(", ".join(words))
|
53 |
+
if len(resp.story.strip()) == 0:
|
54 |
+
continue
|
55 |
+
if len(resp.translated_story.strip()) == 0:
|
56 |
+
continue
|
57 |
+
return resp
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(e)
|
60 |
+
logger.error(traceback.format_exc())
|
61 |
+
logger.error("retrying...")
|
62 |
+
continue
|
63 |
+
return Story(story="", translated_story="")
|
64 |
+
|
65 |
+
def generate_story_and_translated_story(words: List[str]) -> Tuple[str, str]:
|
66 |
+
resp = tell_story(words)
|
67 |
+
return resp.story, resp.translated_story
|
68 |
+
|
69 |
+
|
web.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
import gradio as gr
|
3 |
+
from database.operation import *
|
4 |
+
from memorize import *
|
5 |
+
from database import SessionLocal, engine, Base
|
6 |
+
import database.schema as schema
|
7 |
+
from database import constant
|
8 |
+
import time
|
9 |
+
import asyncio
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
from collections import defaultdict
|
13 |
+
from datetime import datetime
|
14 |
+
|
15 |
+
|
16 |
+
Base.metadata.create_all(bind=engine)
|
17 |
+
db = SessionLocal()
|
18 |
+
|
19 |
+
@contextmanager
|
20 |
+
def session_scope():
|
21 |
+
try:
|
22 |
+
yield db
|
23 |
+
db.commit()
|
24 |
+
except Exception:
|
25 |
+
db.rollback()
|
26 |
+
raise
|
27 |
+
finally:
|
28 |
+
db.close()
|
29 |
+
intro = """\
|
30 |
+
็ฎๆ ๅบๆฏ๏ผๅช่่่ฎฐไฝๅ่ฏๅๅ
ถๆๆ๏ผไฝฟๅพ่ฝๆ ้็ข้
่ฏป๏ผไธ่่็จไบๅไฝใ
|
31 |
+
|
32 |
+
ไธป่ฆๆณๆณ๏ผๆน้่ฎฐๅ่ฏ๏ผๆฏๆน n ไธชๅ่ฏ๏ผ่ฟ n ไธชๅ่ฏ็จ AI ็ๆๆ
ไบ๏ผๅค่ฟฐๆ
ไบๅณๅฏ่ฎฐไฝๅ่ฏใ
|
33 |
+
|
34 |
+
ไธบไปไน๏ผ
|
35 |
+
|
36 |
+
- ๆน้่ฎฐๅ่ฏ๏ผไธๆฌกๅฏไปฅ่ฎฐไฝ n ไธชๅ่ฏ๏ผ่ไธๆฏไธไธชไธไธช่ฎฐ๏ผๆ็้ซใ
|
37 |
+
- ๅค่ฟฐๆ
ไบ๏ผๅณ่ดนๆผๅญฆไน ๆณ๏ผๆ
ไบๆฏๅ่ฏ็่ฎฐๅฟไน้ใ
|
38 |
+
- ๅค่ฟฐๆ
ไบ่ไธๆฏๅค่ฟฐๅ่ฏ๏ผๆ
ไบๅ
ทๆ่ฟ็ปญๆง๏ผๆด็ฌฆๅไบบ็ฑปๅคฉๆง๏ผๅฎนๆ่ฎฐใ
|
39 |
+
|
40 |
+
### ไฝฟ็จๅปบ่ฎฎ
|
41 |
+
|
42 |
+
1. ่ฎฐๅ่ฏๅ๏ผๅ
ๅฎๆด่ฟไธ้ๅ
จ้จๅ่ฏ๏ผๅ้คๅทฒ่ฎฐไฝ็ๅ่ฏ๏ผไป่ๆ้ซๆฐ่ฏๅฏๅบฆ
|
43 |
+
2. ๅ
็ๅ่ฏ่กจๆ ผ๏ผ็ถๅ็่ฑๆๆ
ไบ๏ผๅฏน็
งไธญๆๅฎๆ่ฎฐๅฟ
|
44 |
+
3. ่ฎฐๅฟๅฎๆๅ้่ฆไธไธชไธไธชๅพ้ๅทฒ่ฎฐไฝ็ๅ่ฏ๏ผๅพ้ๆถๅฐ่ฏๅค่ฟฐๅ่ฏๆๆ๏ผไปฅๆญคๆฅๆฃ้ช่ฎฐๅฟๆๆ
|
45 |
+
|
46 |
+
> ๆฌ้กน็ฎๅบไบ[ๅผๆบๆฐๆฎ้](https://github.com/LinXueyuanStdio/DictionaryData)๏ผๅนถไธ[ๅผๆบไปฃ็ ](https://github.com/LinXueyuanStdio/oh-my-words)๏ผๆฌข่ฟๅคงๅฎถ่ดก็ฎไปฃ็ ๏ฝ
|
47 |
+
"""
|
48 |
+
|
49 |
+
with gr.Blocks(title="ๆน้่ฎฐๅ่ฏ") as demo:
|
50 |
+
# gr.Markdown("# ๆน้่ฎฐๅ่ฏ")
|
51 |
+
gr.HTML("<h1 align=\"center\">ๆน้่ฎฐๅ่ฏ</h1>")
|
52 |
+
user = gr.State(value={})
|
53 |
+
|
54 |
+
# 0. ็ปๅฝ
|
55 |
+
with gr.Tab("ไธป้กต"):
|
56 |
+
gr.Markdown(intro)
|
57 |
+
gr.Markdown(f"ๅ
ฑ {get_book_count(db)} ๆฌไนฆ")
|
58 |
+
gr.HTML("""<iframe src="https://ghbtns.com/github-btn.html?user=LinXueyuanStdio&repo=oh-my-words&type=star&count=true&size=small" frameborder="0" scrolling="0" width="170" height="30" title="GitHub"></iframe>""")
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Column():
|
61 |
+
email = gr.TextArea(value=constant.email, lines=1, label="้ฎ็ฎฑ")
|
62 |
+
password = gr.TextArea(value=constant.password, lines=1, label="ๅฏ็ ")
|
63 |
+
login_btn = gr.Button("็ปๅฝ")
|
64 |
+
with gr.Column():
|
65 |
+
register_email = gr.TextArea(value='', lines=1, label="้ฎ็ฎฑ")
|
66 |
+
register_password = gr.TextArea(value='', lines=1, label="ๅฏ็ ")
|
67 |
+
register_btn = gr.Button("็ซๅณๆณจๅ", variant="primary")
|
68 |
+
user_status = gr.Textbox("", lines=1, label="็จๆท็ถๆ")
|
69 |
+
|
70 |
+
# 1. ๅๅปบ่ฎฐๅฟ่ฎกๅ
|
71 |
+
tab1 = gr.Tab("ๅๅปบ่ฎฐๅฟ่ฎกๅ", visible=False)
|
72 |
+
with tab1:
|
73 |
+
select_book = gr.Dropdown([], label="ๅ่ฏไนฆ", info="้ๆฉไธๆฌๅ่ฏไนฆ")
|
74 |
+
batch_size = gr.Number(value=10, label="ๆนๆฌกๅคงๅฐ")
|
75 |
+
randomize = gr.Checkbox(value=True, label="ไปฅๅ่ฏไนฑๅบ่ฟ่ก่ฎฐๅฟ")
|
76 |
+
title = gr.TextArea(value='ๅ่ฏไนฆ', lines=1, label="่ฎฐๅฟ่ฎกๅ็ๅ็งฐ")
|
77 |
+
btn = gr.Button("ๅๅปบ่ฎฐๅฟ่ฎกๅ")
|
78 |
+
status = gr.Textbox("", lines=1, label="็ถๆ")
|
79 |
+
|
80 |
+
def submit(user: Dict[str, str], book, title, randomize, batch_size):
|
81 |
+
user_id = user.get("id", None)
|
82 |
+
if user_id is None:
|
83 |
+
gr.Error("่ฏทๅ
็ปๅฝ")
|
84 |
+
return "่ฏทๅ
็ปๅฝ"
|
85 |
+
book_id = book.split(" [")[1][:-1]
|
86 |
+
user_book = create_user_book(db, UserBookCreate(
|
87 |
+
owner_id=user_id,
|
88 |
+
book_id=book_id,
|
89 |
+
title=title,
|
90 |
+
random=randomize,
|
91 |
+
batch_size=batch_size
|
92 |
+
))
|
93 |
+
if user_book is not None:
|
94 |
+
return "ๆๅ"
|
95 |
+
else:
|
96 |
+
return "ๅคฑ่ดฅ"
|
97 |
+
|
98 |
+
btn.click(submit, [user, select_book, title, randomize, batch_size], [status])
|
99 |
+
def on_select(user: Dict[str, str], evt: gr.SelectData):
|
100 |
+
user_id = user.get("id", None)
|
101 |
+
new_options = []
|
102 |
+
if user_id is None:
|
103 |
+
return gr.Dropdown(choices=new_options), "่ฏทๅ
็ปๅฝ"
|
104 |
+
books = get_all_books_for_user(db, user_id)
|
105 |
+
new_options = [f"{'โญ ' if book.permission == 'private' else ''}{book.bk_name} (ๅ
ฑ {book.bk_item_num} ่ฏ) [{book.bk_id}]" for book in books]
|
106 |
+
return gr.Dropdown(choices=new_options), f"ๆจๅฅฝ๏ผ{user['email']}"
|
107 |
+
tab1.select(on_select, [user], [select_book, status])
|
108 |
+
|
109 |
+
# 2. ้ๆฉๅ่ฏๅๆน
|
110 |
+
with gr.Tab("้ๆฉๅ่ฏๅๆน") as tab2:
|
111 |
+
select_user_book = gr.Dropdown(
|
112 |
+
[], label="่ฎฐๅฟ่ฎกๅ", info="่ฏท้ๆฉ่ฎฐๅฟ่ฎกๅ"
|
113 |
+
)
|
114 |
+
word_count = gr.Number(value=0, label="ๅ่ฏไธชๆฐ")
|
115 |
+
known_words = gr.CheckboxGroup(
|
116 |
+
[], label="ๅทฒๅญฆไผ็ๅ่ฏ", info="ๆญฃๅผ่ฎฐๅฟๅๅฐๅป้คๅทฒๅญฆไผ็ๅ่ฏ๏ผๆ้ซๆฏไธชๆนๆฌก็ๆฐ่ฏๅฏๅบฆ๏ผ่ฟ่ๆ้ซๆ็"
|
117 |
+
)
|
118 |
+
btn = gr.Button("็ๆๆนๆฌก")
|
119 |
+
status = gr.Textbox("3000 ่ฏๅคงๆฆ่ฆ 2 ๅฐๆถๆ่ฝๅๅฎๆๆ็ๆ
ไบ", lines=1, label="็ๆ็ปๆ")
|
120 |
+
|
121 |
+
def on_select_user(user):
|
122 |
+
user_id = user.get("id", None)
|
123 |
+
if user_id is None:
|
124 |
+
gr.Error("่ฏทๅ
็ปๅฝ")
|
125 |
+
return gr.Dropdown(choices=[]), "่ฏทๅ
็ปๅฝ"
|
126 |
+
new_options = []
|
127 |
+
user_book = get_user_books_by_owner_id(db, user_id)
|
128 |
+
new_options = [f"{book.title} | {book.batch_size}ไธชๅ่ฏไธ็ป [{book.id}]" for book in user_book]
|
129 |
+
return gr.Dropdown(choices=new_options), "3000 ่ฏๅคงๆฆ่ฆ 2 ๅฐๆถๆ่ฝๅๅฎๆๆ็ๆ
ไบ"
|
130 |
+
|
131 |
+
def on_select_user_book(user_book):
|
132 |
+
logger.debug(f'user_book {user_book}')
|
133 |
+
if user_book is None:
|
134 |
+
return 0, gr.CheckboxGroup(choices=[])
|
135 |
+
new_options = []
|
136 |
+
user_book_id = user_book.split(" [")[1][:-1]
|
137 |
+
user_book = get_user_book(db, user_book_id)
|
138 |
+
book_id = user_book.book_id
|
139 |
+
book = get_book(db, book_id)
|
140 |
+
if book is None:
|
141 |
+
return 0, gr.CheckboxGroup(choices=[])
|
142 |
+
words = get_words_for_book(db, user_book)
|
143 |
+
new_options = [f"{word.vc_vocabulary}" for word in words]
|
144 |
+
return len(words), gr.CheckboxGroup(choices=new_options)
|
145 |
+
|
146 |
+
select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[word_count, known_words])
|
147 |
+
tab2.select(on_select_user, [user], [select_user_book, status])
|
148 |
+
|
149 |
+
def submit(user_book, known_words):
|
150 |
+
start_time = time.time()
|
151 |
+
user_book_id = user_book.split(" [")[1][:-1]
|
152 |
+
user_book = get_user_book(db, user_book_id)
|
153 |
+
all_words = get_words_for_book(db, user_book)
|
154 |
+
unknown_words = []
|
155 |
+
for w in all_words:
|
156 |
+
if w.vc_vocabulary not in known_words:
|
157 |
+
unknown_words.append(w)
|
158 |
+
track(db, user_book, unknown_words)
|
159 |
+
end_time = time.time()
|
160 |
+
duration = end_time - start_time
|
161 |
+
return f"ๆๅ๏ผๅไธบ {len(unknown_words) // user_book.batch_size} ไธชๆนๆฌก๏ผๅ
ฑ {len(unknown_words)} ไธชๅ่ฏ๏ผ่ๆถ {duration:.2f} ็ง"
|
162 |
+
|
163 |
+
btn.click(submit, [select_user_book, known_words], [status])
|
164 |
+
|
165 |
+
# 3. ่ฎฐๅฟ
|
166 |
+
with gr.Tab("่ฎฐๅฟ") as tab3:
|
167 |
+
select_user_book = gr.Dropdown(
|
168 |
+
[], label="่ฎฐๅฟ่ฎกๅ", info="่ฏท้ๆฉ่ฎฐๅฟ่ฎกๅ"
|
169 |
+
)
|
170 |
+
info = gr.Accordion(f"ๆฐ่ฏ", open=False)
|
171 |
+
with info:
|
172 |
+
gr.Markdown(f"ๆฐ่ฏ")
|
173 |
+
|
174 |
+
dataframe_header = ["ๅ่ฏ", "ไธญๆ่ฏๆ", "่ฑๅผ้ณๆ ", "็พๅผ้ณๆ ", "่ฎฐๅฟ้"]
|
175 |
+
memorizing_dataframe = gr.Dataframe(
|
176 |
+
headers=dataframe_header,
|
177 |
+
datatype=["str"] * len(dataframe_header),
|
178 |
+
col_count=(len(dataframe_header), "fixed"),
|
179 |
+
wrap=True,
|
180 |
+
)
|
181 |
+
batches = gr.State(value=[])
|
182 |
+
current_batch_index = gr.State(value=-1)
|
183 |
+
user_book_id = gr.State(value=None)
|
184 |
+
with gr.Row():
|
185 |
+
# story = gr.HighlightedText([])
|
186 |
+
# translated_story = gr.HighlightedText([])
|
187 |
+
# story = gr.Textbox()
|
188 |
+
# translated_story = gr.Textbox()
|
189 |
+
story = gr.Markdown()
|
190 |
+
translated_story = gr.Markdown()
|
191 |
+
# ่ฏไบไธไธ๏ผ่ฟๆฏ markdown ็ๆพ็คบๆๆๅฅฝ
|
192 |
+
|
193 |
+
memorize_action = gr.CheckboxGroup(choices=[], label="่ฎฐไฝ็ๅ่ฏ", info="่ฝๅคๅค่ฟฐๅบๆๆๆ็ฎ่ฎฐไฝ")
|
194 |
+
with gr.Row():
|
195 |
+
previous_batch_btn = gr.Button("ไธไธๆน")
|
196 |
+
regenerate_btn = gr.Button("้ๆฐ็ๆๆ
ไบ")
|
197 |
+
next_batch_btn = gr.Button("ไธไธๆน", variant="primary")
|
198 |
+
progress = gr.Slider(1, 1, value=1, step=1, label="่ฟๅบฆ", info="")
|
199 |
+
|
200 |
+
def on_select_user(user):
|
201 |
+
user_id = user.get("id", None)
|
202 |
+
if user_id is None:
|
203 |
+
gr.Error("่ฏทๅ
็ปๅฝ")
|
204 |
+
return gr.Dropdown(choices=[])
|
205 |
+
new_options = []
|
206 |
+
user_book = get_user_books_by_owner_id(db, user_id)
|
207 |
+
new_options = [f"{book.title} | {book.batch_size}ไธชๅ่ฏไธ็ป [{book.id}]" for book in user_book]
|
208 |
+
return gr.Dropdown(choices=new_options)
|
209 |
+
|
210 |
+
def update_from_batch(memorizing_batch: UserMemoryBatch):
|
211 |
+
new_options = []
|
212 |
+
word_df = []
|
213 |
+
# logger.debug(get_user_memory_batch(db, memorizing_batch.id))
|
214 |
+
# logger.debug(memorizing_batch.id)
|
215 |
+
# logger.debug(get_user_memory_words_by_batch_id(db, memorizing_batch.id))
|
216 |
+
# logger.debug(get_words_by_ids(db, [w.word_id for w in get_user_memory_words_by_batch_id(db, memorizing_batch.id)]))
|
217 |
+
# words = get_words_in_batch(db, memorizing_batch.id)
|
218 |
+
# words = get_words_by_ids(db, [w.word_id for w in memorizing_words])
|
219 |
+
memorizing_words = get_user_memory_words_by_batch_id(db, memorizing_batch.id)
|
220 |
+
words = get_words_by_ids(db, [w.word_id for w in memorizing_words])
|
221 |
+
# ็ป่ฎก่ฎฐๅฟ้
|
222 |
+
actions = get_actions_at_each_word(db, [w.word_id for w in memorizing_words])
|
223 |
+
remember_count = defaultdict(int)
|
224 |
+
forget_count = defaultdict(int)
|
225 |
+
for a in actions:
|
226 |
+
if a.action == "remember":
|
227 |
+
remember_count[a.word_id] += 1
|
228 |
+
else:
|
229 |
+
forget_count[a.word_id] += 1
|
230 |
+
# ็ป่ฎก่ฎฐๅฟๆ็
|
231 |
+
batch_actions = get_user_memory_batch_actions_by_user_memory_batch_id(db, memorizing_batch.id)
|
232 |
+
batch_actions.sort(key=lambda x: x.create_time)
|
233 |
+
start, end = None, None
|
234 |
+
total_duration = None
|
235 |
+
for a in batch_actions:
|
236 |
+
if a.action == "start":
|
237 |
+
start: datetime = a.create_time
|
238 |
+
elif a.action == "end":
|
239 |
+
end: datetime = a.create_time
|
240 |
+
if start is None:
|
241 |
+
continue
|
242 |
+
if total_duration is None:
|
243 |
+
total_duration = end - start
|
244 |
+
else:
|
245 |
+
total_duration += end - start
|
246 |
+
memory_speed = f"{memorizing_batch.batch_type}"
|
247 |
+
if total_duration is not None:
|
248 |
+
sec = total_duration.total_seconds()
|
249 |
+
minutes = sec / 60
|
250 |
+
memory_speed += f"๏ผๅฝๅๆนๆฌก่ฎฐๅฟๆ็ {len(memorizing_words) / minutes:.2f} ่ฏ/ๅ้๏ผ{minutes:.2f} ๅ้/ๆนๆฌก"
|
251 |
+
# ๅ่ฏไฟกๆฏ่กจๆ ผไธๅพ้
|
252 |
+
for w in words:
|
253 |
+
new_options.append(f"{w.vc_vocabulary}")
|
254 |
+
word_df.append([
|
255 |
+
w.vc_vocabulary, # ๅ่ฏ
|
256 |
+
w.vc_translation, # ไธญๆ่ฏๆ
|
257 |
+
w.vc_phonetic_uk, # ่ฑๅผ้ณๆ
|
258 |
+
w.vc_phonetic_us, # ็พๅผ้ณๆ
|
259 |
+
f"{remember_count[w.vc_id]} / {remember_count[w.vc_id] + forget_count[w.vc_id]}", # ่ฎฐๅฟ้
|
260 |
+
])
|
261 |
+
df = pd.DataFrame(word_df, columns=dataframe_header)
|
262 |
+
if memorizing_batch.batch_type == "ๅๅฟ":
|
263 |
+
df = pd.DataFrame([[row[0], "", row[2], row[3], row[4]] for row in word_df], columns=dataframe_header)
|
264 |
+
# ๆ
ไบ
|
265 |
+
story = memorizing_batch.story
|
266 |
+
translated_story = memorizing_batch.translated_story
|
267 |
+
if len(story) == 0 or len(translated_story) == 0:
|
268 |
+
story, translated_story = regenerate_for_batch(memorizing_batch, words)
|
269 |
+
|
270 |
+
logger.info("่ฎก็ฎๆนๆฌกไฟกๆฏ")
|
271 |
+
logger.info(new_options)
|
272 |
+
logger.info(story)
|
273 |
+
logger.info(translated_story)
|
274 |
+
logger.info("=" * 8)
|
275 |
+
return (gr.Accordion(label=memory_speed), df, story, translated_story, gr.CheckboxGroup(choices=new_options))
|
276 |
+
|
277 |
+
def on_select_user_book(user_book_id: str):
|
278 |
+
"""
|
279 |
+
1. ๅฝๅๅ่ฏ
|
280 |
+
2. ๅฏนๅฝๅๅ่ฏ็ๆไฝ
|
281 |
+
3. ๆ
ไบ
|
282 |
+
"""
|
283 |
+
logger.debug(f'user_book {user_book_id}')
|
284 |
+
if user_book_id is None:
|
285 |
+
# ไธบไปไนไผ็ฉบ๏ผ่ฟ้่ฟๅ็ไธ่ฅฟๅฏ่ฝไผ็็ธ๏ผไฝๅฅฝๅๆง่กไธๅฐ่ฟ้
|
286 |
+
# ไธ็ฎกไบ๏ผๆพไธชๅ็คบ็ๅจ่ฟ้๏ผๅคงๅฎถ็่ง่ฟไธชๅ่ฏท็ป็่ตฐ
|
287 |
+
return [], gr.CheckboxGroup(choices=[])
|
288 |
+
user_book_id: str = user_book_id.split(" [")[1][:-1]
|
289 |
+
user_book = get_user_book(db, user_book_id)
|
290 |
+
batches = get_new_user_memory_batches_by_user_book_id(db, user_book_id) # ๅช็ผๅญๆฐ่ฏ
|
291 |
+
batch_id = user_book.memorizing_batch
|
292 |
+
memorizing_batch = get_user_memory_batch(db, batch_id)
|
293 |
+
current_batch_index = -1
|
294 |
+
if memorizing_batch is not None:
|
295 |
+
for index, b in enumerate(batches):
|
296 |
+
if b.id == memorizing_batch.id:
|
297 |
+
current_batch_index = index
|
298 |
+
break
|
299 |
+
if current_batch_index == -1:
|
300 |
+
# ๅฝๅ่ฟๆฒกๅผๅง่ฎฐๅฟ๏ผๆ่
ๅฝๅๆนๆฌกไธๆฏๆฐ่ฏๆนๆฌก
|
301 |
+
current_batch_index = 0
|
302 |
+
memorizing_batch = batches[0]
|
303 |
+
batch_id = memorizing_batch.id
|
304 |
+
user_book.memorizing_batch = batch_id
|
305 |
+
update_user_book(db, user_book_id, UserBookUpdate(
|
306 |
+
owner_id=user_book.owner_id,
|
307 |
+
book_id=user_book.book_id,
|
308 |
+
title=user_book.title,
|
309 |
+
random=user_book.random,
|
310 |
+
batch_size=user_book.batch_size,
|
311 |
+
memorizing_batch=batch_id
|
312 |
+
))
|
313 |
+
updates = update_from_batch(memorizing_batch)
|
314 |
+
on_batch_start(db, memorizing_batch.id)
|
315 |
+
asyncio.run(pregenerate(batches, current_batch_index))
|
316 |
+
return (batches, current_batch_index, user_book) + updates + (
|
317 |
+
gr.Slider(
|
318 |
+
minimum=1,
|
319 |
+
maximum=len(batches),
|
320 |
+
value=current_batch_index,
|
321 |
+
),)
|
322 |
+
|
323 |
+
batch_widget = [info, memorizing_dataframe, story, translated_story, memorize_action]
|
324 |
+
tab3.select(on_select_user, inputs=[user], outputs=[select_user_book])
|
325 |
+
select_user_book.select(
|
326 |
+
on_select_user_book,
|
327 |
+
inputs=[select_user_book],
|
328 |
+
outputs=[batches, current_batch_index, user_book_id] + batch_widget + [progress]
|
329 |
+
)
|
330 |
+
async def worker_regenerate_for_batch(batches: List[UserMemoryBatch], index: int):
|
331 |
+
started_at = time.monotonic()
|
332 |
+
logger.info(f"started {index}")
|
333 |
+
# start
|
334 |
+
batch = batches[index]
|
335 |
+
story = batch.story
|
336 |
+
translated_story = batch.translated_story
|
337 |
+
if len(story) == 0 or len(translated_story) == 0:
|
338 |
+
batch_words = get_words_in_batch(db, batch.id)
|
339 |
+
regenerate_for_batch(batch, batch_words)
|
340 |
+
# end
|
341 |
+
total = time.monotonic() - started_at
|
342 |
+
logger.info(f'completed in {total:.2f} seconds')
|
343 |
+
|
344 |
+
async def pregenerate(batches: List[UserMemoryBatch], current_batch_index: int):
|
345 |
+
logger.info("ๅผๅง้ข็ๆๆ
ไบ")
|
346 |
+
indexes = [current_batch_index+i+1 for i in range(3)]+[current_batch_index-i-1 for i in range(3)]
|
347 |
+
indexes = [i for i in indexes if 0 <= i < len(batches)]
|
348 |
+
for index in indexes:
|
349 |
+
asyncio.ensure_future(worker_regenerate_for_batch(batches, index))
|
350 |
+
logger.info("็ปๆ้ข็ๆๆ
ไบ")
|
351 |
+
|
352 |
+
def submit_batch(batches: List[UserMemoryBatch], current_batch_index: int):
|
353 |
+
memorizing_batch = batches[current_batch_index]
|
354 |
+
return set_memorizing_batch(batches, current_batch_index, memorizing_batch)
|
355 |
+
|
356 |
+
def set_memorizing_batch(batches: List[UserMemoryBatch], current_batch_index: int, memorizing_batch: UserMemoryBatch):
|
357 |
+
updates = update_from_batch(memorizing_batch)
|
358 |
+
asyncio.run(pregenerate(batches, current_batch_index))
|
359 |
+
logger.info("pregenerated")
|
360 |
+
return updates + (gr.Slider(value=current_batch_index+1), current_batch_index)
|
361 |
+
|
362 |
+
def save_progress(old_batch: UserMemoryBatch, memorize_action: List[str]):
|
363 |
+
# ไฟๅญๅ่ฏ่ฎฐๅฟ่ฟๅบฆ
|
364 |
+
actions = []
|
365 |
+
words = get_words_in_batch(db, old_batch.id)
|
366 |
+
for word in words:
|
367 |
+
if word.vc_vocabulary in memorize_action:
|
368 |
+
actions.append((word.vc_id, "remember"))
|
369 |
+
else:
|
370 |
+
actions.append((word.vc_id, "forget"))
|
371 |
+
save_memorizing_word_action(db, old_batch.id, actions)
|
372 |
+
|
373 |
+
def previous_batch(batches: List[UserMemoryBatch], current_batch_index: int, user_book: schema.UserBook, memorize_action: List[str]):
|
374 |
+
old_index = current_batch_index
|
375 |
+
if current_batch_index <= 0:
|
376 |
+
current_batch_index = 0
|
377 |
+
elif current_batch_index > 0:
|
378 |
+
current_batch_index -= 1
|
379 |
+
if current_batch_index != old_index:
|
380 |
+
# ไธไธ้กตไนๅ้่ฆไฟๅญ่ฎฐๅฟ่ฟๅบฆ
|
381 |
+
# logger.info("ไธไธ้กตไนๅ้่ฆไฟๅญ่ฎฐๅฟ่ฟๅบฆ")
|
382 |
+
# logger.info(memorize_action)
|
383 |
+
# ไฟๅญๆนๆฌก่ฟๅบฆ
|
384 |
+
old_batch = batches[old_index]
|
385 |
+
current_batch = batches[current_batch_index]
|
386 |
+
save_progress(old_batch, memorize_action)
|
387 |
+
on_batch_end(db, old_batch.id)
|
388 |
+
on_batch_start(db, current_batch.id)
|
389 |
+
user_book_id = user_book.id
|
390 |
+
update_user_book_memorizing_batch(db, user_book_id, current_batch.id)
|
391 |
+
return submit_batch(batches, current_batch_index)
|
392 |
+
|
393 |
+
def next_batch(batches: List[UserMemoryBatch], current_batch_index: int, user_book: schema.UserBook, memorize_action: List[str]):
|
394 |
+
old_index = current_batch_index
|
395 |
+
if current_batch_index >= len(batches)-1:
|
396 |
+
current_batch_index = len(batches)-1
|
397 |
+
elif current_batch_index < len(batches) - 1:
|
398 |
+
current_batch_index += 1
|
399 |
+
if current_batch_index != old_index:
|
400 |
+
# ไธไธ้กตไนๅ้่ฆไฟๅญ่ฎฐๅฟ่ฟๅบฆ
|
401 |
+
# logger.info("ไธไธ้กตไนๅ้่ฆไฟๅญ่ฎฐๅฟ่ฟๅบฆ")
|
402 |
+
# logger.info(memorize_action)
|
403 |
+
# ไฟๅญๆนๆฌก่ฟๅบฆ
|
404 |
+
old_batch = batches[old_index]
|
405 |
+
memorizing_batch = get_user_memory_batch(db, user_book.memorizing_batch)
|
406 |
+
if memorizing_batch is not None:
|
407 |
+
old_batch = memorizing_batch
|
408 |
+
current_batch = batches[current_batch_index]
|
409 |
+
save_progress(old_batch, memorize_action)
|
410 |
+
on_batch_end(db, old_batch.id)
|
411 |
+
next_batch = generate_next_batch(db, user_book, minutes=60, k=3)
|
412 |
+
if next_batch is not None:
|
413 |
+
current_batch = next_batch
|
414 |
+
on_batch_start(db, current_batch.id)
|
415 |
+
user_book_id = user_book.id
|
416 |
+
update_user_book_memorizing_batch(db, user_book_id, current_batch.id)
|
417 |
+
if next_batch is not None:
|
418 |
+
return set_memorizing_batch(batches, old_index, current_batch)
|
419 |
+
else:
|
420 |
+
return set_memorizing_batch(batches, current_batch_index, current_batch)
|
421 |
+
else:
|
422 |
+
memorizing_batch = get_user_memory_batch(db, user_book.memorizing_batch)
|
423 |
+
current_batch = batches[current_batch_index]
|
424 |
+
save_progress(memorizing_batch, memorize_action)
|
425 |
+
on_batch_end(db, memorizing_batch.id)
|
426 |
+
next_batch = generate_next_batch(db, user_book, minutes=60, k=3)
|
427 |
+
if next_batch is not None:
|
428 |
+
current_batch = next_batch
|
429 |
+
on_batch_start(db, current_batch.id)
|
430 |
+
user_book_id = user_book.id
|
431 |
+
update_user_book_memorizing_batch(db, user_book_id, current_batch.id)
|
432 |
+
if next_batch is not None:
|
433 |
+
return set_memorizing_batch(batches, old_index, current_batch)
|
434 |
+
else:
|
435 |
+
return set_memorizing_batch(batches, current_batch_index, current_batch)
|
436 |
+
previous_batch_btn.click(
|
437 |
+
previous_batch,
|
438 |
+
inputs=[batches, current_batch_index, user_book_id, memorize_action],
|
439 |
+
outputs=batch_widget + [progress, current_batch_index]
|
440 |
+
)
|
441 |
+
next_batch_btn.click(
|
442 |
+
next_batch,
|
443 |
+
inputs=[batches, current_batch_index, user_book_id, memorize_action],
|
444 |
+
outputs=batch_widget + [progress, current_batch_index]
|
445 |
+
)
|
446 |
+
|
447 |
+
def regenerate_for_batch(memorizing_batch: UserMemoryBatch, batch_words: List[Word]):
|
448 |
+
batch_words_str_list = [word.vc_vocabulary for word in batch_words]
|
449 |
+
logger.info(f"็ๆๆ
ไบ {batch_words_str_list}")
|
450 |
+
story, translated_story = generate_story_and_translated_story(batch_words_str_list)
|
451 |
+
memorizing_batch.story = story
|
452 |
+
memorizing_batch.translated_story = translated_story
|
453 |
+
db.commit()
|
454 |
+
db.refresh(memorizing_batch)
|
455 |
+
create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate(
|
456 |
+
batch_id=memorizing_batch.id,
|
457 |
+
story=story,
|
458 |
+
translated_story=translated_story
|
459 |
+
))
|
460 |
+
logger.info(story)
|
461 |
+
logger.info(translated_story)
|
462 |
+
return story, translated_story
|
463 |
+
|
464 |
+
def regenerate(batches: List[UserMemoryBatch], current_batch_index: int):
|
465 |
+
# ้ๆฐ็ๆๆ
ไบ
|
466 |
+
memorizing_batch = batches[current_batch_index]
|
467 |
+
batch_words = get_words_in_batch(db, memorizing_batch.id)
|
468 |
+
story, translated_story = regenerate_for_batch(memorizing_batch, batch_words)
|
469 |
+
return story, translated_story
|
470 |
+
regenerate_btn.click(regenerate, inputs=[batches, current_batch_index], outputs=[story, translated_story])
|
471 |
+
|
472 |
+
# 4. ไป่ฎฐๅฟ่ฎกๅไธญๅๅปบๅ่ฏไนฆ
|
473 |
+
with gr.Tab("ไป่ฎฐๅฟ่ฎกๅไธญๅๅปบๅ่ฏไนฆ") as tab4:
|
474 |
+
select_user_book = gr.Dropdown(
|
475 |
+
[], label="่ฎฐๅฟ่ฎกๅ", info="่ฏท้ๆฉ่ฎฐๅฟ่ฎกๅ"
|
476 |
+
)
|
477 |
+
word_count = gr.Number(value=0, label="ๅ่ฏไธชๆฐ")
|
478 |
+
known_words = gr.CheckboxGroup(
|
479 |
+
[], label="ๅทฒๅญฆไผ็ๅ่ฏ", info="่ฏทๆฃๆฅๅทฒๅญฆไผ็ๅ่ฏ๏ผ่ฟไบๅ่ฏๅฐไธไผ่ขซๅ
ๅซๅจๆฐ็ๅ่ฏไนฆไธญ"
|
480 |
+
)
|
481 |
+
title = gr.TextArea(value='ๅ่ฏไนฆ', lines=1, label="ๅ่ฏไนฆ็ๅ็งฐ")
|
482 |
+
btn = gr.Button("ไป่ฎฐๅฟ่ฎกๅไธญๅๅปบๅ่ฏไนฆ")
|
483 |
+
status = gr.Textbox("", lines=1, label="็ถๆ")
|
484 |
+
|
485 |
+
def on_select_user(user):
|
486 |
+
user_id = user.get("id", None)
|
487 |
+
if user_id is None:
|
488 |
+
gr.Error("่ฏทๅ
็ปๅฝ")
|
489 |
+
return gr.Dropdown(choices=[])
|
490 |
+
new_options = []
|
491 |
+
user_book = get_user_books_by_owner_id(db, user_id)
|
492 |
+
new_options = [f"{book.title} | {book.batch_size}ไธชๅ่ฏไธ็ป [{book.id}]" for book in user_book]
|
493 |
+
return gr.Dropdown(choices=new_options)
|
494 |
+
|
495 |
+
def on_select_user_book(user_book):
|
496 |
+
logger.debug(f'user_book {user_book}')
|
497 |
+
if user_book is None:
|
498 |
+
return 0, gr.CheckboxGroup(choices=[])
|
499 |
+
new_options = []
|
500 |
+
user_book_id = user_book.split(" [")[1][:-1]
|
501 |
+
words = get_words_in_user_book(db, user_book_id)
|
502 |
+
new_options = [f"{word.vc_vocabulary}" for word in words]
|
503 |
+
return len(words), gr.CheckboxGroup(choices=new_options)
|
504 |
+
|
505 |
+
tab4.select(on_select_user, inputs=[user], outputs=[select_user_book])
|
506 |
+
select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[word_count, known_words])
|
507 |
+
|
508 |
+
def submit(user, user_book, known_words, title):
|
509 |
+
user_id = user.get("id", None)
|
510 |
+
if user_id is None:
|
511 |
+
gr.Error("่ฏทๅ
็ปๅฝ")
|
512 |
+
return "่ฏทๅ
็ปๅฝ"
|
513 |
+
user_book_id = user_book.split(" [")[1][:-1]
|
514 |
+
all_words = get_words_in_user_book(db, user_book_id)
|
515 |
+
unknown_words = []
|
516 |
+
for w in all_words:
|
517 |
+
if w.vc_vocabulary not in known_words:
|
518 |
+
unknown_words.append(w)
|
519 |
+
# all_words = get_words_by_vocabulary(db, known_words)
|
520 |
+
book = save_words_as_book(db, user_id, unknown_words, title)
|
521 |
+
if book is not None:
|
522 |
+
return f"ๆๅ็ๆไธๆฌๅ่ฏไนฆ๏ผ{book.bk_name}"
|
523 |
+
else:
|
524 |
+
return "ๅคฑ่ดฅ"
|
525 |
+
|
526 |
+
btn.click(submit, [user, select_user_book, known_words, title], [status])
|
527 |
+
|
528 |
+
# 5. ็ป่ฎก
|
529 |
+
with gr.Tab("็ป่ฎก") as tab5:
|
530 |
+
# 5.1. ๆ
ไบ็ๆๅๅฒ
|
531 |
+
with gr.Tab("AI ๅๅฒ่ฎฐๅฝ") as tab51:
|
532 |
+
select_user_book = gr.Dropdown(
|
533 |
+
[], label="่ฎฐๅฟ่ฎกๅ", info="่ฏท้ๆฉ่ฎฐๅฟ่ฎกๅ"
|
534 |
+
)
|
535 |
+
|
536 |
+
history_header = ["ๅ่ฏ", "ๆ
ไบ", "ไธญๆๆ
ไบ", "็ๆๆถ้ด"]
|
537 |
+
history_dataframe = gr.Dataframe(
|
538 |
+
headers=history_header,
|
539 |
+
datatype=["str"] * len(history_header),
|
540 |
+
col_count=(len(history_header), "fixed"),
|
541 |
+
wrap=True,
|
542 |
+
min_width=320,
|
543 |
+
height=800,
|
544 |
+
)
|
545 |
+
|
546 |
+
def on_select_user(user):
|
547 |
+
user_id = user.get("id", None)
|
548 |
+
if user_id is None:
|
549 |
+
gr.Error("่ฏทๅ
็ปๅฝ")
|
550 |
+
return gr.Dropdown(choices=[])
|
551 |
+
new_options = []
|
552 |
+
user_book = get_user_books_by_owner_id(db, user_id)
|
553 |
+
new_options = [f"{book.title} | {book.batch_size}ไธชๅ่ฏไธ็ป [{book.id}]" for book in user_book]
|
554 |
+
return gr.Dropdown(choices=new_options)
|
555 |
+
|
556 |
+
def on_select_user_book(user_book_id):
|
557 |
+
logger.debug(f'user_book {user_book_id}')
|
558 |
+
if user_book_id is None:
|
559 |
+
return 0, gr.CheckboxGroup(choices=[])
|
560 |
+
user_book_id = user_book_id.split(" [")[1][:-1]
|
561 |
+
batch_id_to_words_and_history = get_generation_hostorys_by_user_book_id(db, user_book_id)
|
562 |
+
data = []
|
563 |
+
for batch_id, (words, histories) in batch_id_to_words_and_history.items():
|
564 |
+
for history in histories:
|
565 |
+
word = ", ".join([w.vc_vocabulary for w in words])
|
566 |
+
story = history.story
|
567 |
+
translated_story = history.translated_story
|
568 |
+
create_time = history.create_time
|
569 |
+
data.append([word, story, translated_story, create_time])
|
570 |
+
df = pd.DataFrame(data, columns=history_header)
|
571 |
+
return df
|
572 |
+
|
573 |
+
tab51.select(on_select_user, inputs=[user], outputs=[select_user_book])
|
574 |
+
select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[history_dataframe])
|
575 |
+
|
576 |
+
# 5.2. ่ฎฐๅฟๅๅฒ่ฎฐๅฝ
|
577 |
+
with gr.Tab("่ฎฐๅฟๅๅฒ่ฎฐๅฝ") as tab52:
|
578 |
+
select_user_book = gr.Dropdown(
|
579 |
+
[], label="่ฎฐๅฟ่ฎกๅ", info="่ฏท้ๆฉ่ฎฐๅฟ่ฎกๅ"
|
580 |
+
)
|
581 |
+
|
582 |
+
batch_history_header = ["ๅ่ฏ", "ๆ
ไบ", "ไธญๆๆ
ไบ", "ๆนๆฌก็ฑปๅ", "่ฎฐๅฟๆ
ๅต", "็ๆๆถ้ด"]
|
583 |
+
batch_history_dataframe = gr.Dataframe(
|
584 |
+
headers=batch_history_header,
|
585 |
+
datatype=["str"] * len(batch_history_header),
|
586 |
+
col_count=(len(batch_history_header), "fixed"),
|
587 |
+
wrap=True,
|
588 |
+
min_width=320,
|
589 |
+
height=800,
|
590 |
+
)
|
591 |
+
|
592 |
+
def on_select_user(user):
|
593 |
+
user_id = user.get("id", None)
|
594 |
+
if user_id is None:
|
595 |
+
gr.Error("่ฏทๅ
็ปๅฝ")
|
596 |
+
return gr.Dropdown(choices=[])
|
597 |
+
new_options = []
|
598 |
+
user_book = get_user_books_by_owner_id(db, user_id)
|
599 |
+
new_options = [f"{book.title} | {book.batch_size}ไธชๅ่ฏไธ็ป [{book.id}]" for book in user_book]
|
600 |
+
return gr.Dropdown(choices=new_options)
|
601 |
+
|
602 |
+
def on_select_user_book(user_book_id):
|
603 |
+
logger.debug(f'user_book {user_book_id}')
|
604 |
+
if user_book_id is None:
|
605 |
+
return 0, gr.CheckboxGroup(choices=[])
|
606 |
+
user_book_id = user_book_id.split(" [")[1][:-1]
|
607 |
+
actions, batch_id_to_batch, batch_id_to_words, batch_id_to_actions = get_user_memory_batch_history(db, user_book_id)
|
608 |
+
data = []
|
609 |
+
for action in actions:
|
610 |
+
batch_id = action.batch_id
|
611 |
+
|
612 |
+
words = batch_id_to_words[batch_id]
|
613 |
+
word = ", ".join([w.vc_vocabulary for w in words])
|
614 |
+
|
615 |
+
batch = batch_id_to_batch[batch_id]
|
616 |
+
story = batch.story
|
617 |
+
translated_story = batch.translated_story
|
618 |
+
batch_type = batch.batch_type
|
619 |
+
|
620 |
+
memory_actions = batch_id_to_actions.get(batch_id, [])
|
621 |
+
remember_word_ids = {a.word_id for a in memory_actions if a.action == "remember"}
|
622 |
+
remember_words = []
|
623 |
+
forget_words = []
|
624 |
+
for w in words:
|
625 |
+
if w.vc_id in remember_word_ids:
|
626 |
+
remember_words.append(w.vc_vocabulary)
|
627 |
+
else:
|
628 |
+
forget_words.append(w.vc_vocabulary)
|
629 |
+
memory_status = f"่ฎฐไฝ {len(remember_words)} ไธชๅ่ฏ๏ผๅฟ่ฎฐ {len(forget_words)} ไธชๅ่ฏ"
|
630 |
+
memory_status += f"๏ผ่ฎฐไฝ็ๅ่ฏ๏ผ{', '.join(remember_words)}"
|
631 |
+
memory_status += f"๏ผๅฟ่ฎฐ็ๅ่ฏ๏ผ{', '.join(forget_words)}"
|
632 |
+
|
633 |
+
create_time = action.create_time
|
634 |
+
|
635 |
+
data.append([word, story, translated_story, batch_type, memory_status, create_time])
|
636 |
+
df = pd.DataFrame(data, columns=batch_history_header)
|
637 |
+
return df
|
638 |
+
|
639 |
+
tab52.select(on_select_user, inputs=[user], outputs=[select_user_book])
|
640 |
+
select_user_book.select(on_select_user_book, inputs=[select_user_book], outputs=[batch_history_dataframe])
|
641 |
+
|
642 |
+
|
643 |
+
on_login_success_ui = [email, password, login_btn, register_email, register_password, register_btn]
|
644 |
+
on_login_success_ui += [tab1]
|
645 |
+
|
646 |
+
def on_login(login_success):
|
647 |
+
return (
|
648 |
+
gr.TextArea(visible=not login_success),
|
649 |
+
gr.TextArea(visible=not login_success),
|
650 |
+
gr.Button(visible=not login_success),
|
651 |
+
gr.TextArea(visible=not login_success),
|
652 |
+
gr.TextArea(visible=not login_success),
|
653 |
+
gr.Button(visible=not login_success),
|
654 |
+
# gr.Accordion(visible=not login_success),
|
655 |
+
) + (
|
656 |
+
gr.Tab(visible=login_success),
|
657 |
+
)
|
658 |
+
def login(email, password):
|
659 |
+
user = get_user_by_email(db, email)
|
660 |
+
if password is None or len(password) == 0:
|
661 |
+
return {
|
662 |
+
"id": "",
|
663 |
+
"email": "",
|
664 |
+
}, "็ปๅฝๅคฑ่ดฅ", *on_login(False)
|
665 |
+
if user is None or not user.verify_password(password):
|
666 |
+
return {
|
667 |
+
"id": "",
|
668 |
+
"email": "",
|
669 |
+
}, "็ปๅฝๅคฑ่ดฅ", *on_login(False)
|
670 |
+
return {
|
671 |
+
"id": user.id,
|
672 |
+
"email": user.email,
|
673 |
+
}, "็ปๅฝๆๅ", *on_login(True)
|
674 |
+
login_btn.click(login, [email, password], [user, user_status] + on_login_success_ui)
|
675 |
+
def register(email, password):
|
676 |
+
user = get_user_by_email(db, email)
|
677 |
+
if user is not None:
|
678 |
+
return {
|
679 |
+
"id": "",
|
680 |
+
"email": "",
|
681 |
+
}, "ๆณจๅๅคฑ่ดฅ๏ผ่ฏฅ้ฎ็ฎฑๅทฒ่ขซๆณจๅ", *on_login(False)
|
682 |
+
else:
|
683 |
+
user = create_user(db, email=email, password=password)
|
684 |
+
return {
|
685 |
+
"id": user.id,
|
686 |
+
"email": user.email,
|
687 |
+
}, "ๆณจๅๅนถ็ปๅฝๆๅ", *on_login(True)
|
688 |
+
register_btn.click(register, [register_email, register_password], [user, user_status] + on_login_success_ui)
|
689 |
+
|
690 |
+
|
691 |
+
if __name__ == "__main__":
|
692 |
+
# import os
|
693 |
+
# os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
|
694 |
+
# demo.launch(server_name="127.0.0.1", server_port=8090, debug=True)
|
695 |
+
logger.add(f"output/logs/web_{date_str}.log", rotation="1 day", retention="7 days", level="INFO")
|
696 |
+
demo.launch()
|