ホーム » 機械学習 » 交差検証(Cross-Validation)をPythonでやる

交差検証(Cross-Validation)をPythonでやる

こんにちは。

前回、交差検証というワードに触れ、その図を確認して終わりました。今回はその交差検証をしっかりと理解しましょう。

What is Cross-Validation?

モデルの誤差値を検証すること

この精度であっているのか?などの疑問を交差検証により解決します。

Cross-Validation consists of two main parts

  • Leave one out Cross-Validation (LOOCV)
  • K-Folds Cross Validation

という基本的な2つを紹介します。Holdout methodというものが他にありますがK-foldの下位互換と思っていいでしょう。

まずは「K-Folds Cross Validation」についてです。

  1. データセットをK分割する(k-foldと名付ける)。
  2. \forall i \in \{1, \cdots, k \}に対して
    i-foldを除いた他の全foldでモデルを学習させる
    i-foldを用いてモデルのerrorを計算する
  3. k個のerrorを平均する

図で確認しましょう。

では次に「Leave one out Cross-Validation (LOOCV)」についてです。これは上においてK=データの数としたものです。なのでerrorの平均は次のようにかけます。

    \[CV(n) = \frac{1}{n} \sum_{i=1}^n \left( y_{i} - \hat{y_{i}}^{(-i)}    \right)^2  \]

ただしnはデータ数、\hat{y_{i}}^{(-i)}i番目を除いたデータで学習をしたのちのy_{i}に対するpredictionです。しかし、こちらが利用されている例を見たことがありません。なので特に気にしなくていいでしょう。

交差検証=KFOLD

で行きましょう。では以下で実装例を見て見ましょう。ちなみにニューラルネットワーク前処理についての記事を復習しておくといいかもしれません。(ニューラルネットワークについては新しい記事を書く予定です)

話が逸れてしまいましたがこれで交差検証は大丈夫そうですね。次回からは再度、正則化についてのお話です。

でわ

READMORE


コメントする

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です