|
|
|
import torch |
|
import io |
|
import json |
|
from fireworks.flumina import FluminaModule, main as flumina_main |
|
from fireworks.flumina.route import post |
|
import pydantic |
|
from pydantic import BaseModel |
|
from fastapi import Header |
|
from fastapi.responses import Response |
|
import math |
|
import os |
|
import re |
|
import PIL.Image as Image |
|
from typing import Optional, Set, Tuple |
|
|
|
from flux_pipeline import FluxPipeline |
|
from util import load_config, ModelVersion |
|
|
|
|
|
def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]: |
|
""" |
|
Convert specified aspect ratio to a height/width pair. |
|
""" |
|
if ":" not in aspect_ratio: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
w, h = aspect_ratio.split(":") |
|
try: |
|
w, h = int(w), int(h) |
|
except ValueError: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
valid_aspect_ratios = [ |
|
(1, 1), |
|
(21, 9), |
|
(16, 9), |
|
(3, 2), |
|
(5, 4), |
|
(4, 5), |
|
(2, 3), |
|
(9, 16), |
|
(9, 21), |
|
] |
|
if (w, h) not in valid_aspect_ratios: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be one of {valid_aspect_ratios}" |
|
) |
|
|
|
|
|
TARGET_SIZE_MP = 1 |
|
target_size = TARGET_SIZE_MP * 2**20 |
|
|
|
width = math.sqrt(target_size / (w * h)) * w |
|
height = math.sqrt(target_size / (w * h)) * h |
|
|
|
PAD_MULTIPLE = 64 |
|
|
|
if PAD_MULTIPLE: |
|
width = width // PAD_MULTIPLE * PAD_MULTIPLE |
|
height = height // PAD_MULTIPLE * PAD_MULTIPLE |
|
|
|
return int(width), int(height) |
|
|
|
|
|
def encode_image( |
|
image: Image.Image, mime_type: str, jpeg_quality: int = 95 |
|
) -> bytes: |
|
buffered = io.BytesIO() |
|
if mime_type == "image/jpeg": |
|
if jpeg_quality < 0 or jpeg_quality > 100: |
|
raise ValueError( |
|
f"jpeg_quality must be between 0 and 100, not {jpeg_quality}" |
|
) |
|
image.save(buffered, format="JPEG", quality=jpeg_quality) |
|
elif mime_type == "image/png": |
|
image.save(buffered, format="PNG") |
|
else: |
|
raise ValueError(f"invalid mime_type {mime_type}") |
|
return buffered.getvalue() |
|
|
|
|
|
def parse_accept_header(accept: str) -> str: |
|
|
|
parts = accept.split(",") |
|
weighted_types = [] |
|
|
|
for part in parts: |
|
|
|
match = re.match( |
|
r"(?P<media_type>[^;]+)(;q=(?P<q_factor>\d+(\.\d+)?))?", part.strip() |
|
) |
|
if match: |
|
media_type = match.group("media_type") |
|
q_factor = ( |
|
float(match.group("q_factor")) if match.group("q_factor") else 1.0 |
|
) |
|
weighted_types.append((media_type, q_factor)) |
|
else: |
|
raise ValueError(f"Malformed Accept header value: {part.strip()}") |
|
|
|
|
|
sorted_types = sorted(weighted_types, key=lambda x: x[1], reverse=True) |
|
|
|
for media_type, _ in sorted_types: |
|
if media_type in {"image/jpeg", "image/png"}: |
|
return media_type |
|
|
|
raise ValueError(f"Accept header did not include ones of supported MIME types: image/jpeg, image/png") |
|
|
|
|
|
|
|
class Text2ImageRequest(BaseModel): |
|
prompt: str |
|
aspect_ratio: str = "16:9" |
|
guidance_scale: float = 3.5 |
|
num_inference_steps: int = 30 |
|
seed: int = 0 |
|
|
|
|
|
class Error(BaseModel): |
|
object: str = "error" |
|
type: str = "invalid_request_error" |
|
message: str |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
error: Error = pydantic.Field(default_factory=Error) |
|
|
|
|
|
class BillingInfo(BaseModel): |
|
steps: int |
|
height: int |
|
width: int |
|
is_control_net: bool = False |
|
|
|
|
|
class FluminaModule(FluminaModule): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
with open('config.json', 'r') as f: |
|
config_data = json.load(f) |
|
|
|
|
|
if 'config_path' in config_data: |
|
self.pipeline = FluxPipeline.load_pipeline_from_config_path( |
|
config_data['config_path'], |
|
flow_model_path=config_data.get('flow_model_path', None) |
|
) |
|
else: |
|
model_version = ( |
|
ModelVersion.flux_dev |
|
if config_data.get('model_version', 'flux-dev') == "flux-dev" |
|
else ModelVersion.flux_schnell |
|
) |
|
config = load_config( |
|
model_version, |
|
flux_path=config_data.get('flow_model_path', None), |
|
flux_device=config_data.get('flux_device', 'cuda:0'), |
|
ae_path=config_data.get('autoencoder_path', None), |
|
ae_device=config_data.get('autoencoder_device', 'cuda:0'), |
|
text_enc_path=config_data.get('text_enc_path', None), |
|
text_enc_device=config_data.get('text_enc_device', 'cuda:0'), |
|
flow_dtype="float16", |
|
text_enc_dtype="bfloat16", |
|
ae_dtype="bfloat16", |
|
num_to_quant=config_data.get('num_to_quant', 20), |
|
compile_extras=config_data.get('compile', False), |
|
compile_blocks=config_data.get('compile', False), |
|
quant_text_enc=( |
|
None |
|
if config_data.get('quant_text_enc', 'qfloat8') == "bf16" |
|
else config_data.get('quant_text_enc', 'qfloat8') |
|
), |
|
quant_ae=config_data.get('quant_ae', False), |
|
offload_flow=config_data.get('offload_flow', False), |
|
offload_ae=config_data.get('offload_ae', True), |
|
offload_text_enc=config_data.get('offload_text_enc', True), |
|
prequantized_flow=config_data.get('prequantized_flow', False), |
|
quantize_modulation=config_data.get('quantize_modulation', True), |
|
quantize_flow_embedder_layers=config_data.get( |
|
'quantize_flow_embedder_layers', False |
|
), |
|
) |
|
self.pipeline = FluxPipeline.load_pipeline_from_config(config) |
|
|
|
|
|
self.lora_adapters: Set[str] = set() |
|
self.active_lora_adapter: Optional[str] = None |
|
self._test_return_sync_response = False |
|
|
|
def _error_response(self, code: int, message: str) -> Response: |
|
response_json = ErrorResponse( |
|
error=Error(message=message), |
|
).json() |
|
if self._test_return_sync_response: |
|
return response_json |
|
else: |
|
return Response( |
|
response_json, |
|
status_code=code, |
|
media_type="application/json", |
|
) |
|
|
|
def _image_response( |
|
self, image_bytes: bytes, mime_type: str, billing_info: BillingInfo |
|
): |
|
if self._test_return_sync_response: |
|
return image_bytes |
|
else: |
|
headers = {'Fireworks-Billing-Properties': billing_info.json()} |
|
return Response( |
|
image_bytes, status_code=200, media_type=mime_type, headers=headers |
|
) |
|
|
|
@post('/text_to_image') |
|
async def text_to_image( |
|
self, |
|
body: Text2ImageRequest, |
|
accept: str = Header("image/jpeg"), |
|
): |
|
mime_type = parse_accept_header(accept) |
|
width, height = _aspect_ratio_to_width_height(body.aspect_ratio) |
|
img_bio = self.pipeline.generate( |
|
prompt=body.prompt, |
|
height=height, |
|
width=width, |
|
guidance=body.guidance_scale, |
|
num_steps=body.num_inference_steps, |
|
seed=body.seed, |
|
) |
|
|
|
billing_info = BillingInfo( |
|
steps=body.num_inference_steps, |
|
height=height, |
|
width=width, |
|
) |
|
return self._image_response(img_bio.getvalue(), mime_type, billing_info) |
|
|
|
@property |
|
def supported_addon_types(self): |
|
return ['lora'] |
|
|
|
|
|
def load_addon( |
|
self, |
|
addon_account_id: str, |
|
addon_model_id: str, |
|
addon_type: str, |
|
addon_data_path: os.PathLike, |
|
): |
|
if addon_type not in self.supported_addon_types: |
|
raise ValueError( |
|
f"Invalid addon type {addon_type}. Supported types: {self.supported_addon_types}" |
|
) |
|
|
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if addon_type == 'lora': |
|
self.pipeline.load_lora_weights(addon_data_path, adapter_name=qualname) |
|
self.lora_adapters.add(qualname) |
|
else: |
|
raise NotImplementedError( |
|
f'Addon support for type {addon_type} not implemented' |
|
) |
|
|
|
def unload_addon( |
|
self, addon_account_id: str, addon_model_id: str, addon_type: str |
|
): |
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if addon_type == 'lora': |
|
assert qualname in self.lora_adapters |
|
self.pipeline.delete_adapters([qualname]) |
|
self.lora_adapters.remove(qualname) |
|
else: |
|
raise NotImplementedError( |
|
f'Addon support for type {addon_type} not implemented' |
|
) |
|
|
|
def activate_addon(self, addon_account_id: str, addon_model_id: str): |
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if qualname in self.lora_adapters: |
|
if self.active_lora_adapter is not None: |
|
raise ValueError( |
|
f"LoRA adapter {self.active_lora_adapter} already active. Multi-LoRA not yet supported" |
|
) |
|
|
|
self.active_lora_adapter = qualname |
|
return |
|
|
|
raise ValueError(f"Unknown addon {qualname}") |
|
|
|
def deactivate_addon(self, addon_account_id: str, addon_model_id: str): |
|
qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" |
|
|
|
if self.active_lora_adapter == qualname: |
|
self.active_lora_adapter = None |
|
else: |
|
raise AssertionError(f'Addon {qualname} not loaded!') |
|
|
|
|
|
if __name__ == "__flumina_main__": |
|
f = FluminaModule() |
|
flumina_main(f) |
|
|
|
if __name__ == "__main__": |
|
f = FluminaModule() |
|
f._test_return_sync_response = True |
|
import asyncio |
|
out = asyncio.run(f.text_to_image( |
|
body=Text2ImageRequest( |
|
prompt="test" |
|
), |
|
accept="image/png" |
|
)) |
|
with open("out_image.png", "wb") as f: |
|
f.write(out) |
|
|
|
|