読者です 読者をやめる 読者になる 読者になる

ニューラルネットワーク実装

Coursera機械学習コースでニューラルネットワークを学んだので、習作としてC#で実装してみた。

多層パーセプトロン対応。Classification専用(シグモイド関数をベタ書きしてるので)。
ソースコードはここにおいてある。


実装の概略
ネットワーク各層の関数を、行列型の配列Thetas[]で保持している。

public class NeuralNetwork
{
    public Matrix[] Thetas { get; set; }
    public int NumLayers { get; set; }
    public int[] NumNeurons { get; set; }

これに、(RandomizeThetas()で初期化したあと)関数Learn()に学習用データとパラメタを渡せば最急降下法で学習を開始する。

public void Learn(Matrix[] input, Matrix[] output, double alpha, double lambda, int maxItr)
{
    while (maxItr > 0)
    {
        var grad = GetGrad(Thetas, input, output, input.Length, lambda);
        UpdateThetas(input.GetLength(0), grad, alpha, lambda, input, output);
        maxItr--;
    }
}

上記のGetGrad()が、各ノードの偏微分を返す関数で、それを使ってUpdateThetas()内でThetas[]を更新する。

GetGrad()の実装も教科書通り。

public Matrix[] GetGrad(Matrix[] thetas, Matrix[] input, Matrix[] output, int m, double lambda)
{
    var L_deltas = InitializeLDeltas();

    for (int idx = 0; idx < input.GetLength(0); idx++)
    {
        Matrix[] z = null;
        var values = ForwardProp(input[idx], thetas, ref z);
        BackProp(L_deltas, output[idx], values, z);
    }

    var grad = GetGrad(thetas, L_deltas, input.Length, lambda);
    return grad;
}

ForwardProp()は名前の通りフォワードプロパゲーションを、BackProp()はバックプロパゲーションをする関数。
ForwardProp()は各層の中間値を返すので、それを使ってBackProp()で後ろの層から変化量もとめ、L_deltasを更新している。
全入力のループが終わったら、GetGrad()がL_deltas
を使って偏微分を求める。

偏微分を求める実装がややこしいのだが、コースで推奨されているとおりに、近似値を求めて比較する方法でデバッグした。
{\theta}-eと{\theta}+e(eは小さい数)のそれぞれでコストを求め、その傾きを使うというもの。

Matrix GetNumericGradient(NeuralNetwork nn, Matrix[] thetas, int idx, Matrix[] input, Matrix[] output)
{
    var t = thetas.ToArray();

    const double e = 1e-4;
    var ret = new Matrix(t[idx].RowNum, t[idx].ColNum);
    var diffMat = new Matrix(t[idx].RowNum, t[idx].ColNum);

    for (int row = 0; row < diffMat.RowNum; row++)
    {
        for (int col = 0; col < diffMat.ColNum; col++)
        {
            var orgValue = t[idx][row, col];

            t[idx][row, col] = orgValue - e;
            var loss1 = nn.J(thetas, input, output);

            thetas[idx][row, col] = orgValue + e;
            var loss2 = nn.J(thetas, input, output);

            ret[row, col] = (loss2 - loss1) / (2.0 * e);

            thetas[idx][row, col] = orgValue;
        }
    }

    return ret;
}
var numgrads = GetNumericGradients(nn, nn.Thetas, x, y);
var grads = nn.GetGrad(new Matrix[] { theta1, theta2 }, x, y, x.Length, lambda);

2通りの結果がほぼ同じなので、実装が合っていることを確認できる。


使ってみる

こちらのサイトにある手書き数字の学習と認識に挑戦してみる。

static void NeuralNetworkTest_Example()
{
    Console.WriteLine("**Neural Network test (test data from: http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/)**");

    var nn = new NeuralNetwork(3, new int[] { 16, 16, 10 });
    nn.RandomizeThetas();

    //train
    var trainingDataTuple = GetNNData(TraininData.Data);
    var inputs = trainingDataTuple.Item1;
    var outputs = trainingDataTuple.Item2;

    //test
    var testDataTuple = GetNNData(TestData.Data);
    var test_inputs = testDataTuple.Item1;
    var test_outputs = testDataTuple.Item2;

    var alphas = new double[] { 5, 7 };
    var lambdas = new double[] { 0.03, 0.1, 0.3, 1.0 };

    var parameters = nn.FindParameters(alphas, lambdas, inputs, outputs, test_inputs, test_outputs, 300);
            
    Console.WriteLine("Chosen params - alpha:{0}, lambda:{1}", parameters.Item1, parameters.Item2);

    nn.Learn(inputs, outputs, parameters.Item1, parameters.Item2, 300);

    var result = nn.GetResult(test_inputs, test_outputs);
    var err = result.Item1;
    var total = result.Item2;

    Console.WriteLine(string.Format("Total:{0}, Error:{1}, Success Rate: {2}", total, err, (double)(total - err) / (double)total));            
}

トレーニング用とテスト用のデータがすでに用意されていたので、そのまま利用した。
好ましいパラメタを選ぶために、本来ならクロスバリデーション用として別のデータを用意すべきだが、今回はテスト用を流用した。

3層ネットワークで、イテレーションは適当に300回。
合計5回の学習で1時間以上かかってしまった。

f:id:yambe2002:20160324232921p:plain

認識率は約90%だった。何も工夫しないとこんなものだろうか。または、イテレーション回数が少なくて収束しきっていないのかもしれない。
このあたりは、別にちゃんとしたライブラリを使っていろいろ試してみようと思う。

一から自分で組んでみてかなり勉強になったが、本格的に実装するならライブラリを使わないと難しいと改めて実感。(数学的に最適化されてないので実行がかなり遅いし、デバッグも難しい)