機械学習記9日目 ~ 機械学習モデルの構築
準備が整ったので、k近傍法でクラス分類を行う学習モデルを構築します。k近傍法のアルゴリズムは、scikit-learnのKNeighborsClassifierクラスに実装されています。書式は図1のようになります。
それでは、訓練データから学習モデルを構築したいと思います。
まず、KNeighborsClassifierクラスをインポートします。
KNeighborsClassifierオブジェクトを生成します。kの値は、1にしています。
これで、学習モデルが生成されました。
次に、訓練データをfit関数を使い、読み込ませて学習させます。
KNeighborsClassifierオブジェクトを生成したときのパラメータが表示されます。n_neighborsが1以外は、すべてデフォルトのなのが分かります。これで学習は完了しました。
それでは、試しに学習データの中からデータを1つ与えて、正しく学習しているか確認してみます。
新しいデータを作ります。要素は、setosaのがく片の長さ、がく片の幅、花弁の長さ、花弁の幅です。
では、クラスを予測させます。
予測の結果はprediction1に返ってきます。学習モデルが予想したラベルを表示してみます。
setosaのデータを与えたので、正解です。
別のデータで確認します。
こちらも正解でした。ただし、これは学習データからデータを抜き出して渡しているので、ほぼほぼ正解になるはずです。
正しく動作しているようなので、評価用データを使って正解率を出します。
▶モデルの評価
先ほど分割したテストデータを使って、作成した学習モデルがどのくらいの精度を持っているかを評価します。予測結果と評価用データの比較により、正解率を算出します。評価用データを学習モデルにセットします。
評価用データに対して予測した品種ラベルは、次のようになりました。
では、正解ラベルはどうなっているか、もう一度確認してみると。。。
一見わかりませんが、1カ所間違えています。
それでは、精度を計算してみたいと思います。
つまり、この学習モデルでは、テストデータに対する精度は約97%の正解率ということです。
そこで、kの値を3(n_neighbors=3)にして学習モデルを作成し、学習データを読み込めましたが、正解率は同じでした。そもそもk=1で97%の正解率なので、kを増やしても大きな違いはないということのようです。
なるほど、機械学習のデータとプログラムは、このように作るんですね。
k近傍法はわかりやすいアルゴリズムですが、大量で複雑なデータには向かない気がします。明日は、別のアルゴリズムを試してみたいと思います。