# flumina.py import torch import io import json from fireworks.flumina import FluminaModule, main as flumina_main from fireworks.flumina.route import post import pydantic from pydantic import BaseModel from fastapi import Header from fastapi.responses import Response import math import os import re import PIL.Image as Image from typing import Optional, Set, Tuple from flux_pipeline import FluxPipeline from util import load_config, ModelVersion # Util def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]: """ Convert specified aspect ratio to a height/width pair. """ if ":" not in aspect_ratio: raise ValueError( f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" ) w, h = aspect_ratio.split(":") try: w, h = int(w), int(h) except ValueError: raise ValueError( f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" ) valid_aspect_ratios = [ (1, 1), (21, 9), (16, 9), (3, 2), (5, 4), (4, 5), (2, 3), (9, 16), (9, 21), ] if (w, h) not in valid_aspect_ratios: raise ValueError( f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be one of {valid_aspect_ratios}" ) # We consider megapixel not 10^6 pixels but 2^20 (1024x1024) pixels TARGET_SIZE_MP = 1 target_size = TARGET_SIZE_MP * 2**20 width = math.sqrt(target_size / (w * h)) * w height = math.sqrt(target_size / (w * h)) * h PAD_MULTIPLE = 64 if PAD_MULTIPLE: width = width // PAD_MULTIPLE * PAD_MULTIPLE height = height // PAD_MULTIPLE * PAD_MULTIPLE return int(width), int(height) def encode_image( image: Image.Image, mime_type: str, jpeg_quality: int = 95 ) -> bytes: buffered = io.BytesIO() if mime_type == "image/jpeg": if jpeg_quality < 0 or jpeg_quality > 100: raise ValueError( f"jpeg_quality must be between 0 and 100, not {jpeg_quality}" ) image.save(buffered, format="JPEG", quality=jpeg_quality) elif mime_type == "image/png": image.save(buffered, format="PNG") else: raise ValueError(f"invalid mime_type {mime_type}") return buffered.getvalue() def parse_accept_header(accept: str) -> str: # Split the string into the comma-separated components parts = accept.split(",") weighted_types = [] for part in parts: # Use a regular expression to extract the media type and the optional q-factor match = re.match( r"(?P[^;]+)(;q=(?P\d+(\.\d+)?))?", part.strip() ) if match: media_type = match.group("media_type") q_factor = ( float(match.group("q_factor")) if match.group("q_factor") else 1.0 ) weighted_types.append((media_type, q_factor)) else: raise ValueError(f"Malformed Accept header value: {part.strip()}") # Sort the media types by q-factor, descending sorted_types = sorted(weighted_types, key=lambda x: x[1], reverse=True) for media_type, _ in sorted_types: if media_type in {"image/jpeg", "image/png"}: return media_type raise ValueError(f"Accept header did not include ones of supported MIME types: image/jpeg, image/png") # Define request and response schemata class Text2ImageRequest(BaseModel): prompt: str aspect_ratio: str = "16:9" guidance_scale: float = 3.5 num_inference_steps: int = 30 seed: int = 0 class Error(BaseModel): object: str = "error" type: str = "invalid_request_error" message: str class ErrorResponse(BaseModel): error: Error = pydantic.Field(default_factory=Error) class BillingInfo(BaseModel): steps: int height: int width: int is_control_net: bool = False class FluminaModule(FluminaModule): def __init__(self): super().__init__() # Read configuration from config.json with open('config.json', 'r') as f: config_data = json.load(f) # Now, we need to construct the config and load the model if 'config_path' in config_data: self.pipeline = FluxPipeline.load_pipeline_from_config_path( config_data['config_path'], flow_model_path=config_data.get('flow_model_path', None) ) else: model_version = ( ModelVersion.flux_dev if config_data.get('model_version', 'flux-dev') == "flux-dev" else ModelVersion.flux_schnell ) config = load_config( model_version, flux_path=config_data.get('flow_model_path', None), flux_device=config_data.get('flux_device', 'cuda:0'), ae_path=config_data.get('autoencoder_path', None), ae_device=config_data.get('autoencoder_device', 'cuda:0'), text_enc_path=config_data.get('text_enc_path', None), text_enc_device=config_data.get('text_enc_device', 'cuda:0'), flow_dtype="float16", text_enc_dtype="bfloat16", ae_dtype="bfloat16", num_to_quant=config_data.get('num_to_quant', 20), compile_extras=config_data.get('compile', False), compile_blocks=config_data.get('compile', False), quant_text_enc=( None if config_data.get('quant_text_enc', 'qfloat8') == "bf16" else config_data.get('quant_text_enc', 'qfloat8') ), quant_ae=config_data.get('quant_ae', False), offload_flow=config_data.get('offload_flow', False), offload_ae=config_data.get('offload_ae', True), offload_text_enc=config_data.get('offload_text_enc', True), prequantized_flow=config_data.get('prequantized_flow', False), quantize_modulation=config_data.get('quantize_modulation', True), quantize_flow_embedder_layers=config_data.get( 'quantize_flow_embedder_layers', False ), ) self.pipeline = FluxPipeline.load_pipeline_from_config(config) # Initialize LoRA adapters self.lora_adapters: Set[str] = set() self.active_lora_adapter: Optional[str] = None self._test_return_sync_response = False def _error_response(self, code: int, message: str) -> Response: response_json = ErrorResponse( error=Error(message=message), ).json() if self._test_return_sync_response: return response_json else: return Response( response_json, status_code=code, media_type="application/json", ) def _image_response( self, image_bytes: bytes, mime_type: str, billing_info: BillingInfo ): if self._test_return_sync_response: return image_bytes else: headers = {'Fireworks-Billing-Properties': billing_info.json()} return Response( image_bytes, status_code=200, media_type=mime_type, headers=headers ) @post('/text_to_image') async def text_to_image( self, body: Text2ImageRequest, accept: str = Header("image/jpeg"), ): mime_type = parse_accept_header(accept) width, height = _aspect_ratio_to_width_height(body.aspect_ratio) img_bio = self.pipeline.generate( prompt=body.prompt, height=height, width=width, guidance=body.guidance_scale, num_steps=body.num_inference_steps, seed=body.seed, ) billing_info = BillingInfo( steps=body.num_inference_steps, height=height, width=width, ) return self._image_response(img_bio.getvalue(), mime_type, billing_info) @property def supported_addon_types(self): return ['lora'] # Addon interface methods adjusted to remove ControlNet support def load_addon( self, addon_account_id: str, addon_model_id: str, addon_type: str, addon_data_path: os.PathLike, ): if addon_type not in self.supported_addon_types: raise ValueError( f"Invalid addon type {addon_type}. Supported types: {self.supported_addon_types}" ) qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" if addon_type == 'lora': self.pipeline.load_lora_weights(addon_data_path, adapter_name=qualname) self.lora_adapters.add(qualname) else: raise NotImplementedError( f'Addon support for type {addon_type} not implemented' ) def unload_addon( self, addon_account_id: str, addon_model_id: str, addon_type: str ): qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" if addon_type == 'lora': assert qualname in self.lora_adapters self.pipeline.delete_adapters([qualname]) self.lora_adapters.remove(qualname) else: raise NotImplementedError( f'Addon support for type {addon_type} not implemented' ) def activate_addon(self, addon_account_id: str, addon_model_id: str): qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" if qualname in self.lora_adapters: if self.active_lora_adapter is not None: raise ValueError( f"LoRA adapter {self.active_lora_adapter} already active. Multi-LoRA not yet supported" ) self.active_lora_adapter = qualname return raise ValueError(f"Unknown addon {qualname}") def deactivate_addon(self, addon_account_id: str, addon_model_id: str): qualname = f"accounts/{addon_account_id}/models/{addon_model_id}" if self.active_lora_adapter == qualname: self.active_lora_adapter = None else: raise AssertionError(f'Addon {qualname} not loaded!') if __name__ == "__flumina_main__": f = FluminaModule() flumina_main(f) if __name__ == "__main__": f = FluminaModule() f._test_return_sync_response = True import asyncio out = asyncio.run(f.text_to_image( body=Text2ImageRequest( prompt="test" ), accept="image/png" )) with open("out_image.png", "wb") as f: f.write(out)