muryshev commited on
Commit
5fc439f
·
1 Parent(s): 81201dd
Files changed (1) hide show
  1. semantic_search.py +4 -5
semantic_search.py CHANGED
@@ -23,8 +23,10 @@ from huggingface import dataset_utils
23
 
24
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/")
25
 
 
26
  hf_token = os.environ.get("HF_TOKEN", None)
27
  hf_dataset = os.environ.get("HF_DATASET", None)
 
28
 
29
  if hf_token is not None and hf_dataset is not None:
30
  global_data_path = dataset_utils.get_global_data_path()
@@ -59,9 +61,6 @@ db_data_types = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный з
59
 
60
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
61
 
62
- # access token huggingface. Если задан, то используется модель с HF
63
- hf_token = os.environ.get("HF_TOKEN", "")
64
- hf_model_name = os.environ.get("HF_MODEL_NAME", "")
65
 
66
  llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
67
 
@@ -405,8 +404,8 @@ class SemanticSearch:
405
 
406
  def load_model(self):
407
  if hf_token and hf_model_name:
408
- self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
409
- self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
410
  else:
411
  self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
412
  self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
 
23
 
24
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/")
25
 
26
+ # access token huggingface. Если задан, то используется модель с HF
27
  hf_token = os.environ.get("HF_TOKEN", None)
28
  hf_dataset = os.environ.get("HF_DATASET", None)
29
+ hf_model_name = os.environ.get("HF_MODEL_NAME", "")
30
 
31
  if hf_token is not None and hf_dataset is not None:
32
  global_data_path = dataset_utils.get_global_data_path()
 
61
 
62
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
63
 
 
 
 
64
 
65
  llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
66
 
 
404
 
405
  def load_model(self):
406
  if hf_token and hf_model_name:
407
+ self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=hf_token)
408
+ self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=hf_token).to(self.device)
409
  else:
410
  self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
411
  self.model = AutoModel.from_pretrained(global_model_path).to(self.device)