TensorDataset是一个组合了多个Tensor的数据集类。它继承自Dataset类,与Dataset类相比,TensorDataset在处理数据集时更加方便。
在监督学习的训练过程中,我们需要知道训练数据和对应的标签,此时TensorDataset就可以排上用场,代码如下所示:
import torch
from torch.utils.data import DataLoader, TensorDataset
# 训练数据 X 和标签 Y
X = torch.randn(100, 3) # 100个样本,每个样本3个特征
Y = torch.randn(100, 1) # 100个样本的标签
# 创建 TensorDataset
dataset = TensorDataset(X, Y)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 迭代 DataLoader
for i, (x, y) in enumerate(dataloader):
print("当前的批次为:", i)
print("训练数据为:", x)
print("标签为:", y)