UniMTS / data.py
studyfar's picture
initial
41f97d1
raw
history blame
13.7 kB
import torch
import numpy as np
import random
import os
import json
from scipy.signal import resample
import clip
from torch.utils.data import Dataset
class CLIPDataset(Dataset):
def __init__(self, args):
imu_dirs = [
f'{args.data_path}/sim/',
]
text_dirs = [
f'{args.data_path}/aug_texts/',
]
self.paths = []
for imu_dir, text_dir in zip(imu_dirs, text_dirs):
imu_files = [f.split('.')[0] for f in os.listdir(imu_dir) if os.path.isfile(os.path.join(imu_dir, f))]
text_files = [f.split('.')[0] for f in os.listdir(text_dir) if os.path.isfile(os.path.join(text_dir, f))]
common_files = [f for f in imu_files if f in text_files]
for f in common_files:
self.paths.append((os.path.join(imu_dir, f + '.npy'), os.path.join(text_dir, f + '.txt')))
self.args = args
if args.sample < 1:
self.paths = random.sample(self.paths, int(len(self.paths) * args.sample))
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
# load imu
imu_path, text_path = self.paths[idx]
imu = np.load(imu_path)
imu[np.isnan(imu)] = 0
# padding
if len(imu) < self.args.padding_size:
imu = np.pad(imu, ((0, self.args.padding_size - len(imu)), (0, 0), (0, 0)), mode='wrap')
imu = imu[:self.args.padding_size]
# random masking
mask = np.zeros_like(imu)
k = np.random.randint(1, 6) # randomly select k joints
selected_joints = np.random.choice(22, k, replace=False)
mask[:,selected_joints] = 1
imu = imu.reshape(len(imu), -1)
mask = mask.reshape(len(mask), -1)
# load text
with open(text_path, 'r') as file:
lines = file.readlines()
text = random.choice(lines).split('#')[0].strip() # remove the comment starting from "#"
batch = {}
batch['imu'] = imu
batch['text'] = text
batch['mask'] = mask
return batch
def select_samples(data, masks, labels, k, name, data_path):
unique_labels = torch.unique(labels)
selected_data = []
selected_masks = []
selected_labels = []
all_indices = torch.load(f'{data_path}/few_shot_data_2/{name}_k={k}.pth')
for i, label in enumerate(unique_labels):
selected_indices = all_indices[i]
selected_data.append(data[selected_indices])
selected_masks.append(masks[selected_indices])
selected_labels.append(labels[selected_indices])
selected_data = torch.cat(selected_data, dim=0)
selected_masks = torch.cat(selected_masks, dim=0)
selected_labels = torch.cat(selected_labels, dim=0)
return selected_data, selected_masks, selected_labels
def load(dataset, padding_size, data_path, split='test', k=None):
print(dataset)
X = np.load(f'{data_path}/{dataset}/X_{split}.npy')
real_labels = torch.from_numpy(np.load(f'{data_path}/{dataset}/y_{split}.npy'))
with open(f'{data_path}/{dataset}/{dataset}.json', 'r') as file:
data = json.load(file)
all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
if dataset == 'PAMAP':
all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
all_X[:,:,11] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
all_X[:,:,7] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
original_sampling_rate = 100
num_classes = 12
elif dataset == 'USCHAD':
all_X[:,:,5] = np.concatenate((X[:,:,0:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
original_sampling_rate = 100
num_classes = 12
elif dataset == 'UCIHAR':
all_X[:,:,9] = np.concatenate((X[:,:,6:9] * 9.80665, X[:,:,3:6]), axis=-1) # linear accel, gyro, total accel
original_sampling_rate = 50
num_classes = 6
elif dataset == 'Opp_g':
all_X[:,:,10] = np.concatenate((X[:,:,0:3] / 1000 * 9.8, X[:,:,3:6] / 1000), axis=-1) # convert unit from milli g to m/s^2
all_X[:,:,19] = np.concatenate((X[:,:,9:12] / 1000 * 9.8, X[:,:,12:15] / 1000), axis=-1)
all_X[:,:,20] = np.concatenate((X[:,:,18:21] / 1000 * 9.8, X[:,:,21:24] / 1000), axis=-1)
all_X[:,:,15] = np.concatenate((X[:,:,27:30] / 1000 * 9.8, X[:,:,30:33] / 1000), axis=-1)
all_X[:,:,16] = np.concatenate((X[:,:,36:39] / 1000 * 9.8, X[:,:,39:42] / 1000), axis=-1)
original_sampling_rate = 30
num_classes = 4 # locomotion
elif dataset == 'WISDM':
all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
original_sampling_rate = 20
num_classes = 18
elif dataset == 'DSADS':
all_X[:,:,11] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
all_X[:,:,21] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
all_X[:,:,17] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
all_X[:,:,6] = np.concatenate((X[:,:,27:30], X[:,:,30:33]), axis=-1)
all_X[:,:,2] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
original_sampling_rate = 25
num_classes = 19
elif dataset == 'Harth':
all_X[:,:,9,:3] = X[:,:,:3] * 9.80665
all_X[:,:,6,:3] = X[:,:,3:6] * 9.80665
original_sampling_rate = 50
num_classes = 12
elif dataset == 'Wharf':
X = -14.709 + X / 63 * (2 * 14.709)
all_X[:,:,21,:3] = X
original_sampling_rate = 32
num_classes = 14
elif dataset == 'Mhealth':
all_X[:,:,11,:3] = X[:,:,0:3]
all_X[:,:,3] = np.concatenate((X[:,:,6:9], X[:,:,9:12] / 180 * np.pi), axis=-1)
all_X[:,:,21] = np.concatenate((X[:,:,15:18], X[:,:,18:21] / 180 * np.pi), axis=-1)
original_sampling_rate = 50
num_classes = 12
elif dataset == 'UTD-MHAD':
all_X[real_labels < 21,:,21,:] = np.concatenate((X[real_labels < 21,:,0:3] * 9.80665, X[real_labels < 21,:,3:6] / 180 * np.pi), axis=-1)
all_X[real_labels >= 21,:,5,:] = np.concatenate((X[real_labels >= 21,:,0:3] * 9.80665, X[real_labels >= 21,:,3:6] / 180 * np.pi), axis=-1)
original_sampling_rate = 50
num_classes = 27
elif dataset == 'MotionSense':
all_X[:,:,5] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
all_X[:,:,1] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
original_sampling_rate = 50
num_classes = 6
elif dataset == 'w-HAR':
all_X[:,:,7] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
original_sampling_rate = 250
num_classes = 7
elif dataset == 'Shoaib':
all_X[:,:,1] = X[:,:,:6]
all_X[:,:,5] = X[:,:,6:12]
all_X[:,:,21] = X[:,:,12:18]
all_X[:,:,20] = X[:,:,18:24]
all_X[:,:,0] = X[:,:,24:30]
original_sampling_rate = 50
num_classes = 7
elif dataset == 'har70plus':
all_X[:,:,0,:3] = X[:,:,:3] * 9.80665
all_X[:,:,5,:3] = X[:,:,3:6] * 9.80665
original_sampling_rate = 50
num_classes = 7
elif dataset == 'MMAct':
all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
all_X[:,:,21,:3] = X[:,:,6:9]
original_sampling_rate = 50
num_classes = 35
elif dataset == 'realworld':
all_X[:,:,14] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
all_X[:,:,16] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
all_X[:,:,13] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
all_X[:,:,1] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
all_X[:,:,15] = np.concatenate((X[:,:,30:33], X[:,:,33:36]), axis=-1)
all_X[:,:,9] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
original_sampling_rate = 50
num_classes = 8
elif dataset == 'TNDA-HAR':
all_X[:,:,20] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
all_X[:,:,2] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
all_X[:,:,21] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
all_X[:,:,11] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
original_sampling_rate = 50
num_classes = 8
elif dataset == 'ut-complex':
all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
all_X[:,:,21] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
original_sampling_rate = 50
num_classes = 13
all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
# resample real data to 20 Hz
new_sampling_rate = 20
new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
# pad real data to args.padding_size
masks = np.ones_like(resampled_data)
if resampled_data.shape[1] < padding_size:
resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
if split == 'train' and k and k < len(real_inputs):
real_inputs, real_masks, real_labels = select_samples(real_inputs, real_masks, real_labels, k, dataset, data_path)
print(real_inputs.shape, real_labels.shape)
# load text
label_dictionary = data['label_dictionary']
label_list = [' '.join(labels) for labels in label_dictionary.values()]
all_text = clip.tokenize(label_list).cuda()
return real_inputs, real_masks, real_labels, label_list, all_text, num_classes
def load_multiple(dataset_list, padding_size, data_path, split='test', k=None):
real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list = [], [], [], [], [], []
for dataset in dataset_list:
real_inputs, real_masks, real_labels, label_list, all_text, num_classes = load(dataset, padding_size, data_path, split, k)
real_inputs_list.append(real_inputs)
real_masks_list.append(real_masks)
real_labels_list.append(real_labels)
label_list_list.append(label_list)
all_text_list.append(all_text)
num_classes_list.append(num_classes)
return real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list
def load_custom_data(X_path, y_path, config_path, joint_list, original_sampling_rate, padding_size=200, split='test', k=None, few_shot_path=None):
X = np.load(X_path)
real_labels = torch.from_numpy(np.load(y_path))
with open(config_path, 'r') as file:
data = json.load(file)
all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
for i, joint in enumerate(joint_list):
all_X[:,:,joint] = np.concatenate((X[:,:,6*i:6*i+3], X[:,:,6*i+3:6*i+6]), axis=-1)
all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
# resample real data to 20 Hz
new_sampling_rate = 20
new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
# pad real data to args.padding_size
masks = np.ones_like(resampled_data)
if resampled_data.shape[1] < padding_size:
resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
if split == 'train' and k and k < len(real_inputs):
unique_labels = torch.unique(real_labels)
if few_shot_path is None:
print('Generating few shot indices ...')
all_indices = []
for i, label in enumerate(unique_labels):
indices = torch.where(real_labels == label)[0]
selected_indices = indices[torch.randperm(len(indices))[:k]]
all_indices.append(selected_indices)
else:
print('Loading existing few shot indices ...')
all_indices = torch.load(few_shot_path)
selected_data = []
selected_masks = []
selected_labels = []
for i, label in enumerate(unique_labels):
selected_indices = all_indices[i]
selected_data.append(real_inputs[selected_indices])
selected_masks.append(real_masks[selected_indices])
selected_labels.append(real_labels[selected_indices])
selected_data = torch.cat(selected_data, dim=0)
selected_masks = torch.cat(selected_masks, dim=0)
selected_labels = torch.cat(selected_labels, dim=0)
real_inputs, real_masks, real_labels = selected_data, selected_masks, selected_labels
print(real_inputs.shape, real_labels.shape)
# load text
label_dictionary = data['label_dictionary']
label_list = [' '.join(labels) for labels in label_dictionary.values()]
all_text = clip.tokenize(label_list).cuda()
return real_inputs, real_masks, real_labels, label_list, all_text