45fan.com - 路饭网

搜索: 您的位置主页 > 网络频道 > 阅读资讯:DeepLearnToolbox DBN源码介绍

DeepLearnToolbox DBN源码介绍

2016-09-03 09:40:27 来源:www.45fan.com 【

DeepLearnToolbox DBN源码介绍

这几天看了下DeepLearnToolbox的源码,在此记录一下自己对DBN代码的理解。

 

test_example_DBN.m:测试代码

function test_example_DBN
load ../data/mnist_40000_10000;
addpath('../DBN');
addpath('../NN');
addpath('../util');
train_x = double(train_x) / 255;
test_x = double(test_x) / 255;
train_y = double(train_y);
test_y = double(test_y);

rand('state',0)
//train dbn
dbn.sizes = [100 200]; //DBN的结构,v1层为raw pixel/原始图片,h1/v2层的节点数为100,h2/v3层的节点数为200
opts.numepochs =  3;
opts.batchsize = 100;
opts.momentum =  0; //记录以前的更新方向,并与现在的方向结合下,从而加快学习的速度
opts.alpha   =  1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 10);
nn.activation_function = 'sigm';

//train nn
//得到DBN的初始化参数后,用nn进行微调
opts.numepochs = 3;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

assert(er < 0.10, 'Too big error');


 

dbnsetup.m:建立DBN网络

function dbn = dbnsetup(dbn, x, opts)
  n = size(x, 2);
  dbn.sizes = [n, dbn.sizes]; //[784, 100,200]
  // 初始化W,b,c
  for u = 1 : numel(dbn.sizes) - 1
    dbn.rbm{u}.alpha  = opts.alpha;
    dbn.rbm{u}.momentum = opts.momentum;

    dbn.rbm{u}.W = zeros(dbn.sizes(u + 1), dbn.sizes(u));
    dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u));

    dbn.rbm{u}.b = zeros(dbn.sizes(u), 1); //可视层的偏置bias
    dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1);

    dbn.rbm{u}.c = zeros(dbn.sizes(u + 1), 1); //隐层的偏置bias
    dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1);
  end

end

dbntrain.m:训练DBN

 

function dbn = dbntrain(dbn, x, opts)
  n = numel(dbn.rbm);


  dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts);
  for i = 2 : n
    x = rbmup(dbn.rbm{i - 1}, x); // 即sigm(W*x+c)
    dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts);
  end


end

 

rbmtrain.m:训练RBM

采用对比散度(Contrastive Divergence,CD)算法进行训练,这是Hinton在2002年提出了RBM的一个快速学习算法
算法描述在 《Learning Deep Architectures for AI》 Algorithm 1,主要流程如下:

DeepLearnToolbox DBN源码介绍

 

function rbm = rbmtrain(rbm, x, opts)
  assert(isfloat(x), 'x must be a float');
  assert(all(x(:)>=0) && all(x(:)<=1), 'all data in x must be in [0:1]');
  m = size(x, 1);
  numbatches = m / opts.batchsize;
  
  assert(rem(numbatches, 1) == 0, 'numbatches not integer');

  for i = 1 : opts.numepochs //迭代次数 
    kk = randperm(m); //将样本随机打乱
    err = 0;
    for l = 1 : numbatches
      batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);
      
      v1 = batch;
      h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');  // Gibbs采样
      v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);  // Gibbs采样
      h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W');   // sigm(W*v2+c)
      // 对比上述流程图
      c1 = h1' * v1;
      c2 = h2' * v2;
      
      // rbm.momentum:记录以前的更新方向,并与现在的方向结合,从而加快学习速度  
      rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)   / opts.batchsize;
      rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;
      rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;

      rbm.W = rbm.W + rbm.vW;
      rbm.b = rbm.b + rbm.vb;
      rbm.c = rbm.c + rbm.vc;

      err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
    end
    
    disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Average reconstruction error is: ' num2str(err / numbatches)]);
    
  end
end   


dbnunfoldtonn.m:利用DBN的参数去初始化NN,然后用NN进行微调nn = nntrain(nn, train_x, train_y, opts);

 

function nn = dbnunfoldtonn(dbn, outputsize)
//  DBNUNFOLDTONN Unfolds a DBN to a NN
//  dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final layer of size outputsize added.
  if(exist('outputsize','var'))
    size = [dbn.sizes outputsize];
  else
    size = [dbn.sizes];
  end
  nn = nnsetup(size);
  for i = 1 : numel(dbn.rbm)
    nn.W{i} = [dbn.rbm{i}.c dbn.rbm{i}.W]; //利用DBN每层的W和c去初始化NN的参数
  end
end


CNN源码解析http://blog.csdn.net/zouxy09/article/details/9993743

http://blog.csdn.net/dark_scope/article/details/9495505

 

 

 

 


Reference:
(1) Learning Deep Architectures for AI
(2) A Practical Guide to Training Restricted Boltzmann Machines2010

 

 

本文地址:http://www.45fan.com/a/question/71586.html
Tags: 源码 DeepLearnToolbox DBN
编辑:路饭网
关于我们 | 联系我们 | 友情链接 | 网站地图 | Sitemap | App | 返回顶部