DeepLearnToolbox DBN源码介绍
DeepLearnToolbox DBN源码介绍
这几天看了下DeepLearnToolbox的源码,在此记录一下自己对DBN代码的理解。
test_example_DBN.m:测试代码
function test_example_DBN
load ../data/mnist_40000_10000;
addpath('../DBN');
addpath('../NN');
addpath('../util');
train_x = double(train_x) / 255;
test_x = double(test_x) / 255;
train_y = double(train_y);
test_y = double(test_y);
rand('state',0)
//train dbn
dbn.sizes = [100 200]; //DBN的结构,v1层为raw pixel/原始图片,h1/v2层的节点数为100,h2/v3层的节点数为200
opts.numepochs = 3;
opts.batchsize = 100;
opts.momentum = 0; //记录以前的更新方向,并与现在的方向结合下,从而加快学习的速度
opts.alpha = 1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);
%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 10);
nn.activation_function = 'sigm';
//train nn
//得到DBN的初始化参数后,用nn进行微调
opts.numepochs = 3;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);
assert(er < 0.10, 'Too big error');
dbnsetup.m:建立DBN网络
function dbn = dbnsetup(dbn, x, opts)
n = size(x, 2);
dbn.sizes = [n, dbn.sizes]; //[784, 100,200]
// 初始化W,b,c
for u = 1 : numel(dbn.sizes) - 1
dbn.rbm{u}.alpha = opts.alpha;
dbn.rbm{u}.momentum = opts.momentum;
dbn.rbm{u}.W = zeros(dbn.sizes(u + 1), dbn.sizes(u));
dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u));
dbn.rbm{u}.b = zeros(dbn.sizes(u), 1); //可视层的偏置bias
dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1);
dbn.rbm{u}.c = zeros(dbn.sizes(u + 1), 1); //隐层的偏置bias
dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1);
end
end
dbntrain.m:训练DBN
function dbn = dbntrain(dbn, x, opts)
n = numel(dbn.rbm);
dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts);
for i = 2 : n
x = rbmup(dbn.rbm{i - 1}, x); // 即sigm(W*x+c)
dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts);
end
end
rbmtrain.m:训练RBM
采用对比散度(Contrastive Divergence,CD)算法进行训练,这是Hinton在2002年提出了RBM的一个快速学习算法
算法描述在 《Learning Deep Architectures for AI》 Algorithm 1,主要流程如下:
function rbm = rbmtrain(rbm, x, opts)
assert(isfloat(x), 'x must be a float');
assert(all(x(:)>=0) && all(x(:)<=1), 'all data in x must be in [0:1]');
m = size(x, 1);
numbatches = m / opts.batchsize;
assert(rem(numbatches, 1) == 0, 'numbatches not integer');
for i = 1 : opts.numepochs //迭代次数
kk = randperm(m); //将样本随机打乱
err = 0;
for l = 1 : numbatches
batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);
v1 = batch;
h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W'); // Gibbs采样
v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W); // Gibbs采样
h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W'); // sigm(W*v2+c)
// 对比上述流程图
c1 = h1' * v1;
c2 = h2' * v2;
// rbm.momentum:记录以前的更新方向,并与现在的方向结合,从而加快学习速度
rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2) / opts.batchsize;
rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;
rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;
rbm.W = rbm.W + rbm.vW;
rbm.b = rbm.b + rbm.vb;
rbm.c = rbm.c + rbm.vc;
err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
end
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Average reconstruction error is: ' num2str(err / numbatches)]);
end
end
dbnunfoldtonn.m:利用DBN的参数去初始化NN,然后用NN进行微调nn = nntrain(nn, train_x, train_y, opts);
function nn = dbnunfoldtonn(dbn, outputsize)
// DBNUNFOLDTONN Unfolds a DBN to a NN
// dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final layer of size outputsize added.
if(exist('outputsize','var'))
size = [dbn.sizes outputsize];
else
size = [dbn.sizes];
end
nn = nnsetup(size);
for i = 1 : numel(dbn.rbm)
nn.W{i} = [dbn.rbm{i}.c dbn.rbm{i}.W]; //利用DBN每层的W和c去初始化NN的参数
end
end
CNN源码解析:http://blog.csdn.net/zouxy09/article/details/9993743
http://blog.csdn.net/dark_scope/article/details/9495505
Reference:
(1) Learning Deep Architectures for AI
(2) A Practical Guide to Training Restricted Boltzmann Machines2010