Spaces:
Sleeping
Sleeping
File size: 2,595 Bytes
d098bff c5b0047 d098bff c5b0047 4fe68bc c5b0047 d098bff c5b0047 aa1b602 d098bff c5b0047 aa1b602 d098bff c5b0047 d098bff c5b0047 4fe68bc 91fbea0 4fe68bc 2e517b5 4fe68bc 5b981d0 4fe68bc 5b981d0 4fe68bc 4bb483c 4fe68bc 0af8d1d 2e517b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from dataclasses import dataclass
from enum import IntEnum
import yaml
from typing import Dict, Optional, List
from pydantic import BaseModel, ValidationError
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from openai import OpenAI
class OAuthProvider(IntEnum):
NONE = 0
GOOGLE = 1
@dataclass
class User:
oauth: OAuthProvider
username: str
permissions_id: str
class PileConfig(BaseModel):
file2persona: Dict[str, str]
file2prefix: Dict[str, str]
persona2system: Dict[str, str]
prompt: str
class PermissionsConfig(BaseModel):
google_domains: Optional[List[str]] = None
class InferenceConfig(BaseModel):
chat_template: str
permissions: Optional[PermissionsConfig] = None
class RepoConfig(BaseModel):
name: str
class ModelConfig(BaseModel):
pile: PileConfig
inference: InferenceConfig
repo: RepoConfig
@classmethod
def from_yaml(cls, yaml_file = "datasets/config.yaml"):
with open(yaml_file, 'r') as file:
data = yaml.safe_load(file)
try:
return cls(**data)
except ValidationError as e:
raise e
class Client:
def __init__(self, api_url, api_key, personas = {}):
self.api_url = api_url
self.api_key = api_key
self.input_personas = personas
self.init_all()
def init_all(self):
self.init_client()
self.get_metadata()
self.get_personas()
def init_client(self):
self.openai = OpenAI(
base_url=f"{self.api_url}/v1",
api_key=self.api_key,
)
def get_metadata(self):
models = self.openai.models.list()
vllm_model_name = models.data[0].id
model_name, *suffix = vllm_model_name.split("@")
revision = dict(enumerate(suffix)).get(0, None)
self.vllm_model_name = vllm_model_name
self.model_name = model_name
self.revision = revision
def get_personas(self):
personas = {}
if self.revision is not None:
try:
config_path = hf_hub_download(self.model_name, "config.yaml",
subfolder="datasets",
revision=self.revision)
self.config = ModelConfig.from_yaml(config_path)
personas = self.config.pile.persona2system
except EntryNotFoundError:
pass
personas["vanilla"] = None
self.personas = self.input_personas | personas |