File size: 5,230 Bytes
933ca80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from .imports import *
import os
from .collators import DataCollatorForMultitaskCellClassification

def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
    try:
        dataset = load_from_disk(dataset_path)

        task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
        task_to_column = dict(zip(task_names, config["task_columns"]))
        config["task_names"] = task_names

        if not is_test:
            available_columns = set(dataset.column_names)
            for column in task_to_column.values():
                if column not in available_columns:
                    raise KeyError(f"Column {column} not found in the dataset. Available columns: {list(available_columns)}")

        label_mappings = {}
        task_label_mappings = {}
        cell_id_mapping = {}
        num_labels_list = []

        # Load or create task label mappings
        if not is_test:
            for task, column in task_to_column.items():
                unique_values = sorted(set(dataset[column]))  # Ensure consistency
                label_mappings[column] = {label: idx for idx, label in enumerate(unique_values)}
                task_label_mappings[task] = label_mappings[column]
                num_labels_list.append(len(unique_values))

            # Print the mappings for each task with dataset type prefix
            for task, mapping in task_label_mappings.items():
                print(f"{dataset_type.capitalize()} mapping for {task}: {mapping}")  # sanity check, for train/validation splits

            # Save the task label mappings as a pickle file
            with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
                pickle.dump(task_label_mappings, f)
        else:
            # Load task label mappings from pickle file for test data
            with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
                task_label_mappings = pickle.load(f)
            
            # Infer num_labels_list from task_label_mappings
            for task, mapping in task_label_mappings.items():
                num_labels_list.append(len(mapping))

        # Store unique cell IDs in a separate dictionary
        for idx, record in enumerate(dataset):
            cell_id = record.get('unique_cell_id', idx)
            cell_id_mapping[idx] = cell_id

        # Transform records to the desired format
        transformed_dataset = []
        for idx, record in enumerate(dataset):
            transformed_record = {}
            transformed_record['input_ids'] = torch.tensor(record['input_ids'], dtype=torch.long)
            
            # Use index-based cell ID for internal tracking
            transformed_record['cell_id'] = idx

            if not is_test:
                # Prepare labels
                label_dict = {}
                for task, column in task_to_column.items():
                    label_value = record[column]
                    label_index = task_label_mappings[task][label_value]
                    label_dict[task] = label_index
                transformed_record['label'] = label_dict
            else:
                # Create dummy labels for test data
                label_dict = {task: -1 for task in config["task_names"]}
                transformed_record['label'] = label_dict

            transformed_dataset.append(transformed_record)

        return transformed_dataset, cell_id_mapping, num_labels_list
    except KeyError as e:
        print(f"Missing configuration or dataset key: {e}")
    except Exception as e:
        print(f"An error occurred while loading or preprocessing data: {e}")
        return None, None, None

def preload_and_process_data(config):
    # Load and preprocess data once
    train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
    val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
    return train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list

def get_data_loader(preprocessed_dataset, batch_size):
    nproc = os.cpu_count() ### I/O operations
    
    data_collator = DataCollatorForMultitaskCellClassification()
    
    loader = DataLoader(preprocessed_dataset, batch_size=batch_size, shuffle=True,
                        collate_fn=data_collator, num_workers=nproc, pin_memory=True)
    return loader
def preload_data(config):
    # Preprocessing the data before the Optuna trials start
    train_loader = get_data_loader("train", config)
    val_loader = get_data_loader("val", config)
    return train_loader, val_loader

def load_and_preprocess_test_data(config):
    """
    Load and preprocess test data, treating it as unlabeled.
    """
    return load_and_preprocess_data(config["test_path"], config, is_test=True)

def prepare_test_loader(config):
    """
    Prepare DataLoader for the test dataset.
    """
    test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
    test_loader = get_data_loader(test_dataset, config['batch_size'])
    return test_loader, cell_id_mapping, num_labels_list