File size: 1,697 Bytes
b177a48
afdb110
 
 
 
 
b177a48
 
afdb110
 
 
 
 
 
b177a48
afdb110
 
b177a48
afdb110
b177a48
afdb110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe8637
b177a48
 
afdb110
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from openai import OpenAI
from PIL import Image
import requests
import io
import os
import base64


class OpenaiModel():
    def __init__(self, model_name, model_type):
        self.model_name = model_name
        self.model_type = model_type
    
    def __call__(self, *args, **kwargs):

        if self.model_type == "text2image":
            assert "prompt" in kwargs, "prompt is required for text2image model"

            client = OpenAI()

            if 'Dalle-3' in self.model_name:
                client = OpenAI()
                response = client.images.generate(
                    model="dall-e-3",
                    prompt=kwargs["prompt"],
                    size="1024x1024",
                    quality="standard",
                    n=1,
                    )
            elif 'Dalle-2' in self.model_name:
                client = OpenAI()
                response = client.images.generate(
                    model="dall-e-2",
                    prompt=kwargs["prompt"],
                    size="512x512",
                    quality="standard",
                    n=1,
                    )
            else:
                raise NotImplementedError

            result_url = response.data[0].url
            response = requests.get(result_url)
            result = Image.open(io.BytesIO(response.content))
            return result
        else:
            raise ValueError("model_type must be text2image or image2image")



def load_openai_model(model_name, model_type):
    return OpenaiModel(model_name, model_type)


if __name__ == "__main__":
    pipe = load_openai_model('Dalle-2', 'text2image')
    result = pipe(prompt='draw a tiger')
    print(result)