在线学习与增量学习

写在前面

最近在公司调研了下在线学习方面的知识,在此做个记录。

在线学习和增量学习

在线学习包含增量学习和减量学习

  • 增量学习:是指一个学习系统能不断地从新样本中学习新的知识,并能保存大部分以前已经学习到的知识
  • 减量学习:抛弃“价值最低”的保留的训练样本

在线学习主要两个作用:

  • 可以利用全部的数据
  • 及时利用新的数据

避免离线更新模型的问题

在线学习算法

理论上说,只要能是用SGD或者类似mini_batch SGD进行优化的模型都是可以实现增量训练的,其中以Linear Model和Deep Learning最为典型,下面就简要介绍现成工具的使用吧。

sklearn中的增量学习

sklearn中提供了很多增量学习算法。虽然不是所有的算法都可以增量学习,但是学习器提供了 partial_fit的函数的都可以进行增量学习。

  • Classification
  • sklearn.naive_bayes.MultinomialNB
  • sklearn.naive_bayes.BernoulliNB
  • sklearn.linear_model.Perceptron
  • sklearn.linear_model.SGDClassifier
  • sklearn.linear_model.PassiveAggressiveClassifier
  • Regression
  • sklearn.linear_model.SGDRegressor
  • sklearn.linear_model.PassiveAggressiveRegressor
  • Clustering
  • sklearn.cluster.MiniBatchKMeans
  • Decomposition / feature Extraction
  • sklearn.decomposition.MiniBatchDictionaryLearning
  • sklearn.decomposition.IncrementalPCA
  • sklearn.decomposition.LatentDirichletAllocation
  • sklearn.cluster.MiniBatchKMeans

lightGBM的增量学习方法

主要是通过init_model和keep_training_booster两个参数实现增量训练

1
2
3
4
5
6
7
8
9
gbm = lgb.train(params,
lgb_train,
num_boost_round=1000,
valid_sets=lgb_eval,
init_model=gbm, #如果gbm不为None,那么就是在上次的基础上接着训练
feature_name=x_cols,
early_stopping_rounds=10,
verbose_eval=False,
keep_training_booster=True) # 增量训练

xgboost的增量学习方法

两种增量训练的方式

  • 在当前迭代树的基础上增加新树,原树不变
  • 当前迭代树的结构不变,重新计算叶节点权重,同时也可增加新树
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import xgboost as xgb
from sklearn.datasets import load_digits

xgb_params_01 = {}
xgb_params_02 = {'process_type': 'update',
'updater': 'refresh',
'refresh_leaf': True}


digits_2class = load_digits(2)

X_2class = digits_2class['data']
y_2class = digits_2class['target']

dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
gbdt = xgb.train(xgb_params_01, dtrain_2class, num_boost_round=3)
print(gbdt.get_dump())
gbdt = xgb.train(xgb_params_02, dtrain_2class, num_boost_round=3, xgb_model=gbdt)
print(gbdt.get_dump())

keras的增量学习方法

在深度学习中,由于训练周期长,如果因为异常退出要重新训练代价很大。
keras中增量训练是每次将一部分数据丢进网络训练,然后保存网络,下次新的数据过来再加载网络接着训练。

1
2
3
4
5
6
7
8
9
10
11
12
model = None
# 模型保存路径
model_file = '/path/model.h5'

if model == None:
model = Sequential()
model.add(Dense(64,activation='relu', input_shape=()))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='mape', optimizer='adam', metrics=['mse', 'mae', 'mape'])
else:
model = load_model(model_file)
history = model.fit(x=x_train, y=y_train, batch_size=64, epoces=10)

Ad Click Prediction: a View from the Trenches

这篇论文是Google在FTRL上的工程实践,但这是针对数据量大、特征维度大的情况。主要讲解了在线学习在LR上进行广告点击预估上的应用。

参考文献

您的支持将鼓励我继续创作!