from openai import OpenAI from PIL import Image import requests import io import os import base64 class OpenaiModel(): def __init__(self, model_name, model_type): self.model_name = model_name self.model_type = model_type def __call__(self, *args, **kwargs): if self.model_type == "text2image": assert "prompt" in kwargs, "prompt is required for text2image model" client = OpenAI() if 'Dalle-3' in self.model_name: client = OpenAI() response = client.images.generate( model="dall-e-3", prompt=kwargs["prompt"], size="1024x1024", quality="standard", n=1, ) elif 'Dalle-2' in self.model_name: client = OpenAI() response = client.images.generate( model="dall-e-2", prompt=kwargs["prompt"], size="512x512", quality="standard", n=1, ) else: raise NotImplementedError result_url = response.data[0].url response = requests.get(result_url) result = Image.open(io.BytesIO(response.content)) return result else: raise ValueError("model_type must be text2image or image2image") def load_openai_model(model_name, model_type): return OpenaiModel(model_name, model_type) if __name__ == "__main__": pipe = load_openai_model('Dalle-2', 'text2image') result = pipe(prompt='draw a tiger') print(result)