「Deep Learning Javaプログラミング 深層学習の理論と実装」memo

概要

Deep LearningのアルゴリズムをJavaで実装する書籍。

ゼロからの実装とあるが、書籍に記載されているコードのみでは完成しない。
ソースコードをダウンロードすることが前提。
写経する事での学ぶが出来ない。
コードリーディングによる解説になる。
ニューラルネットワークの基本であれば「ニューラルネットワーク自作入門」が図解も含めて分かりやすい。
但しPythonで書かています。

memo

  • SingleLayerNeuralNetworks
    4箇所書籍に記載がない。
    utilはダウンロードのみ。

    package dlwj;
    
    import java.util.Random;
    
    public class Perceptrons {
        
        int nIn; // 書籍には出てこない。
        double[] w; // 書籍には出てこない。
    
        public Perceptrons(int nIn) {
            this.nIn = nIn;
            w = new double[nIn];
        }
        
        public int train(double[] x, int t, double learningRate) {
            int classified = 0;
            double c = 0;
            
            // データが正しく分類されているかチェック。誤分類されているときのみ、勾配降下法を適用
            for(int i = 0; i < nIn; i++) {
                c += w[i] * x[i] * t;
            }
            
            if(c > 0) {
                classified = 1;
            }
            else {
                for(int i = 0; i < nIn; i++) {
                    w[i] += learningRate * x[i] * t;
                }
            }
            
            return classified;
        }
        
        public int predict(double[] x) {
            double preActivation = 0.;
            
            for(int i = 0; i < nIn; i++) {
                preActivation += w[i] * x[i];
            }
            
            return ActivationFunction.step(preActivation); // 呼び方変えています。
        }
    
        public static void main(String[] args) {
            final Random rng = new Random(1234); // 書籍には出てこない。
            
            final int train_N = 1000;
            final int test_N = 200;
            final int nIn = 2;
            
            double[][] train_X = new double[train_N][nIn];
            int[] train_T = new int[train_N];
            
            double[][] test_X = new double[test_N][nIn];
            int[] test_T = new int[test_N];
            int[] predicted_T = new int[test_N];
            
            final int epochs = 2000;
            final double  learningRate = 1;
            
            GaussianDistribution g1 = new GaussianDistribution(-2.0, 1.0, rng);
            GaussianDistribution g2 = new GaussianDistribution(2.0, 1.0, rng);
    
            // 書籍には出てこない。↓
            for(int i = 0; i < train_N/2 - 1; i++) {
                train_X[i][0] = g1.random();
                train_X[i][1] = g2.random();
                train_T[i] = 1;
            }
            for(int i = 0; i < test_N/2 - 1; i++) {
                test_X[i][0] = g1.random();
                test_X[i][1] = g2.random();
                test_T[i] = 1;
            }
    
            // 書籍には出てこない。
            for(int i = train_N/2; i < train_N; i++) {
                train_X[i][0] = g2.random();
                train_X[i][1] = g1.random();
                train_T[i] = -1;
            }
            for(int i = test_N/2; i < test_N; i++) {
                test_X[i][0] = g2.random();
                test_X[i][1] = g1.random();
                test_T[i] = -1;
            }
            // 書籍には出てこない。↑
            
            Perceptrons classifier = new Perceptrons(nIn);
    
            int epoch = 0; // 書籍には出てこない。
            while(true) {
                int classified_ = 0;
                
                for(int i = 0; i < train_N; i++) {
                    classified_ += classifier.train(train_X[i], train_T[i], learningRate);
                }
                
                if(classified_ == train_N) {
                    break;
                }
                
                epoch++;
                if(epoch > epochs) {
                    break;
                }
            }
            
            for(int i = 0; i < test_N; i++) {
                // テストデータの分類結果を配列に格納
                predicted_T[i] = classifier.predict(test_X[i]);
            }        
            
            int[][] confusionMatrix = new int[2][2];
            double accuracy = 0.;
            double precision = 0.;
            double recall = 0.;
            
            for(int i = 0; i < test_N; i++) {
                if(predicted_T[i] > 0) {
                    if(test_T[i] > 0) {
                        accuracy += 1;
                        precision += 1;
                        recall += 1;
                        confusionMatrix[0][0] += 1;
                    }
                    else {
                        confusionMatrix[1][0] += 1;
                    }
                }
                else {
                    if(test_T[i] > 0) {
                        confusionMatrix[0][1] += 1;
                    }
                    else {
                        accuracy += 1;
                        confusionMatrix[1][1] +=1;
                    }
                }
            }
            
            accuracy /= test_N;
            precision /= confusionMatrix[0][0] + confusionMatrix[1][0];
            recall /= confusionMatrix[0][0] + confusionMatrix[0][1];
            
            System.out.println("----------------------------");
            System.out.println("Perceptrons model evaluation");
            System.out.println("----------------------------");
            System.out.printf("Accuracy:  %.1f %%\n", accuracy * 100);
            System.out.printf("Precision: %.1f %%\n", precision * 100);
            System.out.printf("Recall:    %.1f %%\n", recall * 100);
        }
    }
    
  • 多クラスロジスティック回帰以降は書籍がどこのコードなのか、ソースを見ないと分からない。
  • 断片コードの説明であって「ゼロから実装」感はない。

memo

Posted by shi-n