aredden's picture
Add offloading & improved fp8 inference.
28dec30
raw
history blame
3.28 kB
import argparse
import uvicorn
from api import app
from flux_pipeline import FluxPipeline
from util import load_config, ModelVersion
def parse_args():
parser = argparse.ArgumentParser(description="Launch Flux API server")
parser.add_argument(
"-c",
"--config-path",
type=str,
help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments",
)
parser.add_argument(
"-p",
"--port",
type=int,
default=8088,
help="Port to run the server on",
)
parser.add_argument(
"-H",
"--host",
type=str,
default="0.0.0.0",
help="Host to run the server on",
)
parser.add_argument(
"-f", "--flow-model-path", type=str, help="Path to the flow model"
)
parser.add_argument(
"-t", "--text-enc-path", type=str, help="Path to the text encoder"
)
parser.add_argument(
"-a", "--autoencoder-path", type=str, help="Path to the autoencoder"
)
parser.add_argument(
"-m",
"--model-version",
type=str,
choices=["flux-dev", "flux-schnell"],
default="flux-dev",
help="Choose model version",
)
parser.add_argument(
"-F",
"--flux-device",
type=str,
default="cuda:0",
help="Device to run the flow model on",
)
parser.add_argument(
"-T",
"--text-enc-device",
type=str,
default="cuda:0",
help="Device to run the text encoder on",
)
parser.add_argument(
"-A",
"--autoencoder-device",
type=str,
default="cuda:0",
help="Device to run the autoencoder on",
)
parser.add_argument(
"-q",
"--num-to-quant",
type=int,
default=20,
help="Number of linear layers in flow transformer (the 'unet') to quantize",
)
parser.add_argument(
"-C",
"--compile",
action="store_true",
default=False,
help="Compile the flow model with extra optimizations",
)
return parser.parse_args()
def main():
args = parse_args()
if args.config_path:
app.state.model = FluxPipeline.load_pipeline_from_config_path(
args.config_path, flow_model_path=args.flow_model_path
)
else:
model_version = (
ModelVersion.flux_dev
if args.model_version == "flux-dev"
else ModelVersion.flux_schnell
)
config = load_config(
model_version,
flux_path=args.flow_model_path,
flux_device=args.flux_device,
ae_path=args.autoencoder_path,
ae_device=args.autoencoder_device,
text_enc_path=args.text_enc_path,
text_enc_device=args.text_enc_device,
flow_dtype="float16",
text_enc_dtype="bfloat16",
ae_dtype="bfloat16",
num_to_quant=args.num_to_quant,
compile_extras=args.compile,
compile_blocks=args.compile,
)
app.state.model = FluxPipeline.load_pipeline_from_config(config)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()