madhavanvenkatesh
commited on
CUDA kernels incompatible with standard PyTorch device movement with 4bit/8bit, necessitating device-specific handling
Browse files- geneformer/perturber_utils.py +60 -70
geneformer/perturber_utils.py
CHANGED
@@ -117,83 +117,73 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
117 |
model_type = "MTLCellClassifier"
|
118 |
quantize = True
|
119 |
|
120 |
-
|
121 |
-
output_hidden_states = True
|
122 |
-
elif mode == "train":
|
123 |
-
output_hidden_states = False
|
124 |
|
125 |
-
|
|
|
126 |
if model_type == "MTLCellClassifier":
|
127 |
-
|
128 |
-
|
129 |
-
"bnb_config": BitsAndBytesConfig(
|
130 |
-
load_in_8bit=True,
|
131 |
-
),
|
132 |
-
}
|
133 |
else:
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
output_hidden_states=output_hidden_states,
|
180 |
-
output_attentions=False,
|
181 |
-
quantization_config=quantize["bnb_config"],
|
182 |
-
)
|
183 |
-
# if eval mode, put the model in eval mode for fwd pass
|
184 |
if mode == "eval":
|
185 |
model.eval()
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
model = model.to(
|
192 |
-
|
|
|
193 |
model.enable_input_require_grads()
|
194 |
-
model = get_peft_model(model,
|
195 |
-
return model
|
196 |
|
|
|
197 |
|
198 |
def quant_layers(model):
|
199 |
layer_nums = []
|
|
|
117 |
model_type = "MTLCellClassifier"
|
118 |
quantize = True
|
119 |
|
120 |
+
output_hidden_states = (mode == "eval")
|
|
|
|
|
|
|
121 |
|
122 |
+
# Quantization logic
|
123 |
+
if quantize:
|
124 |
if model_type == "MTLCellClassifier":
|
125 |
+
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
126 |
+
peft_config = None
|
|
|
|
|
|
|
|
|
127 |
else:
|
128 |
+
quantize_config = BitsAndBytesConfig(
|
129 |
+
load_in_4bit=True,
|
130 |
+
bnb_4bit_use_double_quant=True,
|
131 |
+
bnb_4bit_quant_type="nf4",
|
132 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
133 |
+
)
|
134 |
+
peft_config = LoraConfig(
|
135 |
+
lora_alpha=128,
|
136 |
+
lora_dropout=0.1,
|
137 |
+
r=64,
|
138 |
+
bias="none",
|
139 |
+
task_type="TokenClassification",
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
quantize_config = None
|
143 |
+
peft_config = None
|
144 |
+
|
145 |
+
# Model class selection
|
146 |
+
model_classes = {
|
147 |
+
"Pretrained": BertForMaskedLM,
|
148 |
+
"GeneClassifier": BertForTokenClassification,
|
149 |
+
"CellClassifier": BertForSequenceClassification,
|
150 |
+
"MTLCellClassifier": BertForMaskedLM
|
151 |
+
}
|
152 |
+
|
153 |
+
model_class = model_classes.get(model_type)
|
154 |
+
if not model_class:
|
155 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
156 |
+
|
157 |
+
# Model loading
|
158 |
+
model_args = {
|
159 |
+
"pretrained_model_name_or_path": model_directory,
|
160 |
+
"output_hidden_states": output_hidden_states,
|
161 |
+
"output_attentions": False,
|
162 |
+
}
|
163 |
+
|
164 |
+
if model_type != "Pretrained":
|
165 |
+
model_args["num_labels"] = num_classes
|
166 |
+
|
167 |
+
if quantize_config:
|
168 |
+
model_args["quantization_config"] = quantize_config
|
169 |
+
|
170 |
+
# Load the model
|
171 |
+
model = model_class.from_pretrained(**model_args)
|
172 |
+
|
|
|
|
|
|
|
|
|
|
|
173 |
if mode == "eval":
|
174 |
model.eval()
|
175 |
+
|
176 |
+
# Handle device placement and PEFT
|
177 |
+
if not quantize:
|
178 |
+
# Only move non-quantized models
|
179 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
180 |
+
model = model.to(device)
|
181 |
+
elif peft_config:
|
182 |
+
# Apply PEFT for quantized models (except MTLCellClassifier)
|
183 |
model.enable_input_require_grads()
|
184 |
+
model = get_peft_model(model, peft_config)
|
|
|
185 |
|
186 |
+
return model
|
187 |
|
188 |
def quant_layers(model):
|
189 |
layer_nums = []
|