File size: 5,230 Bytes
933ca80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from .imports import *
import os
from .collators import DataCollatorForMultitaskCellClassification
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
try:
dataset = load_from_disk(dataset_path)
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
task_to_column = dict(zip(task_names, config["task_columns"]))
config["task_names"] = task_names
if not is_test:
available_columns = set(dataset.column_names)
for column in task_to_column.values():
if column not in available_columns:
raise KeyError(f"Column {column} not found in the dataset. Available columns: {list(available_columns)}")
label_mappings = {}
task_label_mappings = {}
cell_id_mapping = {}
num_labels_list = []
# Load or create task label mappings
if not is_test:
for task, column in task_to_column.items():
unique_values = sorted(set(dataset[column])) # Ensure consistency
label_mappings[column] = {label: idx for idx, label in enumerate(unique_values)}
task_label_mappings[task] = label_mappings[column]
num_labels_list.append(len(unique_values))
# Print the mappings for each task with dataset type prefix
for task, mapping in task_label_mappings.items():
print(f"{dataset_type.capitalize()} mapping for {task}: {mapping}") # sanity check, for train/validation splits
# Save the task label mappings as a pickle file
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
pickle.dump(task_label_mappings, f)
else:
# Load task label mappings from pickle file for test data
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
task_label_mappings = pickle.load(f)
# Infer num_labels_list from task_label_mappings
for task, mapping in task_label_mappings.items():
num_labels_list.append(len(mapping))
# Store unique cell IDs in a separate dictionary
for idx, record in enumerate(dataset):
cell_id = record.get('unique_cell_id', idx)
cell_id_mapping[idx] = cell_id
# Transform records to the desired format
transformed_dataset = []
for idx, record in enumerate(dataset):
transformed_record = {}
transformed_record['input_ids'] = torch.tensor(record['input_ids'], dtype=torch.long)
# Use index-based cell ID for internal tracking
transformed_record['cell_id'] = idx
if not is_test:
# Prepare labels
label_dict = {}
for task, column in task_to_column.items():
label_value = record[column]
label_index = task_label_mappings[task][label_value]
label_dict[task] = label_index
transformed_record['label'] = label_dict
else:
# Create dummy labels for test data
label_dict = {task: -1 for task in config["task_names"]}
transformed_record['label'] = label_dict
transformed_dataset.append(transformed_record)
return transformed_dataset, cell_id_mapping, num_labels_list
except KeyError as e:
print(f"Missing configuration or dataset key: {e}")
except Exception as e:
print(f"An error occurred while loading or preprocessing data: {e}")
return None, None, None
def preload_and_process_data(config):
# Load and preprocess data once
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
return train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list
def get_data_loader(preprocessed_dataset, batch_size):
nproc = os.cpu_count() ### I/O operations
data_collator = DataCollatorForMultitaskCellClassification()
loader = DataLoader(preprocessed_dataset, batch_size=batch_size, shuffle=True,
collate_fn=data_collator, num_workers=nproc, pin_memory=True)
return loader
def preload_data(config):
# Preprocessing the data before the Optuna trials start
train_loader = get_data_loader("train", config)
val_loader = get_data_loader("val", config)
return train_loader, val_loader
def load_and_preprocess_test_data(config):
"""
Load and preprocess test data, treating it as unlabeled.
"""
return load_and_preprocess_data(config["test_path"], config, is_test=True)
def prepare_test_loader(config):
"""
Prepare DataLoader for the test dataset.
"""
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
test_loader = get_data_loader(test_dataset, config['batch_size'])
return test_loader, cell_id_mapping, num_labels_list |