luost26's picture
Update
753e275
raw
history blame
2.1 kB
import os
import argparse
import ray
import shelve
import time
import pandas as pd
from typing import Mapping
from tools.eval.base import EvalTask, TaskScanner
from tools.eval.similarity import eval_similarity
from tools.eval.energy import eval_interface_energy
@ray.remote(num_cpus=1)
def evaluate(task, args):
funcs = []
funcs.append(eval_similarity)
if not args.no_energy:
funcs.append(eval_interface_energy)
for f in funcs:
task = f(task)
return task
def dump_db(db: Mapping[str, EvalTask], path):
table = []
for task in db.values():
if 'abopt' in path and task.scores['seqid'] >= 100.0:
# In abopt (Antibody Optimization) mode, ignore sequences identical to the wild-type
continue
table.append(task.to_report_dict())
table = pd.DataFrame(table)
table.to_csv(path, index=False, float_format='%.6f')
return table
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='./results')
parser.add_argument('--pfx', type=str, default='rosetta')
parser.add_argument('--no_energy', action='store_true', default=False)
args = parser.parse_args()
ray.init()
db_path = os.path.join(args.root, 'evaluation_db')
with shelve.open(db_path) as db:
scanner = TaskScanner(root=args.root, postfix=args.pfx, db=db)
while True:
tasks = scanner.scan()
futures = [evaluate.remote(t, args) for t in tasks]
if len(futures) > 0:
print(f'Submitted {len(futures)} tasks.')
while len(futures) > 0:
done_ids, futures = ray.wait(futures, num_returns=1)
for done_id in done_ids:
done_task = ray.get(done_id)
done_task.save_to_db(db)
print(f'Remaining {len(futures)}. Finished {done_task.in_path}')
db.sync()
dump_db(db, os.path.join(args.root, 'summary.csv'))
time.sleep(1.0)
if __name__ == '__main__':
main()