import argparse import uvicorn from api import app from flux_pipeline import FluxPipeline from util import load_config, ModelVersion def parse_args(): parser = argparse.ArgumentParser(description="Launch Flux API server") parser.add_argument( "-c", "--config-path", type=str, help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments", ) parser.add_argument( "-p", "--port", type=int, default=8088, help="Port to run the server on", ) parser.add_argument( "-H", "--host", type=str, default="0.0.0.0", help="Host to run the server on", ) parser.add_argument( "-f", "--flow-model-path", type=str, help="Path to the flow model" ) parser.add_argument( "-t", "--text-enc-path", type=str, help="Path to the text encoder" ) parser.add_argument( "-a", "--autoencoder-path", type=str, help="Path to the autoencoder" ) parser.add_argument( "-m", "--model-version", type=str, choices=["flux-dev", "flux-schnell"], default="flux-dev", help="Choose model version", ) parser.add_argument( "-F", "--flux-device", type=str, default="cuda:0", help="Device to run the flow model on", ) parser.add_argument( "-T", "--text-enc-device", type=str, default="cuda:0", help="Device to run the text encoder on", ) parser.add_argument( "-A", "--autoencoder-device", type=str, default="cuda:0", help="Device to run the autoencoder on", ) parser.add_argument( "-q", "--num-to-quant", type=int, default=20, help="Number of linear layers in flow transformer (the 'unet') to quantize", ) parser.add_argument( "-C", "--compile", action="store_true", default=False, help="Compile the flow model with extra optimizations", ) return parser.parse_args() def main(): args = parse_args() if args.config_path: app.state.model = FluxPipeline.load_pipeline_from_config_path( args.config_path, flow_model_path=args.flow_model_path ) else: model_version = ( ModelVersion.flux_dev if args.model_version == "flux-dev" else ModelVersion.flux_schnell ) config = load_config( model_version, flux_path=args.flow_model_path, flux_device=args.flux_device, ae_path=args.autoencoder_path, ae_device=args.autoencoder_device, text_enc_path=args.text_enc_path, text_enc_device=args.text_enc_device, flow_dtype="float16", text_enc_dtype="bfloat16", ae_dtype="bfloat16", num_to_quant=args.num_to_quant, compile_extras=args.compile, compile_blocks=args.compile, ) app.state.model = FluxPipeline.load_pipeline_from_config(config) uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": main()