惯性聚合 高效追踪和阅读你感兴趣的博客、新闻、科技资讯
阅读原文 在惯性聚合中打开

推荐订阅源

宝玉的分享
宝玉的分享
S
SegmentFault 最新的问题
Google DeepMind News
Google DeepMind News
OSCHINA 社区最新新闻
OSCHINA 社区最新新闻
aimingoo的专栏
aimingoo的专栏
The Cloudflare Blog
博客园 - Franky
阮一峰的网络日志
阮一峰的网络日志
I
InfoQ
V
V2EX
P
Proofpoint News Feed
F
Fortinet All Blogs
freeCodeCamp Programming Tutorials: Python, JavaScript, Git & More
酷 壳 – CoolShell
酷 壳 – CoolShell
D
DataBreaches.Net
cs.AI updates on arXiv.org
cs.AI updates on arXiv.org
L
Lohrmann on Cybersecurity
Recent Announcements
Recent Announcements
Latest news
Latest news
P
Palo Alto Networks Blog
博客园_首页
cs.CL updates on arXiv.org
cs.CL updates on arXiv.org
S
Securelist
Cyber Security Advisories - MS-ISAC
Cyber Security Advisories - MS-ISAC
博客园 - 【当耐特】
Threat Intelligence Blog | Flashpoint
Threat Intelligence Blog | Flashpoint
MongoDB | Blog
MongoDB | Blog
Blog — PlanetScale
Blog — PlanetScale
NISL@THU
NISL@THU
博客园 - 聂微东
Hugging Face - Blog
Hugging Face - Blog
V
Visual Studio Blog
云风的 BLOG
云风的 BLOG
P
Privacy & Cybersecurity Law Blog
C
Cybersecurity and Infrastructure Security Agency CISA
Cisco Talos Blog
Cisco Talos Blog
月光博客
月光博客
Security Latest
Security Latest
P
Proofpoint News Feed
小众软件
小众软件
T
Threat Research - Cisco Blogs
A
About on SuperTechFans
博客园 - 三生石上(FineUI控件)
C
Cisco Blogs
T
The Exploit Database - CXSecurity.com
爱范儿
爱范儿
罗磊的独立博客
Project Zero
Project Zero
W
WeLiveSecurity
U
Unit 42

极客兔兔

Go sync.Cond | Go 语言高性能编程 Go 死码消除与调试(debug)模式 | Go 语言高性能编程 Go sync.Once | Go 语言高性能编程 Go 逃逸分析 | Go 语言高性能编程 2020 年终总结 | 极客兔兔 Go struct 内存对齐 | Go 语言高性能编程 Go 空结构体 struct{} 的使用 | Go 语言高性能编程 控制协程(goroutine)的并发数量 | Go 语言高性能编程 | 极客兔兔 如何退出协程 goroutine (其他场景) | Go 语言高性能编程 如何退出协程 goroutine (超时场景) | Go 语言高性能编程 Go 语言陷阱 - 数组和切片 | Go 语言高性能编程 减小 Go 代码编译后的二进制体积 | Go 语言高性能编程 Go Reflect 提高反射性能 | Go 语言高性能编程 读写锁和互斥锁的性能比较 | Go 语言高性能编程 | 极客兔兔 for 和 range 的性能比较 | Go 语言高性能编程 切片(slice)性能及陷阱 | Go 语言高性能编程 | 极客兔兔 字符串拼接性能及原理 | Go 语言高性能编程 | 极客兔兔 pprof 性能分析 | Go 语言高性能编程 benchmark 基准测试 | Go 语言高性能编程 Go 语言高性能编程 | 极客兔兔 Go 接口型函数的使用场景 | 极客兔兔 Python 简明教程 | 快速入门 | 极客兔兔 Go 语言笔试面试题(代码输出) | 极客面试 | 极客兔兔 动手写RPC框架 - GeeRPC第七天 服务发现与注册中心(registry) | 极客兔兔 动手写RPC框架 - GeeRPC第六天 负载均衡(load balance) 动手写RPC框架 - GeeRPC第五天 支持HTTP协议 | 极客兔兔 动手写RPC框架 - GeeRPC第四天 超时处理(timeout) | 极客兔兔 动手写RPC框架 - GeeRPC第三天 服务注册(service register) 动手写RPC框架 - GeeRPC第二天 支持并发与异步的客户端 | 极客兔兔 动手写RPC框架 - GeeRPC第一天 服务端与消息编码 | 极客兔兔 7天用Go从零实现RPC框架GeeRPC | 极客兔兔 Go 语言笔试面试题(并发编程) | 极客面试 | 极客兔兔 Go 语言笔试面试题(基础语法) | 极客面试 | 极客兔兔 Go 语言笔试面试题汇总 | 极客面试 | 极客兔兔 Go Context 并发编程简明教程 | 快速入门 Go Mmap 文件内存映射简明教程 | 快速入门 动手写ORM框架 - GeeORM第七天 数据库迁移(Migrate) | 极客兔兔 动手写ORM框架 - GeeORM第六天 支持事务(Transaction) | 极客兔兔 动手写ORM框架 - GeeORM第五天 实现钩子(Hooks) | 极客兔兔 动手写ORM框架 - GeeORM第四天 链式操作与更新删除 | 极客兔兔 动手写ORM框架 - GeeORM第三天 记录新增和查询 | 极客兔兔 动手写ORM框架 - GeeORM第二天 对象表结构映射 | 极客兔兔 动手写ORM框架 - GeeORM第一天 database/sql 基础 SQLite 常用命令 | 速查表(Cheat Sheet) 7天用Go从零实现ORM框架GeeORM | 极客兔兔 动手写分布式缓存 - GeeCache第七天 使用 Protobuf 通信 动手写分布式缓存 - GeeCache第六天 防止缓存击穿 | 极客兔兔 动手写分布式缓存 - GeeCache第五天 分布式节点 | 极客兔兔 动手写分布式缓存 - GeeCache第四天 一致性哈希(hash) | 极客兔兔 Go Mock (gomock)简明教程 | 快速入门 动手写分布式缓存 - GeeCache第三天 HTTP 服务端 动手写分布式缓存 - GeeCache第二天 单机并发缓存 | 极客兔兔 Go Test 单元测试简明教程 | 快速入门 7天用Go从零实现分布式缓存GeeCache | 极客兔兔 Go WebAssembly (Wasm) 简明教程 | 快速入门 Go RPC & TLS 鉴权简明教程 | 快速入门 Go Protobuf 简明教程 | 快速入门 Go语言动手写Web框架 - Gee第七天 错误恢复(Panic Recover) WSL, Git, Mircosoft Terminal 等常用工具配置 Rust 简明教程 | 快速入门 | 极客兔兔 Go语言动手写Web框架 - Gee第六天 模板(HTML Template) 百宝箱 - 值得收藏的工具网站 | 极客兔兔 Go语言动手写Web框架 - Gee第五天 中间件Middleware | 极客兔兔 Go语言动手写Web框架 - Gee第四天 分组控制Group | 极客兔兔 Go语言动手写Web框架 - Gee第三天 前缀树路由Router | 极客兔兔 博客折腾记(七) - Gitalk Plus | 极客兔兔 Go语言动手写Web框架 - Gee第二天 上下文Context | 极客兔兔 Go2 新特性简明教程 | 快速入门 | 极客兔兔 博客折腾记(六) - 不要为了流量忘记了初心 | 极客兔兔 Go语言动手写Web框架 - Gee第一天 http.Handler | 极客兔兔 7天用Go从零实现Web框架Gee教程 | 极客兔兔 Go Gin 简明教程 | 快速入门 Go 语言简明教程 | 快速入门 | 极客兔兔 机器学习笔试面试题 11-20 | 极客面试 | 极客兔兔 机器学习笔试面试题 1-10 | 极客面试 | 极客兔兔 机器学习笔试面试题汇总 | 极客面试 | 极客兔兔 TensorFlow 2 中文文档 - RNN LSTM 文本分类 TensorFlow 2 中文文档 - TFHub 迁移学习 TensorFlow 2 中文文档 - 卷积神经网络分类 CIFAR-10 TensorFlow 2 中文文档 - 保存与加载模型 TensorFlow 2 中文文档 - 过拟合与欠拟合 TensorFlow 2 中文文档 - 特征工程结构化数据分类 TensorFlow 2 中文文档 - IMDB 文本分类 TensorFlow 2 中文文档 - MNIST 图像分类 TensorFlow 2 / 2.0 中文文档 TensorFlow 2.0 (九) - 强化学习 70行代码实战 Policy Gradient 博客折腾记(五) - 友链这件事,没那么简单 | 极客兔兔 博客折腾记(四) - 原创资格是争取来的 | 极客兔兔 TensorFlow 2.0 (八) - 强化学习 DQN 玩转 gym Mountain Car TensorFlow 2.0 (七) - 强化学习 Q-Learning 玩转 OpenAI gym 博客折腾记(三) - 主题设计、彩蛋与阅读量翻倍 | 极客兔兔 TensorFlow 2.0 (六) - 监督学习玩转 OpenAI gym game 博客折腾记(二) - 对搜索引擎的理解 | 极客兔兔 博客折腾记(一) - 极致性能的尝试 | 极客兔兔 Pandas 数据处理(三) - Cheat Sheet 中文版 TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) TensorFlow入门(四) - mnist手写数字识别(制作h5py训练集) | 极客兔兔 TensorFlow入门(三) - mnist手写数字识别(可视化训练) | 极客兔兔 Pandas 数据处理(二) - 筛选数据 | 极客兔兔 Pandas 数据处理(一) - DataFrame 与 Series
TensorFlow 2 中文文档 - 回归预测燃油效率
2019-07-11 · via 极客兔兔

源代码/数据集已上传到 Github - tensorflow2-docs-zh

TF2.0 TensorFlow 2 / 2.0 文档:Regression 回归

主要内容:使用回归预测烟油效率。

回归通常用来预测连续值,比如价格和概率。分类问题不一样,类别是固定的,目的是判断属于哪一类。比如给你一堆猫和狗的图片,判断一张图片是猫还是狗就是一个典型的分类问题。

接下来使用的是经典的 Auto MPG 数据集,这个数据集包括气缸(cylinders),排量(displayment),马力(horsepower) 和重量(weight)等属性。我们需要利用这些属性搭建模型,预测汽车的燃油效率(fuel efficiency)。

模型搭建使用tf.keras API。

1
2
3
4
5
6
7
8
import pathlib

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Auto MPG 数据集

获取数据

1
2
3
4
5
6
7
8
9
10
11
12
13

url = "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data"
dataset_path = keras.utils.get_file("auto-mpg.data", url)


column_names = ['MPG','气缸','排量','马力','重量','加速度', '年份', '产地']
raw_dataset = pd.read_csv(dataset_path, names=column_names,
na_values = "?", comment='\t',
sep=" ", skipinitialspace=True)

dataset = raw_dataset.copy()

dataset.head(3)
MPG 气缸 排量 马力 重量 加速度 年份 产地
0 18.0 8 307.0 130.0 3504.0 12.0 70 1
1 15.0 8 350.0 165.0 3693.0 11.5 70 1
2 18.0 8 318.0 150.0 3436.0 11.0 70 1

清洗数据

检查是否有 NA 值。

1
dataset.isna().sum()
1
2
3
4
5
6
7
8
9
MPG    0
气缸 0
排量 0
马力 6
重量 0
加速度 0
年份 0
产地 0
dtype: int64

直接去除含有NA值的行(马力)

1
dataset = dataset.dropna()

在获取的数据集中,Origin(产地)不是数值类型,需转为独热编码。

1
2
3
4
5
6
origin = dataset.pop('产地')
dataset['美国'] = (origin == 1)*1.0
dataset['欧洲'] = (origin == 2)*1.0
dataset['日本'] = (origin == 3)*1.0

dataset.head(3)
MPG 气缸 排量 马力 重量 加速度 年份 美国 欧洲 日本
0 18.0 8 307.0 130.0 3504.0 12.0 70 1.0 0.0 0.0
1 15.0 8 350.0 165.0 3693.0 11.5 70 1.0 0.0 0.0
2 18.0 8 318.0 150.0 3436.0 11.0 70 1.0 0.0 0.0

划分训练集与测试集

1
2
3

train_dataset = dataset.sample(frac=0.8, random_state=0)
test_dataset = dataset.drop(train_dataset.index)

检查数据

快速看一看训练集中属性两两之间的关系吧。

1
2
3
4
5

plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

sns.pairplot(train_dataset[["MPG", "气缸", "排量", "重量"]], diag_kind="kde")

matplotlib 中文乱码看这里:matplotlib图例中文乱码?

Seaborn four feature relation

你还可以使用train_dataset.describle()快速浏览每一属性的平均值、标准差、最小值、最大值等信息,能够帮助你快速地识别出不合理的数据。

1
2
3
4
train_stats = train_dataset.describe()
train_stats.pop("MPG")
train_stats = train_stats.transpose()
train_stats
count mean std min 25% 50% 75% max
气缸 314.0 5.477707 1.699788 3.0 4.00 4.0 8.00 8.0
排量 314.0 195.318471 104.331589 68.0 105.50 151.0 265.75 455.0
….

分离 label

1
2
3

train_labels = train_dataset.pop('MPG')
test_labels = test_dataset.pop('MPG')

归一化数据

通常训练前需要归一化数据,不同属性使用的计量单位不一样,值的范围不一样,训练就会很困难。比如其中一个属性的范围是[0.1, 0.5],而另一个属性的范围是[1000, 5000],那数值大的属性就容易对训练产生干扰,很可能导致训练不能收敛,或者是数值小的属性在模型中几乎没有发挥作用。归一化将不同范围的数据映射到[0,1]的空间内,可以有效地避免这个问题。

1
2
3
4
def norm(x):
return (x - train_stats['mean']) / train_stats['std']
normed_train_data = norm(train_dataset)
normed_test_data = norm(test_dataset)

模型

搭建模型

我们的模型包含2个全连接的隐藏层构成,输出层返回一个连续值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def build_model():
input_dim = len(train_dataset.keys())
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=[input_dim,]),
layers.Dense(64, activation='relu'),
layers.Dense(1)
])

model.compile(loss='mse', metrics=['mae', 'mse'],
optimizer=tf.keras.optimizers.RMSprop(0.001))
return model

model = build_model()

model.summary()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param
=================================================================
dense_4 (Dense) (None, 64) 640
_________________________________________________________________
dense_5 (Dense) (None, 32) 2080
_________________________________________________________________
dense_6 (Dense) (None, 1) 33
=================================================================
Total params: 2,753
Trainable params: 2,753
Non-trainable params: 0
_________________________________________________________________

训练模型

在之前的案例,比如结构化数据分类,我们调用model.fit会打印出训练的进度。我们可以禁用默认的行为,并自定义训练进度条。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import sys


EPOCHS = 1000

class ProgressBar(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs):

self.draw_progress_bar(epoch + 1, EPOCHS)

def draw_progress_bar(self, cur, total, bar_len=50):
cur_len = int(cur / total * bar_len)
sys.stdout.write("\r")
sys.stdout.write("[{:<{}}] {}/{}".format("=" * cur_len, bar_len, cur, total))
sys.stdout.flush()

history = model.fit(
normed_train_data, train_labels,
epochs=EPOCHS, validation_split = 0.2, verbose=0,
callbacks=[ProgressBar()])
1
[==================================================] 1000/1000

训练过程都存储在了history对象中,我们可以借助 matplotlib 将训练过程可视化。

1
2
3
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
hist.tail(3)
loss mae mse val_loss val_mae val_mse epoch
997 3.132053 1.142280 3.132053 9.711935 2.361466 9.711935 997
998 3.021109 1.093424 3.021109 9.488593 2.298264 9.488593 998
999 3.028849 1.132241 3.028849 9.453931 2.275017 9.453931 999
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def plot_history(history):
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
plt.figure()
plt.xlabel('epoch')
plt.ylabel('metric - MSE')
plt.plot(hist['epoch'], hist['mse'], label='训练集')
plt.plot(hist['epoch'], hist['val_mse'], label = '验证集')
plt.ylim([0, 20])
plt.legend()

plt.figure()
plt.xlabel('epoch')
plt.ylabel('metric - MAE')
plt.plot(hist['epoch'], hist['mae'], label='训练集')
plt.plot(hist['epoch'], hist['val_mae'], label = '验证集')
plt.ylim([0, 5])
plt.legend()

plot_history(history)

MAE

从图中,我们可以看到,从100 epoch开始,训练集的loss仍旧继续降低,但验证集的loss却在升高,说明过拟合了,训练应该早一点结束。接下来,我们使用 keras.callbacks.EarlyStopping,每一波(epoch)训练结束时,测试训练情况,如果训练不再有效果(验证集的loss,即val_loss 不再下降),则自动地停止训练。

1
2
3
4
5
6
model = build_model()
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,
validation_split = 0.2, verbose=0,
callbacks=[early_stop, ProgressBar()])
plot_history(history)
1
[===                                               ] 70/1000

在第 70 epoch 时,停止了训练。

MAE

接下来使用测试集来评估训练效果。

1
2
3
loss, mae, mse = model.evaluate(normed_test_data, test_labels, verbose=0)
print("测试集平均绝对误差(MAE): {:5.2f} MPG".format(mae))

从图中我们也可以看出,1.9比验证集还略低一点。

预测

最后,我们使用测试集中的数据来预测 MPG 值。

1
2
3
4
5
6
7
8
9
10
test_pred = model.predict(normed_test_data).flatten()

plt.scatter(test_labels, test_pred)
plt.xlabel('真实值')
plt.ylabel('预测值')
plt.axis('equal')
plt.axis('square')
plt.xlim([0,plt.xlim()[1]])
plt.ylim([0,plt.ylim()[1]])
plt.plot([-100, 100], [-100, 100])

看起来,模型训练得还不错。

Test True Pred

结论

  • 均方误差(Mean Squared Error, MSE) 常作为回归问题的损失函数(loss function),与分类问题不太一样。
  • 同样,评价指标(evaluation metrics)也不一样,分类问题常用准确率(accuracy),回归问题常用平均绝对误差 (Mean Absolute Error, MAE)
  • 每一列数据都有不同的范围,每一列,即每一个feature的数据需要分别缩放到相同的范围。常用归一化的方式,缩放到[0, 1]。
  • 如果训练数据过少,最好搭建一个隐藏层少的小的神经网络,避免过拟合。
  • 早停法(Early Stoping)也是防止过拟合的一种方式。

返回文档首页

完整代码:Github - auto_mpg_regression.ipynb
参考文档:Regression: Predict fuel efficiency

附 推荐



上一篇 « TensorFlow 2 中文文档 - 特征工程结构化数据分类 下一篇 » TensorFlow 2 中文文档 - 过拟合与欠拟合