博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
用于pytorch的H5Dataset接口(类比TensorDataset接口)
阅读量:6278 次
发布时间:2019-06-22

本文共 2292 字,大约阅读时间需要 7 分钟。

pytorch的TensorDataset接口

1 class TensorDataset(Dataset): 2     """Dataset wrapping data and target tensors. 3     Each sample will be retrieved by indexing both tensors along the first 4     dimension. 5     Arguments: 6         data_tensor (Tensor): contains sample data. 7         target_tensor (Tensor): contains sample targets (labels). 8     """ 9 10     def __init__(self, data_tensor, target_tensor):11         assert data_tensor.size(0) == target_tensor.size(0)12         self.data_tensor = data_tensor13         self.target_tensor = target_tensor14 15     def __getitem__(self, index):16         return self.data_tensor[index], self.target_tensor[index]17 18     def __len__(self):19 return self.data_tensor.size(0)

用于hdf5的H5Dataset接口

1 class H5Dataset(Dataset): 2     """Dataset wrapping data and target tensors. 3  4     Each sample will be retrieved by indexing both tensors along the first 5     dimension. 6  7     Arguments: 8         data_tensor (Tensor): contains sample data. 9         target_tensor (Tensor): contains sample targets (labels).10     """11 12     def __init__(self, data_tensor, target_tensor):13         assert data_tensor.shape[0] == target_tensor.shape[0]14         self.data_tensor = data_tensor15         self.target_tensor = target_tensor16 17     def __getitem__(self, index):18         # print(index)19         return self.data_tensor[index], self.target_tensor[index]20 21     def __len__(self):22         return self.data_tensor.shape[0]

对应的DataLoader(把TensorDataset改成H5Dataset即可)

1 def load_data(): 2     f = h5py.File("./dataset/CAVE.h5", 'r') 3     MS_train = f['train']["MS"] 4     RGB_train = f['train']["RGB"] 5     MS_test = f['test']["MS"] 6     RGB_test = f['test']["RGB"] 7     train_set = H5Dataset(RGB_train, MS_train) 8     test_set = H5Dataset(RGB_test, MS_test) 9     training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, pin_memory=True,10                                       shuffle=True)11     testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, pin_memory=True,12                                      shuffle=False)13     return training_data_loader, testing_data_loader

 

转载于:https://www.cnblogs.com/nwpuxuezha/p/7846751.html

你可能感兴趣的文章
Android 传感器
查看>>
【js】函数问题
查看>>
postgresql----数组类型和函数
查看>>
聚集索引,非聚集索引,唯一索引
查看>>
github命令行实用操作
查看>>
进程同步
查看>>
DRF 分页组件
查看>>
Https 与http
查看>>
c++ explicit 修饰构造函数
查看>>
HDU 3018 Ant Trip
查看>>
每天一个linux命令(4) df命令
查看>>
jchdl - GSL实例 - Counter
查看>>
23 设计模式
查看>>
linux
查看>>
hessian 在spring中的使用 (bean 如 Dao无法注入的问题)
查看>>
leetcode Nim game
查看>>
leetcode 189. Rotate Array
查看>>
24. Spring Boot 自定义Starter (未整理,待续)
查看>>
Lua用于游戏运行期热更(不重启游戏客户端)
查看>>
Openresty+Lua+Redis灰度发布
查看>>