NeonBohdan commited on
Commit
4fe68bc
·
1 Parent(s): c5b0047

Added Client class

Browse files
Files changed (1) hide show
  1. shared.py +47 -1
shared.py CHANGED
@@ -2,6 +2,10 @@ import yaml
2
 
3
  from typing import Dict
4
  from pydantic import BaseModel, ValidationError
 
 
 
 
5
 
6
 
7
 
@@ -30,4 +34,46 @@ class ModelConfig(BaseModel):
30
  try:
31
  return cls(**data)
32
  except ValidationError as e:
33
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from typing import Dict
4
  from pydantic import BaseModel, ValidationError
5
+ from huggingface_hub import hf_hub_download
6
+ from huggingface_hub.utils import EntryNotFoundError
7
+
8
+ from openai import OpenAI
9
 
10
 
11
 
 
34
  try:
35
  return cls(**data)
36
  except ValidationError as e:
37
+ raise e
38
+
39
+
40
+ class Client:
41
+ def __init__(self, api_url, api_key):
42
+ self.api_url = api_url
43
+ self.api_key = api_key
44
+
45
+ self.init_all()
46
+
47
+ def init_all(self):
48
+ self.init_client()
49
+ self.get_metadata()
50
+ self.get_personas()
51
+
52
+ def init_client(self):
53
+ self.client = OpenAI(
54
+ base_url=f"{self.api_url}/v1",
55
+ api_key=self.api_key,
56
+ )
57
+
58
+ def get_metadata(self):
59
+ models = self.client.models.list()
60
+ vllm_model_name = models.data[0].id
61
+
62
+ model_name, *suffix = vllm_model_name.split("@")
63
+ revision = dict(enumerate(suffix)).get(0, None)
64
+
65
+ self.vllm_model_name = vllm_model_name
66
+ self.model_name = model_name
67
+ self.revision = revision
68
+
69
+ def get_personas(self):
70
+ try:
71
+ config_path = hf_hub_download(self.model_name, "config.yaml",
72
+ subfolder="datasets",
73
+ revision=self.revision)
74
+ self.config = ModelConfig.from_yaml(config_path)
75
+ self.personas = self.config.pile.persona2system
76
+ except EntryNotFoundError:
77
+ self.personas = {}
78
+
79
+ self.personas["default"] = None