|
from .imports import * |
|
import os |
|
from .collators import DataCollatorForMultitaskCellClassification |
|
|
|
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""): |
|
try: |
|
dataset = load_from_disk(dataset_path) |
|
|
|
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))] |
|
task_to_column = dict(zip(task_names, config["task_columns"])) |
|
config["task_names"] = task_names |
|
|
|
if not is_test: |
|
available_columns = set(dataset.column_names) |
|
for column in task_to_column.values(): |
|
if column not in available_columns: |
|
raise KeyError(f"Column {column} not found in the dataset. Available columns: {list(available_columns)}") |
|
|
|
label_mappings = {} |
|
task_label_mappings = {} |
|
cell_id_mapping = {} |
|
num_labels_list = [] |
|
|
|
|
|
if not is_test: |
|
for task, column in task_to_column.items(): |
|
unique_values = sorted(set(dataset[column])) |
|
label_mappings[column] = {label: idx for idx, label in enumerate(unique_values)} |
|
task_label_mappings[task] = label_mappings[column] |
|
num_labels_list.append(len(unique_values)) |
|
|
|
|
|
for task, mapping in task_label_mappings.items(): |
|
print(f"{dataset_type.capitalize()} mapping for {task}: {mapping}") |
|
|
|
|
|
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f: |
|
pickle.dump(task_label_mappings, f) |
|
else: |
|
|
|
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f: |
|
task_label_mappings = pickle.load(f) |
|
|
|
|
|
for task, mapping in task_label_mappings.items(): |
|
num_labels_list.append(len(mapping)) |
|
|
|
|
|
for idx, record in enumerate(dataset): |
|
cell_id = record.get('unique_cell_id', idx) |
|
cell_id_mapping[idx] = cell_id |
|
|
|
|
|
transformed_dataset = [] |
|
for idx, record in enumerate(dataset): |
|
transformed_record = {} |
|
transformed_record['input_ids'] = torch.tensor(record['input_ids'], dtype=torch.long) |
|
|
|
|
|
transformed_record['cell_id'] = idx |
|
|
|
if not is_test: |
|
|
|
label_dict = {} |
|
for task, column in task_to_column.items(): |
|
label_value = record[column] |
|
label_index = task_label_mappings[task][label_value] |
|
label_dict[task] = label_index |
|
transformed_record['label'] = label_dict |
|
else: |
|
|
|
label_dict = {task: -1 for task in config["task_names"]} |
|
transformed_record['label'] = label_dict |
|
|
|
transformed_dataset.append(transformed_record) |
|
|
|
return transformed_dataset, cell_id_mapping, num_labels_list |
|
except KeyError as e: |
|
print(f"Missing configuration or dataset key: {e}") |
|
except Exception as e: |
|
print(f"An error occurred while loading or preprocessing data: {e}") |
|
return None, None, None |
|
|
|
def preload_and_process_data(config): |
|
|
|
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(config["train_path"], config, dataset_type="train") |
|
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(config["val_path"], config, dataset_type="validation") |
|
return train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list |
|
|
|
def get_data_loader(preprocessed_dataset, batch_size): |
|
nproc = os.cpu_count() |
|
|
|
data_collator = DataCollatorForMultitaskCellClassification() |
|
|
|
loader = DataLoader(preprocessed_dataset, batch_size=batch_size, shuffle=True, |
|
collate_fn=data_collator, num_workers=nproc, pin_memory=True) |
|
return loader |
|
def preload_data(config): |
|
|
|
train_loader = get_data_loader("train", config) |
|
val_loader = get_data_loader("val", config) |
|
return train_loader, val_loader |
|
|
|
def load_and_preprocess_test_data(config): |
|
""" |
|
Load and preprocess test data, treating it as unlabeled. |
|
""" |
|
return load_and_preprocess_data(config["test_path"], config, is_test=True) |
|
|
|
def prepare_test_loader(config): |
|
""" |
|
Prepare DataLoader for the test dataset. |
|
""" |
|
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config) |
|
test_loader = get_data_loader(test_dataset, config['batch_size']) |
|
return test_loader, cell_id_mapping, num_labels_list |