File size: 2,595 Bytes
d098bff
 
 
c5b0047
 
d098bff
c5b0047
4fe68bc
 
 
 
c5b0047
 
d098bff
 
 
 
 
 
 
 
 
 
 
c5b0047
 
 
 
 
 
 
aa1b602
 
 
d098bff
c5b0047
 
aa1b602
d098bff
c5b0047
 
 
 
d098bff
c5b0047
 
 
 
 
 
 
 
 
 
 
 
4fe68bc
 
 
 
91fbea0
4fe68bc
 
2e517b5
4fe68bc
 
 
 
 
 
 
 
 
5b981d0
4fe68bc
 
 
 
 
5b981d0
4fe68bc
 
 
 
 
 
 
 
 
 
4bb483c
 
 
 
 
 
 
 
 
 
4fe68bc
0af8d1d
2e517b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from dataclasses import dataclass
from enum import IntEnum

import yaml

from typing import Dict, Optional, List
from pydantic import BaseModel, ValidationError
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

from openai import OpenAI


class OAuthProvider(IntEnum):
    NONE = 0
    GOOGLE = 1


@dataclass
class User:
    oauth: OAuthProvider
    username: str
    permissions_id: str


class PileConfig(BaseModel):
    file2persona: Dict[str, str]
    file2prefix: Dict[str, str]
    persona2system: Dict[str, str]
    prompt: str

class PermissionsConfig(BaseModel):
    google_domains: Optional[List[str]] = None


class InferenceConfig(BaseModel):
    chat_template: str
    permissions: Optional[PermissionsConfig] = None


class RepoConfig(BaseModel):
    name: str


class ModelConfig(BaseModel):
    pile: PileConfig
    inference: InferenceConfig
    repo: RepoConfig

    @classmethod
    def from_yaml(cls, yaml_file = "datasets/config.yaml"):
        with open(yaml_file, 'r') as file:
            data = yaml.safe_load(file)
        try:
            return cls(**data)
        except ValidationError as e:
            raise e
        

class Client:
    def __init__(self, api_url, api_key, personas = {}):
        self.api_url = api_url
        self.api_key = api_key
        self.input_personas = personas
        
        self.init_all()

    def init_all(self):
        self.init_client()
        self.get_metadata()
        self.get_personas()

    def init_client(self):
        self.openai = OpenAI(
            base_url=f"{self.api_url}/v1",
            api_key=self.api_key,
        )

    def get_metadata(self):
        models = self.openai.models.list()
        vllm_model_name = models.data[0].id

        model_name, *suffix = vllm_model_name.split("@")
        revision = dict(enumerate(suffix)).get(0, None)

        self.vllm_model_name = vllm_model_name
        self.model_name = model_name
        self.revision = revision
    
    def get_personas(self):
        personas = {}
        if self.revision is not None:
            try:
                config_path = hf_hub_download(self.model_name, "config.yaml",
                                        subfolder="datasets",
                                        revision=self.revision)
                self.config = ModelConfig.from_yaml(config_path)
                personas = self.config.pile.persona2system
            except EntryNotFoundError:
                pass

        personas["vanilla"] = None
        self.personas = self.input_personas | personas