scikit-learn 最小二乗法
scikit-learnで最小二乗法
y=sin(x)に誤差をランダムで付加した100個のデータを準備し、
学習用:テスト用=8:2に分けて実行。
import numpy as np import matplotlib.pyplot as plt #sklearn.cross_validationは非推奨、次のヴァージョンで廃止予定 #sklearn.model_selectionへ移動 from sklearn.model_selection import train_test_split as tts from sklearn.linear_model import LinearRegression from sklearn.preprocessing import PolynomialFeatures from sklearn.pipeline import make_pipeline from sklearn.metrics import mean_squared_error as mse # y=sin(x)に誤差を付加したデータを作る data_size = 100 # サンプル数 x_orig = np.linspace(0, 1, data_size) def f(x): return np.sin(2 * np.pi * x) x = np.random.rand(data_size)[:, np.newaxis] y = f(x) + np.random.rand(data_size)[:, np.newaxis] - 0.5 x_train, x_test, y_train, y_test = tts(x, y, test_size=0.8) plt.plot(x_orig, f(x_orig), ":") plt.scatter(x_train, y_train) plt.xlim( (0, 1) )
# 最小二乗法の次数を変えたサブプロットを描画 row, column = 2, 3 # サブプロットの行数、列数 fig, axs = plt.subplots(row, column, figsize=(8, 6)) for deg, ax in enumerate(axs.ravel()): # パイプラインを作る e = make_pipeline(PolynomialFeatures(deg), LinearRegression()) # 学習 e.fit(x_train, y_train) # 予測 px = e.predict(x_orig[:, np.newaxis]) ax.scatter(x_train, y_train) ax.plot(x_orig, px, "m") ax.set(xlim=(0, 1), ylim=(-2, 2), yticks=([-1, 0, 1]), title="degree={}".format(deg)) plt.tight_layout() plt.show()
train_error = np.empty(10) test_error = np.empty(10) # 最適な次数を探す for deg in range(10): # 学習と予測 e = make_pipeline(PolynomialFeatures(deg), LinearRegression()) e.fit(x_train, y_train) # 誤差の評価 train_error[deg] = mse(y_train, e.predict(x_train)) test_error[deg] = mse(y_test, e.predict(x_test)) plt.plot(np.arange(10), train_error, "--", label="train") plt.plot(np.arange(10), test_error, label="test") # 最適次数を表示 plt.plot(test_error.argmin(), test_error.min(), "ro") plt.text(test_error.argmin()+0.5, test_error.min(), "best fit", color="r", backgroundcolor="w", fontsize=14) plt.ylim( (0, 1) ) plt.legend(fontsize=14) plt.show()
3次関数での近似が良さそう。
選んだ関数がsinなのでなんとなくうなづける。
今回のデータサイズだと、3~6次は大差なさそう。
乱数の初期値(seed)がかわると、4次や5次が最適となることもあった。
ただ、7次以上は学習用データで過学習をしてしまい、
テスト時の誤差が非常に大きくなることも。