猫になりたい

コンサルのデータ分析屋、計量経済とか機械学習をやっています。pyてょnは3.7を使ってマスコレルウィンストングリーン。

バイアス バリアンストレードオフ(Bias-Variance Tradeoff)とは

機械学習でバイアス−バリアンストレードオフ(Bias-Variance Tradeoff(Decomposition))の話をよく目にするので自分なりに整理をしました。

参考文献

参考文献は以下の通りです。

Bias-Variance Tradeoffとは

Bias-Variance Tradeoffは $$ \begin{align} σ ^2 + \mathrm{Var} \big[ \hat{f}(x) \big] + ( f(x) - \mathrm{E} \big[ \hat{f}(x) \big])^2  \tag{1} \\ \end{align}$$

という様な式から判ることで、端的にいうと「単純なモデルは学習に用いるサンプルが変化しても関数の形は大きく変わらないが(低分散, Low Variance)、誤差(真の値とのズレ)が大きい(高バイアス, High Variance)。複雑なモデルは逆に、学習に用いるサンプルが変化すると関数の形が大きく変わる(高分散, High Variance)が、表現力が高いために誤差は小さくなる(低バイアス, Low Variance)。」ということです。

導出

先ずBias-Variance Tradeoffの式を導出してみましょう。次の様な回帰問題を考えます。 今、手元に母集団からサンプリングされた目的変数 yと説明変数のベクトル xがあり、真のモデルが y=f(x) + \epsilonであるとします。 \epsilonは平均0で分散 σ^2の確率変数です。私たちの目的は真のモデルをなるべく良く近似するモデル \hat{f}(x)を見つけることです。
そこで構築したモデルの良さを図る指標として二乗誤差(MSE)を使い、手元のサンプル xを用いて構築したモデル \hat{y} = \hat{f}(x)と真のモデル y=f(x) + \epsilonとのMSE \mathrm{E} [(y-\hat{y})^2]を見てましょう(\hat{}が付いた文字は推定された変数・モデルであることを意味しています)。MSEの式を展開していくと、

$$ \begin{align} \mathrm{E} \big[ (y-\hat{y} )^2 \big] &= \mathrm{E}\big[( y^2 + \hat{f}(x)^2 - 2y \hat{f}(x) ) \big] \\ &= \mathrm{E}\big[ y^2 \big] + \mathrm{E}\big[( \hat{f}(x)^2 \big] - 2y\mathrm{E}\big[( \hat{f}(x) \big] \\ \end{align} $$

ここで $$\begin{align} \mathrm{Var} \big[ x \big] = \mathrm{E} \big[ x^2 \big] - \mathrm{E} \big[ x \big]^2 \\ \Leftrightarrow \mathrm{Var} \big[ x \big] +\mathrm{E} \big[ x \big]^2 = \mathrm{E} \big[ x^2 \big] \end{align}$$ より、

$$ \begin{align} &= \mathrm{Var} \big[ y \big] +\mathrm{E} \big[ y \big]^2 + \mathrm{Var} \big[ \hat{f}(x) \big] +\mathrm{E} \big[ \hat{f}(x) \big]^2 - 2y\mathrm{E}\big[( \hat{f}(x) \big] \\ &= \mathrm{Var} \big[ y \big] + \mathrm{Var} \big[ \hat{f}(x) \big] + ( \mathrm{E} \big[ y \big]^2+\mathrm{E} \big[ \hat{f}(x) \big]^2 - 2y\mathrm{E}\big[( \hat{f}(x) \big] ) \\ &= \mathrm{Var} \big[ y \big] + \mathrm{Var} \big[ \hat{f}(x) \big] + ( y ^2+\mathrm{E} \big[ \hat{f}(x) \big]^2 - 2y\mathrm{E}\big[( \hat{f}(x) \big] ) \\ &= \mathrm{Var} \big[ y \big] + \mathrm{Var} \big[ \hat{f}(x) \big] + ( y - \mathrm{E} \big[ \hat{f}(x) \big])^2 \\ &= σ ^2 + \mathrm{Var} \big[ \hat{f}(x) \big] + ( f(x) - \mathrm{E} \big[ \hat{f}(x) \big])^2 \\ \end{align} $$*1 となり先に掲載した(1)式が導出されました。

但し1行目→2行目の式変形は、 $$ \begin{align} \mathrm{E} \big[ y \big] = \mathrm{E} \big[ f(x)\big] + \mathrm{E} \big[\epsilon \big] = \mathrm{E} \big[ f(x)\big] = f(x) \end{align} $$ 最後から2行目→最終行の式変形は、 $$ \begin{align} \mathrm{Var} \big[ y \big] &= \mathrm{E} \big[ (f(x)- \mathrm{E} \big[f(x) \big] )^2 \big] \\ &= \mathrm{E} \big[ (f(x) +\epsilon - f(x) )^2 \big] \\ &= \mathrm{E} \big[ \epsilon ^2 \big] \\ &= \mathrm{Var} \big[ \epsilon \big] + \mathrm{E} \big[ \epsilon \big] ^2\\ &= σ^2\\ \end{align} $$ によります。

次頁ではこの式の意味するところを見てみます。


スポンサーリンク

解釈

得られた (1)式を見てみましょう。 $$ \begin{align} σ ^2 + \mathrm{Var} \big[ \hat{f}(x) \big] + ( f(x) - \mathrm{E} \big[ \hat{f}(x) \big])^2 \\ \end{align}$$

この式はMSEを表すので、当然0に近いほどいいわけです。
各項を個別に見てみると、
第1項の σ^2は真のモデルに内在する撹乱項によるもので我々にコントロールは不可能です。
第2項の\mathrm{Var}  \big[ \hat{f}(x) \big]  は予測値の分散を表します。ここでの期待値はサンプルに対する期待値なので、 予測値の分散が大きいということは学習データが変わると予測値も大きく変わることを意味します。これはモデルが学習データを過学習してしまうことで予測精度が悪化するということです。
第3項の ( f(x) - \mathrm{E} [ \hat{f}(x) ])^2 はバイアス(真の値と予測値の期待値の差)の2乗を表します。予測値のバイアス f(x) -
\mathrm{E} [ \hat{f}(x) ] が0に近づけばその2乗誤差も0に近づくことがわかると思います。

例を考えてみましょう。
 xをスカラー、 \hat{f}_h(x)をh次の多項式とし、これで真の関数 f(x)を近似したいとします。hは次数を表すハイパーパラメータです。
(E.g.  h=2 \hat{f}_2(x) = \beta_0 + \beta_1x + \beta_2x^2 h=0 \hat{f}_0(x) = \beta_0 (定数関数)。)
 h=0、つまり定数関数の時予測値は xに依存しないので、予測値の分散は\mathrm{Var}  \big[ \hat{f}_0(x)^2 \big] = 0 になります。しかしこの f_0(x) = \beta_0 と真の値 f(x)  との誤差 f(x) - \mathrm{E} [ \hat{f}_0(x) ] が大きい(予測が全然当たらない)ことは明らかでしょう。 一方で h=10の様にhが大きいモデルは表現力が高いので複雑なデータにもキレイに当てはまるようにパラメータを学習できます(Low Variance)が、学習データにオーバーフィットしてしまいデータが少し変わるとその関数の見た目が大きく変わってしまいます(High Variance)。 具体的な図は、はじめてのパターン認識のp.18の図や参考文献のwikipedia の右側の図を御覧ください。

まとめ

以上の様に真のy と予測された\hat{y} の誤差をMSEで測ると、バイアスと、バリアンスと、誤差の分散に分解できることがわかりました。モデルでは改善できない誤差の分散を除くと、モデルが単純であればバイアスは小さくなるがバリアンスは大きくなり、モデルが複雑であればバリアンスは小さくなるがバイアスは大きくなります。つまり我々はMSEを最小にするような適度に複雑なモデルを見つければ良い予測を行うことが出来るわけです。 実際には例で上げたようなモデルの複雑さを決めるハイパーパラメータをCV等で探します。

*1:何故かσが\sigmaで表示されないのははてなの環境の問題……?