aredden commited on
Commit
a035930
·
1 Parent(s): b6617b1

Fix non-prequantized inference

Browse files
cublas_linear.py DELETED
@@ -1 +0,0 @@
1
- from cublas_ops import CublasLinear
 
 
modules/flux_model.py CHANGED
@@ -369,7 +369,7 @@ class Flux(nn.Module):
369
  Transformer model for flow matching on sequences.
370
  """
371
 
372
- def __init__(self, params: FluxParams, dtype: torch.dtype = torch.bfloat16):
373
  super().__init__()
374
 
375
  self.dtype = dtype
 
369
  Transformer model for flow matching on sequences.
370
  """
371
 
372
+ def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
373
  super().__init__()
374
 
375
  self.dtype = dtype
modules/flux_model_f8.py CHANGED
@@ -370,7 +370,7 @@ class Flux(nn.Module):
370
  Transformer model for flow matching on sequences.
371
  """
372
 
373
- def __init__(self, params: FluxParams, dtype: torch.dtype = torch.bfloat16):
374
  super().__init__()
375
 
376
  self.dtype = dtype
 
370
  Transformer model for flow matching on sequences.
371
  """
372
 
373
+ def __init__(self, params: FluxParams, dtype: torch.dtype = torch.float16):
374
  super().__init__()
375
 
376
  self.dtype = dtype
util.py CHANGED
@@ -211,6 +211,8 @@ def load_flow_model(config: ModelSpec) -> Flux:
211
  sd = load_sft(ckpt_path, device="cpu")
212
  missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
213
  print_load_warning(missing, unexpected)
 
 
214
  return model
215
 
216
 
 
211
  sd = load_sft(ckpt_path, device="cpu")
212
  missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
213
  print_load_warning(missing, unexpected)
214
+ if not config.prequantized_flow:
215
+ model.type(into_dtype(config.flow_dtype))
216
  return model
217
 
218