File size: 6,256 Bytes
933ca80 f07bfd7 9af94b8 933ca80 9af94b8 933ca80 9af94b8 933ca80 9af94b8 f07bfd7 9af94b8 933ca80 9af94b8 933ca80 9af94b8 933ca80 9af94b8 933ca80 9af94b8 933ca80 9af94b8 933ca80 f07bfd7 933ca80 9af94b8 f07bfd7 9af94b8 933ca80 9af94b8 f07bfd7 9af94b8 f07bfd7 9af94b8 f07bfd7 933ca80 9af94b8 933ca80 f07bfd7 933ca80 9af94b8 933ca80 f07bfd7 933ca80 9af94b8 f07bfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
from .collators import DataCollatorForMultitaskCellClassification
from .imports import *
def validate_columns(dataset, required_columns, dataset_type):
"""Ensures required columns are present in the dataset."""
missing_columns = [col for col in required_columns if col not in dataset.column_names]
if missing_columns:
raise KeyError(
f"Missing columns in {dataset_type} dataset: {missing_columns}. "
f"Available columns: {dataset.column_names}"
)
def create_label_mappings(dataset, task_to_column):
"""Creates label mappings for the dataset."""
task_label_mappings = {}
num_labels_list = []
for task, column in task_to_column.items():
unique_values = sorted(set(dataset[column]))
mapping = {label: idx for idx, label in enumerate(unique_values)}
task_label_mappings[task] = mapping
num_labels_list.append(len(unique_values))
return task_label_mappings, num_labels_list
def save_label_mappings(mappings, path):
"""Saves label mappings to a pickle file."""
with open(path, "wb") as f:
pickle.dump(mappings, f)
def load_label_mappings(path):
"""Loads label mappings from a pickle file."""
with open(path, "rb") as f:
return pickle.load(f)
def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
"""Transforms the dataset to the required format."""
transformed_dataset = []
cell_id_mapping = {}
for idx, record in enumerate(dataset):
transformed_record = {
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
"cell_id": idx, # Index-based cell ID
}
if not is_test:
label_dict = {
task: task_label_mappings[task][record[column]]
for task, column in task_to_column.items()
}
else:
label_dict = {task: -1 for task in config["task_names"]}
transformed_record["label"] = label_dict
transformed_dataset.append(transformed_record)
cell_id_mapping[idx] = record.get("unique_cell_id", idx)
return transformed_dataset, cell_id_mapping
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
"""Main function to load and preprocess data."""
try:
dataset = load_from_disk(dataset_path)
# Setup task and column mappings
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
task_to_column = dict(zip(task_names, config["task_columns"]))
config["task_names"] = task_names
label_mappings_path = os.path.join(
config["results_dir"],
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
)
if not is_test:
validate_columns(dataset, task_to_column.values(), dataset_type)
# Create and save label mappings
task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
save_label_mappings(task_label_mappings, label_mappings_path)
else:
# Load existing mappings for test data
task_label_mappings = load_label_mappings(label_mappings_path)
num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
# Transform dataset
transformed_dataset, cell_id_mapping = transform_dataset(
dataset, task_to_column, task_label_mappings, config, is_test
)
return transformed_dataset, cell_id_mapping, num_labels_list
except KeyError as e:
raise ValueError(f"Configuration error or dataset key missing: {e}")
except Exception as e:
raise RuntimeError(f"Error during data loading or preprocessing: {e}")
def preload_and_process_data(config):
"""Preloads and preprocesses train and validation datasets."""
# Process train data and save mappings
train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
# Process validation data and save mappings
val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
# Validate that the mappings match
validate_label_mappings(config)
return (*train_data, *val_data[:2]) # Return train and val data along with mappings
def validate_label_mappings(config):
"""Ensures train and validation label mappings are consistent."""
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
train_mappings = load_label_mappings(train_mappings_path)
val_mappings = load_label_mappings(val_mappings_path)
for task_name in config["task_names"]:
if train_mappings[task_name] != val_mappings[task_name]:
raise ValueError(
f"Mismatch in label mappings for task '{task_name}'.\n"
f"Train Mapping: {train_mappings[task_name]}\n"
f"Validation Mapping: {val_mappings[task_name]}"
)
def get_data_loader(preprocessed_dataset, batch_size):
"""Creates a DataLoader with optimal settings."""
return DataLoader(
preprocessed_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=DataCollatorForMultitaskCellClassification(),
num_workers=os.cpu_count(),
pin_memory=True,
)
def preload_data(config):
"""Preprocesses train and validation data for trials."""
train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
return train_loader, val_loader
def load_and_preprocess_test_data(config):
"""Loads and preprocesses test data."""
return load_and_preprocess_data(config["test_path"], config, is_test=True)
def prepare_test_loader(config):
"""Prepares DataLoader for test data."""
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
test_loader = get_data_loader(test_dataset, config["batch_size"])
return test_loader, cell_id_mapping, num_labels_list
|