ぶつりやAI

物理屋の視点から、原理原則を大事に、ディープラーニングのわかり易い説明を心がけています。

【実践基礎1】過学習とバリデーション(Validation)

前の記事  次の記事


ディープラーニングの実践に進む前に、実用上の理論やフレームワークを理解することが重要 です。特に「過学習(Overfitting)」は、ディープラーニングの学習において最も重要な問題の一つであり、適切な対策をしなければ、モデルの性能は大幅に低下してしまいます。

本記事では、過学習とは何か、そしてそれを防ぐための手法であるバリデーション(Validation) について解説します。

ディープラーニングをより基本的な理論から学びたい方は、ディープラーニング入門【初心者向け】 の「理論基礎編」を御覧ください。

最新記事の更新情報等はXでお知らせしています。

目次

フィッティング

理論のあるフィッティング

フィッティングとは、複数のデータの間の因果関係を、数式でできるだけ適切に表現することを指します。データに対して、あるパラメーターを持つ数式をフィット(最適化) させることで、関係性をモデル化します。

例えばオームの法則:

 V = RI

を考えてみましょう。 V は電圧、 R は抵抗値、 I は電流で、電圧は電流に比例します。実験によって様々な電流  I を流し、電圧  V を測定すると、データ点が得られます。これらのデータを用いて、関数  V = f(I | R) をフィッティングすると、最適な  R の値を推定できます。

過学習とバリデーション (validation)の説明。データ点がたくさんあり、物理法則に沿って正しくフィッティングした場合。

オームの法則のフィッティング

ディープラーニングのフィッティング

(教師あり)ディープラーニングも本質的にはこのフィッティングと同じです。実際、インプットとアウトプットを1変数、活性化関数を  y = x、バイアスを0に固定し、中間層をなくせば、オームの式と同じになります。

しかし、一般的にディープラーニングは、きれいな理論式では表せないような複雑な因果関係のものに適応されるため、無数のパラメーターを持たせて自由度を高めたアプローチになっています。

言い換えると、ディープラーニングは、理論的に厳密な数式によるモデル化を放棄し、あらゆるパターンを捉えるために膨大なパラメーターを持たせた手法 だといえます。

良いフィッティングとは?

ここで、「オームの法則を知らない状態で、複雑な関数を使ってデータをフィッティングするとどうなるか?」を考えてみます。ディープラーニングと全く同じ戦法です。最も簡単な方法は、多項式を使うことです。

 y = f_n = a_0 + a_1 \cdot x + a_2 \cdot x^2 + \cdots a_n \cdot x^n

上はn次の多項式と呼ばれます。さて、とりあえず十分に複雑な現象でも表せるように、80次の多項式で先ほどと同じデータ点をフィッティングしたら、以下のようになりました。

オームの法則に従うデータを多項式でフィッティング(Fitting)した図。過学習してノイズを説明してしまい、多項式が複雑な挙動を示している。

多項式によるフィッティング

いかがでしょうか?こちらの損失 (loss) は、先程の  V = f( I | R) よりも小さい値になっています。ではこの理論式は現象をよく表しているでしょうか?直観的に、これはまずいだろうと思われたのではないでしょうか?

良いフィッティングのために

過学習(Overfitting)とは?

上で感じた不安を確かめるために、もう一度同じ実験をして、先ほどフィッティングしたモデルはそのままに、新しいデータを重ねて見てみましょう。

新しい測定データ(バリデーションデータ:validation data)、フィット済みのオームの式、フィット済みの多項式。フィット済みの多項式は、新しい測定データだと損失関数の値が大きくなってしまう。フィットした関数が必要以上に複雑で、自由度がデータ点の数に近いため、ノイズを説明してしまい、モデルの汎化性能が低くなっている。

新しい測定データ、フィットしたオームの式、多項式。

先ほどとそっくりですが、データ点のばらつき方だけが変わっています。そして、損失 (loss) を比較すると、直線モデル(青)のほうが良い値を示しています。

これは、このデータ点の本質が直線であり、そこからのズレはノイズだからです。多項式は測定するたびにランダムにばらつくノイズまで説明しようとしてしまったために、測定を改めてノイズが変わった途端に損失が悪くなってしまうのです。これを「過学習(オーバーフィッティング)」といいます。

このように、モデルがたくさんのパラメーターを持っていて、極めて複雑な関係を説明できてしまう、というのは、諸刃の剣でもあるのです。過学習の問題点を改めて整理すると、次のようになります:

  • フィットしたデータの損失(loss)は非常に小さくなるが、関数が無意味に複雑になる
  • 新しいデータの損失(loss)が大きくなり、よく説明できない

つまり、フィットしたデータのノイズまで説明してしまい、本来の規則性を捉えられなくなってしまう のです。実はこの実験で用いた2つ目の新しいデータが、次に説明する「バリデーションデータ」に相当します。

バリデーション(Validation)とは?

過学習を監視するための基本的な手法が、バリデーション(Validation)です。ディープラーニングでは、モデルの学習を進める際に、フィッティングに用いないデータを用意し、学習の間に損失をモニタリングします。

このデータを バリデーションデータ(Validation Data)と呼びます。したがって、学習に必要なデータは以下の2つになります。

  • 訓練データ(Training Data): モデルを学習させるために使用するデータ
  • バリデーションデータ(Validation Data): 学習中のモデルの性能を確認するためのデータ(モデルのパラメーター更新には使わない)

訓練データでパラメーターを更新するたびに、バリデーションデータで損失を計算し(Validation Loss)、本当の性能をモニタリングするのです。例えば Validation Loss が上昇し始めたら学習を自動で止める「Early Stopping」 という手法などに使われたりします。

過学習を防止するには?

学習中、validation loss などを監視する、あるいは記録することで、過学習が起きたかどうかをいち早く検知することができます。もし過学習を防ぐには様々な手法がありますが、その中でも最も基本的なのは、データを増やすこと、そしてデータの多様性を確保することです。

先ほど多項式でフィッティングをしたら過学習が起きたのは、自由度(パラメータの数)が高すぎたからでした。じつは多項式のパラメータの数がデータ数と同じだと、すべてのデータ点を通るようにフィッティングできてしまいます。このように、自由度が高すぎるかどうかは、データの数で決まります。データが十分に多ければ、過学習のリスクは減ります。

しかしデータを増やすことで過学習を防げるのは、多様性が確保される前提です。例えば画像認識でネコを判別するモデルを考えましょう。もし特定の種類のネコの画像だけでトレーニングをしたら、他の種類のネコは認識できなくなるかもしれません。これはデータにバイアスがあれば、結局特定のパターンに過剰にフィットしてしまいます。こういう状況を、データにバイアスがある、といいます。

まとめ

  • フィッティングとは、データの間の関係を数式で表現すること
  • 過学習(Overfitting)は、モデルがノイズまで学習してしまう現象
  • バリデーション(Validation)によって、過学習をモニタリングできる

過学習を避けるための手法は他にもたくさんありますが、それらは別の記事で詳しく解説していきます!

おすすめカテゴリー

buturiya-ai.hatenablog.com

buturiya-ai.hatenablog.com


前の記事: 誤差逆伝搬とは?勾配計算による学習の仕組み

次の記事: ミニバッチとエポック