Update esm_scripts/extract.py
Browse files- esm_scripts/extract.py +2 -10
esm_scripts/extract.py
CHANGED
@@ -132,38 +132,31 @@ def run(args):
|
|
132 |
|
133 |
|
134 |
def run_demo(protein_name, protein_seq, model, alphabet, include,
|
135 |
-
repr_layers
|
136 |
-
|
137 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
138 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
139 |
data_loader = torch.utils.data.DataLoader(
|
140 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
141 |
)
|
142 |
print(f"Read sequences")
|
143 |
-
|
144 |
# output_dir.mkdir(parents=True, exist_ok=True)
|
145 |
return_contacts = "contacts" in include
|
146 |
-
|
147 |
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
|
148 |
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
|
149 |
-
|
150 |
with torch.no_grad():
|
151 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
152 |
print(
|
153 |
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
|
154 |
)
|
155 |
-
if torch.cuda.is_available()
|
156 |
toks = toks.to(device="cuda", non_blocking=True)
|
157 |
-
|
158 |
out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
159 |
-
|
160 |
logits = out["logits"].to(device="cpu")
|
161 |
representations = {
|
162 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
163 |
}
|
164 |
if return_contacts:
|
165 |
contacts = out["contacts"].to(device="cpu")
|
166 |
-
|
167 |
for i, label in enumerate(labels):
|
168 |
result = {"label": label}
|
169 |
truncate_len = min(truncation_seq_length, len(strs[i]))
|
@@ -185,7 +178,6 @@ def run_demo(protein_name, protein_seq, model, alphabet, include,
|
|
185 |
}
|
186 |
if return_contacts:
|
187 |
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
188 |
-
|
189 |
return result['representations'][36]
|
190 |
|
191 |
|
|
|
132 |
|
133 |
|
134 |
def run_demo(protein_name, protein_seq, model, alphabet, include,
|
135 |
+
repr_layers=[-1], truncation_seq_length=1022, toks_per_batch=4096):
|
|
|
136 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
137 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
138 |
data_loader = torch.utils.data.DataLoader(
|
139 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
140 |
)
|
141 |
print(f"Read sequences")
|
|
|
142 |
# output_dir.mkdir(parents=True, exist_ok=True)
|
143 |
return_contacts = "contacts" in include
|
|
|
144 |
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
|
145 |
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
|
|
|
146 |
with torch.no_grad():
|
147 |
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
148 |
print(
|
149 |
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
|
150 |
)
|
151 |
+
if torch.cuda.is_available():
|
152 |
toks = toks.to(device="cuda", non_blocking=True)
|
|
|
153 |
out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
|
|
154 |
logits = out["logits"].to(device="cpu")
|
155 |
representations = {
|
156 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
157 |
}
|
158 |
if return_contacts:
|
159 |
contacts = out["contacts"].to(device="cpu")
|
|
|
160 |
for i, label in enumerate(labels):
|
161 |
result = {"label": label}
|
162 |
truncate_len = min(truncation_seq_length, len(strs[i]))
|
|
|
178 |
}
|
179 |
if return_contacts:
|
180 |
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
|
|
181 |
return result['representations'][36]
|
182 |
|
183 |
|