機械学習記6日目つづき ~ 回帰の実装

 回帰分析は、「結果データ」と「結果に影響を及ぼすデータ」の関係性を統計的に求める手法です。結果データを「目的変数」、結果に影響を及ぼすデータを「説明変数」と呼びます。

 scikit-learnには、この回帰分析を行うモデルとして「線形回帰モデル」が搭載されています。線形回帰とは、すべてのデータにできるだけフィットする線を引くことで、予測を行うモデルです。線の式は「最小2乗法」を用いて求めます。

 また、説明変数が1つの場合を「単回帰分析」、複数の場合を「重回帰分析」と呼びます。scikit-learnでは、どちらの単回帰分析も行うことができます。ここでは、単回帰分析を行ってみます。

 

▶単回帰分析

 scikit-learnでは、「linear_model.LinearRegression」というクラスを用いることで、線形回帰モデルを作ることができます。「linear_model.LinearRegression」を使ってモデルを生成するときの書式は、図1のようになります。

 実行はJupyter Notebookで行いますので、起動しておきます。最初に必要なライブラリをインポート(import)します。おなじみのNumPy、matplotlib、Pandasの他に、「sklearn」の「linear_model」をインポートします。

f:id:hackU0001:20190225230251p:plain

図1.linear_model.LinearRegression

f:id:hackU0001:20190225231943p:plain

 

pandasのDataFrame形式で、X軸とY軸のデータをわかりやすい感じで用意します。

f:id:hackU0001:20190225232754p:plain

とりあえず、どのような散布図になるかmatplotlibでプロットしてみます(図2)。

f:id:hackU0001:20190225233325p:plain

線形回帰は、すべてのデータにできるだけフィットする線を引くことで、予測を行うモデルです。まずは、目測で回帰直線を引いてみます。大体、こんな感じになるはずです(図3)。

 XとYのデータを公式に当てはめると回帰係数や切片が求めるので、そこから計算して回帰直線を引くことができますが、その計算をscikit-learnのLinearRegressionで行います。

 では、scikit-learnのLinearRegressionで、線形回帰モデルを生成します。

f:id:hackU0001:20190225234849p:plain

図2.データの散布図

 

 

f:id:hackU0001:20190225235457p:plain

図3.目測で引いた回帰直線

 

f:id:hackU0001:20190225235836p:plain

 

 作成したモデルに、XとYのデータを渡します。学習(機械学習では「訓練」と呼ぶ)には、fit関数を使います。実行すると、どのようなパラメータで線形回帰モデルが学習したか確認の表示が出力されます。

f:id:hackU0001:20190226000646p:plain

 

 これで、XとYのデータを学習した線形回帰モデルが完成しました。

 NumPyのarange関数を使い、Xの最小値から最大値まで0.01刻みの配列(array([0.], [0.01],  [0.02], …]))を生成します。

f:id:hackU0001:20190226001519p:plain

 

 NumPyのnewaxis関数を使い、X座標を2次元配列( array ([ [0. ], [0.01], [ 0.02], …])に変換します。

f:id:hackU0001:20190226002527p:plain

 

 X座標を線形回帰モデルのpredict関数に渡して、Y座標を作ります。

f:id:hackU0001:20190226002843p:plain

 

 では、元のデータと作成した回帰直線を一緒に表示してみます(図4)。

f:id:hackU0001:20190226003528p:plain

 

 

f:id:hackU0001:20190226003813p:plain

図4.回帰直線の表示

 

 

f:id:hackU0001:20190226004617p:plain

 

f:id:hackU0001:20190226005029p:plain

図5.グラフタイトルの追加

うまくいきました!

もし、導き出した回帰係数や切片が知りたいときは、係数はcoef_属性に、切片はintercept_属性に格納されています。

f:id:hackU0001:20190226005808p:plain

 

f:id:hackU0001:20190226010201p:plain

図6.回帰グラフ

 機械学習の勉強をはじめてから6日目でやっと回帰までたどり着きました。