找回密码
 立即注册
首页 业界区 安全 初识Dataset

初识Dataset

扈梅风 昨天 21:16
Dataset与Dataloader

Dataset主要是提供一种方式去获取数据以及label,主要实现如何获取每一个数据及其label,告诉我们总共有多少的数据;
Dataloader为后面的网络提供不同的数据类型;
Dataset

1.是一个抽象的类
2.可重写__getitiem__与__len__类
1.png

可以通过控制台,看到很多变量的属性。
运用dataset

2.png

添加标签

用的是蚂蚁和蜜蜂的数据集,并没有标签,当然可以手动添加简单的,下面提供代码形式自动增加
其中目录为:需手动添加文件夹
3.png

完整代码

点击查看代码
  1. from torch.utils.data import Dataset  #引入Dataset这个类
  2. from PIL import Image   #读取我们的图片
  3. #Image.open:读入该图片;
  4. #.size:图片的大小;
  5. #。show:打开看看图片
  6. import os   #获取所有图片的地址
  7. #os.listdir(路径):获得列表所有文件的地址
  8. #os.path.join:将两个路径合起来
  9. class Mydata(Dataset):   #定义了一个类
  10.     #def:函数,传入参数后,赋予其性质或者功能
  11.     def __init__(self, root_dir,label_dir):  #获取这个图片需要什么,就定义什么变量
  12.         self.root_dir=root_dir
  13.         self.label_dir=label_dir
  14.         self.path=os.path.join(self.root_dir,self.label_dir)
  15.         self.img_path=os.listdir(self.path)
  16.     def __getitem__(self,idx):
  17.         img_name=self.img_path[idx]
  18.         img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
  19.         img=Image.open(img_item_path)
  20.         label =self.label_dir
  21.         return img,label
  22.     def  __len_(self):  #读取传入东西的长度,数量等
  23.         return len(self.img_path)
  24. root_dir="dataset/train"
  25. ants_label_dir="ants"
  26. bees_label_dir="bees"
  27. ants_dataset=Mydata(root_dir,ants_label_dir)
  28. bees_dataset=Mydata(root_dir,bees_label_dir)
  29. train_dataset=ants_dataset+bees_dataset
复制代码
### 添加标签代码点击查看代码
  1. #快速将所有图片都添加label
  2. import os
  3. root_dir="dataset/train"
  4. target_dir="bees_image"
  5. img_path=os.listdir(os.path.join(root_dir,target_dir))
  6. label=target_dir.split('_')[0]
  7. out_dir="bees_label"
  8. for i in img_path:
  9.     file_name=i.split('.jpg')[0]
  10.     with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:
  11.         f.write((label))
复制代码
来源

练手数据集
密码: 5suq
笔记学习视频

来源:豆瓜网用户自行投稿发布,如果侵权,请联系站长删除

相关推荐

您需要登录后才可以回帖 登录 | 立即注册