Fix non-prequantized inference
Browse files- cublas_linear.py +0 -1
- modules/flux_model.py +1 -1
- modules/flux_model_f8.py +1 -1
- util.py +2 -0
cublas_linear.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from cublas_ops import CublasLinear
|
|
|
|
modules/flux_model.py
CHANGED
@@ -369,7 +369,7 @@ class Flux(nn.Module):
|
|
369 |
Transformer model for flow matching on sequences.
|
370 |
"""
|
371 |
|
372 |
-
def __init__(self, params: FluxParams, dtype: torch.dtype = torch.
|
373 |
super().__init__()
|
374 |
|
375 |
self.dtype = dtype
|
|
|
369 |
Transformer model for flow matching on sequences.
|
370 |
"""
|
371 |
|
372 |
+
def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
|
373 |
super().__init__()
|
374 |
|
375 |
self.dtype = dtype
|
modules/flux_model_f8.py
CHANGED
@@ -370,7 +370,7 @@ class Flux(nn.Module):
|
|
370 |
Transformer model for flow matching on sequences.
|
371 |
"""
|
372 |
|
373 |
-
def __init__(self, params: FluxParams, dtype: torch.dtype = torch.
|
374 |
super().__init__()
|
375 |
|
376 |
self.dtype = dtype
|
|
|
370 |
Transformer model for flow matching on sequences.
|
371 |
"""
|
372 |
|
373 |
+
def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
|
374 |
super().__init__()
|
375 |
|
376 |
self.dtype = dtype
|
util.py
CHANGED
@@ -211,6 +211,8 @@ def load_flow_model(config: ModelSpec) -> Flux:
|
|
211 |
sd = load_sft(ckpt_path, device="cpu")
|
212 |
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
213 |
print_load_warning(missing, unexpected)
|
|
|
|
|
214 |
return model
|
215 |
|
216 |
|
|
|
211 |
sd = load_sft(ckpt_path, device="cpu")
|
212 |
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
213 |
print_load_warning(missing, unexpected)
|
214 |
+
if not config.prequantized_flow:
|
215 |
+
model.type(into_dtype(config.flow_dtype))
|
216 |
return model
|
217 |
|
218 |
|