Spaces:
Running
on
T4
Running
on
T4
update
Browse files- semantic_search.py +4 -5
semantic_search.py
CHANGED
@@ -23,8 +23,10 @@ from huggingface import dataset_utils
|
|
23 |
|
24 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/")
|
25 |
|
|
|
26 |
hf_token = os.environ.get("HF_TOKEN", None)
|
27 |
hf_dataset = os.environ.get("HF_DATASET", None)
|
|
|
28 |
|
29 |
if hf_token is not None and hf_dataset is not None:
|
30 |
global_data_path = dataset_utils.get_global_data_path()
|
@@ -59,9 +61,6 @@ db_data_types = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный з
|
|
59 |
|
60 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
61 |
|
62 |
-
# access token huggingface. Если задан, то используется модель с HF
|
63 |
-
hf_token = os.environ.get("HF_TOKEN", "")
|
64 |
-
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
|
65 |
|
66 |
llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
|
67 |
|
@@ -405,8 +404,8 @@ class SemanticSearch:
|
|
405 |
|
406 |
def load_model(self):
|
407 |
if hf_token and hf_model_name:
|
408 |
-
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=
|
409 |
-
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=
|
410 |
else:
|
411 |
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
|
412 |
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
|
|
|
23 |
|
24 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/")
|
25 |
|
26 |
+
# access token huggingface. Если задан, то используется модель с HF
|
27 |
hf_token = os.environ.get("HF_TOKEN", None)
|
28 |
hf_dataset = os.environ.get("HF_DATASET", None)
|
29 |
+
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
|
30 |
|
31 |
if hf_token is not None and hf_dataset is not None:
|
32 |
global_data_path = dataset_utils.get_global_data_path()
|
|
|
61 |
|
62 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
63 |
|
|
|
|
|
|
|
64 |
|
65 |
llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
|
66 |
|
|
|
404 |
|
405 |
def load_model(self):
|
406 |
if hf_token and hf_model_name:
|
407 |
+
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=hf_token)
|
408 |
+
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=hf_token).to(self.device)
|
409 |
else:
|
410 |
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
|
411 |
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
|