惯性聚合 高效追踪和阅读你感兴趣的博客、新闻、科技资讯
阅读原文 在惯性聚合中打开

推荐订阅源

N
News | PayPal Newsroom
云风的 BLOG
云风的 BLOG
GbyAI
GbyAI
Engineering at Meta
Engineering at Meta
B
Blog RSS Feed
钛媒体:引领未来商业与生活新知
钛媒体:引领未来商业与生活新知
The Register - Security
The Register - Security
L
LangChain Blog
A
About on SuperTechFans
S
Schneier on Security
博客园 - 三生石上(FineUI控件)
Stack Overflow Blog
Stack Overflow Blog
The Hacker News
The Hacker News
AWS News Blog
AWS News Blog
博客园 - 司徒正美
Scott Helme
Scott Helme
K
Kaspersky official blog
Cyberwarzone
Cyberwarzone
T
Tenable Blog
腾讯CDC
Recorded Future
Recorded Future
cs.CL updates on arXiv.org
cs.CL updates on arXiv.org
G
GRAHAM CLULEY
Security Latest
Security Latest
S
Securelist
D
Darknet – Hacking Tools, Hacker News & Cyber Security
aimingoo的专栏
aimingoo的专栏
Google DeepMind News
Google DeepMind News
V
Vulnerabilities – Threatpost
雷峰网
雷峰网
T
The Exploit Database - CXSecurity.com
freeCodeCamp Programming Tutorials: Python, JavaScript, Git & More
V
V2EX
T
The Blog of Author Tim Ferriss
D
Docker
S
Security Affairs
F
Full Disclosure
Know Your Adversary
Know Your Adversary
N
News and Events Feed by Topic
N
News and Events Feed by Topic
T
Tor Project blog
Hugging Face - Blog
Hugging Face - Blog
www.infosecurity-magazine.com
www.infosecurity-magazine.com
Microsoft Security Blog
Microsoft Security Blog
Simon Willison's Weblog
Simon Willison's Weblog
Recent Announcements
Recent Announcements
博客园_首页
博客园 - 聂微东
让小产品的独立变现更简单 - ezindie.com
让小产品的独立变现更简单 - ezindie.com
S
Security @ Cisco Blogs

极客兔兔

Go sync.Cond | Go 语言高性能编程 Go 死码消除与调试(debug)模式 | Go 语言高性能编程 Go sync.Once | Go 语言高性能编程 Go 逃逸分析 | Go 语言高性能编程 2020 年终总结 | 极客兔兔 Go struct 内存对齐 | Go 语言高性能编程 Go 空结构体 struct{} 的使用 | Go 语言高性能编程 控制协程(goroutine)的并发数量 | Go 语言高性能编程 | 极客兔兔 如何退出协程 goroutine (其他场景) | Go 语言高性能编程 如何退出协程 goroutine (超时场景) | Go 语言高性能编程 Go 语言陷阱 - 数组和切片 | Go 语言高性能编程 减小 Go 代码编译后的二进制体积 | Go 语言高性能编程 Go Reflect 提高反射性能 | Go 语言高性能编程 读写锁和互斥锁的性能比较 | Go 语言高性能编程 | 极客兔兔 for 和 range 的性能比较 | Go 语言高性能编程 切片(slice)性能及陷阱 | Go 语言高性能编程 | 极客兔兔 字符串拼接性能及原理 | Go 语言高性能编程 | 极客兔兔 pprof 性能分析 | Go 语言高性能编程 benchmark 基准测试 | Go 语言高性能编程 Go 语言高性能编程 | 极客兔兔 Go 接口型函数的使用场景 | 极客兔兔 Python 简明教程 | 快速入门 | 极客兔兔 Go 语言笔试面试题(代码输出) | 极客面试 | 极客兔兔 动手写RPC框架 - GeeRPC第七天 服务发现与注册中心(registry) | 极客兔兔 动手写RPC框架 - GeeRPC第六天 负载均衡(load balance) 动手写RPC框架 - GeeRPC第五天 支持HTTP协议 | 极客兔兔 动手写RPC框架 - GeeRPC第四天 超时处理(timeout) | 极客兔兔 动手写RPC框架 - GeeRPC第三天 服务注册(service register) 动手写RPC框架 - GeeRPC第二天 支持并发与异步的客户端 | 极客兔兔 动手写RPC框架 - GeeRPC第一天 服务端与消息编码 | 极客兔兔 7天用Go从零实现RPC框架GeeRPC | 极客兔兔 Go 语言笔试面试题(并发编程) | 极客面试 | 极客兔兔 Go 语言笔试面试题(基础语法) | 极客面试 | 极客兔兔 Go 语言笔试面试题汇总 | 极客面试 | 极客兔兔 Go Context 并发编程简明教程 | 快速入门 Go Mmap 文件内存映射简明教程 | 快速入门 动手写ORM框架 - GeeORM第七天 数据库迁移(Migrate) | 极客兔兔 动手写ORM框架 - GeeORM第六天 支持事务(Transaction) | 极客兔兔 动手写ORM框架 - GeeORM第五天 实现钩子(Hooks) | 极客兔兔 动手写ORM框架 - GeeORM第四天 链式操作与更新删除 | 极客兔兔 动手写ORM框架 - GeeORM第三天 记录新增和查询 | 极客兔兔 动手写ORM框架 - GeeORM第二天 对象表结构映射 | 极客兔兔 动手写ORM框架 - GeeORM第一天 database/sql 基础 SQLite 常用命令 | 速查表(Cheat Sheet) 7天用Go从零实现ORM框架GeeORM | 极客兔兔 动手写分布式缓存 - GeeCache第七天 使用 Protobuf 通信 动手写分布式缓存 - GeeCache第六天 防止缓存击穿 | 极客兔兔 动手写分布式缓存 - GeeCache第五天 分布式节点 | 极客兔兔 动手写分布式缓存 - GeeCache第四天 一致性哈希(hash) | 极客兔兔 Go Mock (gomock)简明教程 | 快速入门 动手写分布式缓存 - GeeCache第三天 HTTP 服务端 动手写分布式缓存 - GeeCache第二天 单机并发缓存 | 极客兔兔 Go Test 单元测试简明教程 | 快速入门 7天用Go从零实现分布式缓存GeeCache | 极客兔兔 Go WebAssembly (Wasm) 简明教程 | 快速入门 Go RPC & TLS 鉴权简明教程 | 快速入门 Go Protobuf 简明教程 | 快速入门 Go语言动手写Web框架 - Gee第七天 错误恢复(Panic Recover) WSL, Git, Mircosoft Terminal 等常用工具配置 Rust 简明教程 | 快速入门 | 极客兔兔 Go语言动手写Web框架 - Gee第六天 模板(HTML Template) 百宝箱 - 值得收藏的工具网站 | 极客兔兔 Go语言动手写Web框架 - Gee第五天 中间件Middleware | 极客兔兔 Go语言动手写Web框架 - Gee第四天 分组控制Group | 极客兔兔 Go语言动手写Web框架 - Gee第三天 前缀树路由Router | 极客兔兔 博客折腾记(七) - Gitalk Plus | 极客兔兔 Go语言动手写Web框架 - Gee第二天 上下文Context | 极客兔兔 Go2 新特性简明教程 | 快速入门 | 极客兔兔 博客折腾记(六) - 不要为了流量忘记了初心 | 极客兔兔 Go语言动手写Web框架 - Gee第一天 http.Handler | 极客兔兔 7天用Go从零实现Web框架Gee教程 | 极客兔兔 Go Gin 简明教程 | 快速入门 Go 语言简明教程 | 快速入门 | 极客兔兔 机器学习笔试面试题 11-20 | 极客面试 | 极客兔兔 机器学习笔试面试题 1-10 | 极客面试 | 极客兔兔 机器学习笔试面试题汇总 | 极客面试 | 极客兔兔 TensorFlow 2 中文文档 - RNN LSTM 文本分类 TensorFlow 2 中文文档 - TFHub 迁移学习 TensorFlow 2 中文文档 - 卷积神经网络分类 CIFAR-10 TensorFlow 2 中文文档 - 保存与加载模型 TensorFlow 2 中文文档 - 过拟合与欠拟合 TensorFlow 2 中文文档 - 回归预测燃油效率 TensorFlow 2 中文文档 - 特征工程结构化数据分类 TensorFlow 2 中文文档 - IMDB 文本分类 TensorFlow 2 中文文档 - MNIST 图像分类 TensorFlow 2 / 2.0 中文文档 TensorFlow 2.0 (九) - 强化学习 70行代码实战 Policy Gradient 博客折腾记(五) - 友链这件事,没那么简单 | 极客兔兔 博客折腾记(四) - 原创资格是争取来的 | 极客兔兔 TensorFlow 2.0 (八) - 强化学习 DQN 玩转 gym Mountain Car TensorFlow 2.0 (七) - 强化学习 Q-Learning 玩转 OpenAI gym 博客折腾记(三) - 主题设计、彩蛋与阅读量翻倍 | 极客兔兔 TensorFlow 2.0 (六) - 监督学习玩转 OpenAI gym game 博客折腾记(二) - 对搜索引擎的理解 | 极客兔兔 博客折腾记(一) - 极致性能的尝试 | 极客兔兔 Pandas 数据处理(三) - Cheat Sheet 中文版 TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) TensorFlow入门(三) - mnist手写数字识别(可视化训练) | 极客兔兔 Pandas 数据处理(二) - 筛选数据 | 极客兔兔 Pandas 数据处理(一) - DataFrame 与 Series
TensorFlow入门(四) - mnist手写数字识别(制作h5py训练集) | 极客兔兔
2018-04-02 · via 极客兔兔

源代码/数据集已上传到 Github - tensorflow-tutorial-samples

这篇文章是 TensorFlow Tutorial 入门教程的第四篇文章。

在之前的几篇文章中,我们都是通过 tensorflow.examples.tutorials.mnist来使用mnist训练集集,制作训练集主要有2个目的,一是加快训练时读取的速度,而是支持随机批读取。假如,每次训练时,都是直接读取图片,再将图片转为矩阵进行训练,那这样效率无疑是非常低下的。

这篇文章将使用numpy 和 h5py(HDF5文件格式)2种方式来制作训练集,并对这两种方式进行对比。

准备图片

mnist-images

直接读取tensorflow中mnist数据集,将数据集还原为图片。

在这里,使用 pillow库将矩阵转为图片。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np
from PIL import Image
from tensorflow.examples.tutorials.mnist import input_data


def gen_image(arr, index, label):

matrix = (np.reshape(1.0 - arr, (28, 28)) * 255).astype(np.uint8)
img = Image.fromarray(matrix, 'L')

img.save("./images/{}_{}.png".format(label, index))


data = input_data.read_data_sets('../mnist/data_set')
x, y = data.train.next_batch(200)
for i, (arr, label) in enumerate(zip(x, y)):
print(i, label)
gen_image(arr, i, label)

这样,就得到了200张 28*28的图片供下一步制作训练集。

制作npy格式的数据集

numpy能够将矩阵保存为文件,也能从文件中读取矩阵,因此可以考虑使用numpy制作数据集。

1
2
3
4
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

1. 图片转为矩阵并保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x, y = [], []

for i, image_path in enumerate(os.listdir('./images')):

label = int(image_path.split('_')[0])
label_one_hot = [0 if i != label else 1 for i in range(10)]
y.append(label_one_hot)


image = Image.open('./images/{}'.format(image_path)).convert('L')
image_arr = 1 - np.reshape(image, 784) / 255.0
x.append(image_arr)

np.save('data_set/X.npy', np.array(x))
np.save('data_set/Y.npy', np.array(y))

2. 读取文件随机批处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class DataSet:
def __init__(self):
x, y = np.load('data_set/X.npy'), np.load('data_set/Y.npy')
self.train_x, self.test_x, self.train_y, self.test_y = \
train_test_split(x, y, test_size=0.2, random_state=0)

self.train_size = len(self.train_x)

def get_train_batch(self, batch_size=64):

choice = np.random.randint(self.train_size, size=batch_size)
batch_x = self.train_x[choice, :]
batch_y = self.train_y[choice, :]

return batch_x, batch_y

def get_test_set(self):
return self.test_x, self.test_y
  • 一般情况下,我们会用随机批梯度下降的方式去进行训练,因此需要实现随机获取 batch_size个数据的功能。
  • 为了测试模型的泛化能力,测试集一般不与测试集交叉,常用 sklearn库中的train_test_split去分离训练数据与测试数据。

3. 如何使用

1
2
3
4
data_source = DataSet()
for i in range(1000):
train_x, train_y = data_source.get_train_batch(batch_size=32)
// ...

制作HDF5格式的数据集

HDF 是用于存储和分发科学数据的一种自我描述、多对象文件格式。HDF 是由美国国家超级计算应用中心(NCSA)创建的,以满足不同群体的科学家在不同工程项目领域之需要。一个HDF5文件就是一个由两种基本数据对象(groups and datasets)存放多种科学数据的容器:

  • HDF5 group: 包含0个或多个HDF5对象以及支持元数据(metadata)的一个群组结构。
  • HDF5 dataset: 数据元素的一个多维数组以及支持元数据(metadata)

直观理解,一个HDF5文件可以存储多个数据(value),并用索引(key)找到,支持层级嵌套,类似于Python中的字典。

Python中h5py来制作和使用HDF5格式的文件。

1
2
3
4
5
import os
import h5py
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

1. 图片转为矩阵并保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
x, y = [], []

for i, image_path in enumerate(os.listdir('./images')):

label = int(image_path.split('_')[0])
label_one_hot = [0 if i != label else 1 for i in range(10)]
y.append(label_one_hot)


image = Image.open('./images/{}'.format(image_path)).convert('L')
image_arr = 1 - np.reshape(image, 784) / 255.0
x.append(image_arr)

with h5py.File('./data_set/data.h5', 'w') as f:
f.create_dataset('x_data', data=np.array(x))
f.create_dataset('y_data', data=np.array(y))

2. 读取文件随机批处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class DataSet:
def __init__(self):
with h5py.File('./data_set/data.h5', 'r') as f:
x, y = f['x_data'].value, f['y_data'].value

self.train_x, self.test_x, self.train_y, self.test_y = \
train_test_split(x, y, test_size=0.2, random_state=0)

self.train_size = len(self.train_x)

def get_train_batch(self, batch_size=64):

choice = np.random.randint(self.train_size, size=batch_size)
batch_x = self.train_x[choice, :]
batch_y = self.train_y[choice, :]

return batch_x, batch_y

def get_test_set(self):
return self.test_x, self.test_y

f[‘x_data’] 是一个datasets,拥有 name, shape, value 属性

可以看到,我们只用了1个HDF5文件就将x 和 y存下来了。假如在保存文件前对训练集和测试集进行拆分,同样能够将 train_x, train_y, test_x, test_y 一起保存在一个 HDF5文件中,使用非常方便。

npy格式与hdf5格式的对比

# 读取(1000次/ms) 存储空间(M)
npy 1204 1.3
hdf5 1665 1.3

使用 200 张 28 * 28的图片对比,可以发现在没有使用任何压缩辅助的情况下,两种格式的数据占据的磁盘空间是一样的,HDF5的读取速度比npy慢了1/3,训练集如果能一次读取内存,启动训练前的读取时间可以忽略不计,但是HDF5格式的文件因为能够存储metadata和支持层级嵌套,键索引,使用起来更方便。

觉得还不错,不要吝惜你的star,支持是持续不断更新的动力。



上一篇 « TensorFlow入门(三) - mnist手写数字识别(可视化训练) 下一篇 » TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络)