pytorch的TensorDataset接口
1 class TensorDataset(Dataset): 2 """Dataset wrapping data and target tensors. 3 Each sample will be retrieved by indexing both tensors along the first 4 dimension. 5 Arguments: 6 data_tensor (Tensor): contains sample data. 7 target_tensor (Tensor): contains sample targets (labels). 8 """ 9 10 def __init__(self, data_tensor, target_tensor):11 assert data_tensor.size(0) == target_tensor.size(0)12 self.data_tensor = data_tensor13 self.target_tensor = target_tensor14 15 def __getitem__(self, index):16 return self.data_tensor[index], self.target_tensor[index]17 18 def __len__(self):19 return self.data_tensor.size(0)
用于hdf5的H5Dataset接口
1 class H5Dataset(Dataset): 2 """Dataset wrapping data and target tensors. 3 4 Each sample will be retrieved by indexing both tensors along the first 5 dimension. 6 7 Arguments: 8 data_tensor (Tensor): contains sample data. 9 target_tensor (Tensor): contains sample targets (labels).10 """11 12 def __init__(self, data_tensor, target_tensor):13 assert data_tensor.shape[0] == target_tensor.shape[0]14 self.data_tensor = data_tensor15 self.target_tensor = target_tensor16 17 def __getitem__(self, index):18 # print(index)19 return self.data_tensor[index], self.target_tensor[index]20 21 def __len__(self):22 return self.data_tensor.shape[0]
对应的DataLoader(把TensorDataset改成H5Dataset即可)
1 def load_data(): 2 f = h5py.File("./dataset/CAVE.h5", 'r') 3 MS_train = f['train']["MS"] 4 RGB_train = f['train']["RGB"] 5 MS_test = f['test']["MS"] 6 RGB_test = f['test']["RGB"] 7 train_set = H5Dataset(RGB_train, MS_train) 8 test_set = H5Dataset(RGB_test, MS_test) 9 training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, pin_memory=True,10 shuffle=True)11 testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, pin_memory=True,12 shuffle=False)13 return training_data_loader, testing_data_loader