アシアルブログ

アシアルの中の人が技術と想いのたけをつづるブログです

交差エントロピーについて基本的な考え方をまとめてみました

こんにちは。 Monaca開発チームの内藤です。

交差エントロピーという関数をご存知でしょうか? 機械学習における分類問題やロジスティック回帰などで評価関数として利用されることの多い関数で、次のような形をしています。

 H(p,q) = - \sum_{x \in X} p(x) \log q(x)

でもこれ、なんかちょっと、ヘンテコな形の式ですよね。 log とか出てきますし、、。あまり普段は見かけない形の関数です。

よく使われる関数であるわりに、なかなかまとまった解説がないようなので、今回はこの交差エントロピーについて、基本的な事項についてまとめてみました。

(交差でない普通の)エントロピーについて

普通のエントロピーは、情報量とも呼ばれていて

 H(p) = - \sum_{x \in X} p(x) \log p(x)

という形をしています。これは、(定数倍を除いて)データの平均量を表しています。(なお、この定義は情報理論で定義されているエントロピーであり、単位はありません。エントロピー自体はもともと統計力学で最初に導入されたもので、こちらはボルツマン定数がかかるため物理的な単位を持っています)

まず、この普通のエントロピーがどうしてこんな式になっているかについて、調べてみましょう。

ビット数と、表現出来るパターン数

エントロピー自体は、データを内部的に二進数で扱うかどうかとは無関係なのですが、二進数から考えるのが簡単だと思うので、ここではまず二進数でデータを表現することから考えてみます。

二進数でデータを表現する場合、n ビットで何パターンを表現できるかを考えてみると、1ビットで2パターン、2ビットで4パターン、3ビットで8パターンというように、1ビット増えるごとに表現できるパターンが2倍になるので、nビットであれば、

 2^n 

パターンを表現出来ることが分かります。

ということは、逆に、kパターンを何ビットで表せるかを考えると、

 \log_{2} k

ビット必要であることが分かります。(これはk=6などのときに整数になりませんが、今は気にしないことにします)

簡単な例として、ABCDという4パターンを表現したい場合を考えましょう。わかりやすいように、A = 晴れ、B = 曇り、C = 雨、 D = 雪と考えてくれても良いです。天気を例としたこちらの記事が非常に分かりやすかったので、参考とさせてもらっています。 cookie-box.hatenablog.com

そうすると、これは2ビットで表すことができます。

 00 = A  (晴れ) 
 01 = B  (曇り) 
 10 = C  (雨) 
 11 = D  (雪) 

のような感じで対応させれば良いでしょう。

さて、このデータをn個分(あるいは、天気で言うと、n日分)欲しいとすると、合計で何ビットのデータになるでしょうか? ここで、A,B,C,Dは、どれも同じくらいの頻度で発生するものとします。 つまり、A, B, C, Dのそれぞれの発生する確率p(X)は

 p(A) = 1/4   (= 晴れの確率) 
 p(B) = 1/4   (= 曇りの確率) 
 p(C) = 1/4   (= 雨の確率) 
 p(D) = 1/4   (= 雪の確率) 

と考えています。

この場合、1個分(1日分)を表現するのに2ビット必要なので、n個分(n日分)を表現するのであれば、2nビットが必要になりますね。

では、次に、そもそもCとDを区別しないで、どちらもCとしたらどうでしょうか? このとき、パターンはA, B, Cの3種類になりますが、CはAやBよりも頻度が2倍になっています。

 p(A) = 1/4  (= 晴れの確率) 
 p(B) = 1/4  (= 曇りの確率) 
 p(C) = 1/2  (= 雨または雪の確率) 

このとき、AやBは以前と同様、表現するのに2ビット必要です。 けれども実は、Cについてだけは、表現するのに1ビットで十分であり、かつ、この時もっとも効率が良くなります。実際に、もとのビットでの対応を思い出せば、

 00 = A
 01 = B
 10 = C
 11 = C

となっているので、AやBは2ビット必要ですが、Cについては、上位1ビットが1であれば、その時点で下位1ビットを参照することなくCであることがわかります。この下位1ビットは、もともとは、CとDを区別するのに必要だった1ビットだから、今は不要になったという訳です。

しかも、Cは、もともとはCとDだったものなので、AやBより出現頻度(=出現確率)が高いのですが、高い頻度で出現するCを1ビットだけで表せるということで、データ効率が最も良くなっていることになります。

これを一般化すると、k種類のパターンを表現するのには、

  \log_{2} k

ビットのデータが必要なのだけれども、そのうちの k_C 個のパターンをすべてCとして同一視する場合、「Cのときは」 k_C個のパターンを識別するのに必要だったビット数

 \log_{2} k_C 

を節約することが出来るということです。 つまり、「Cのときは」表現するのに

 \log_{2} k - \log_{2} k_C

ビット必要で、「C以外のときは」以前と同様に

 \log_{2} k

ビットが必要になるということですね。

より一般化してみる

これをもう少し一般化して考えると、いま、パターンとしてはx_1, x_2, ... , x_mというmパターンあるのだけれど、その重複数(あるいは、発生頻度を整数比で表したもの)がそれぞれ k_1, k_2, ..., k_mだとすると、まずは全部で

 k = k_1 + k_2 + ... + k_m 

パターンを完全に区別すると考えれば、必要なビット数は

 \log_{2} k

ビットになるけれども、そのうちのx_1のときはk_1個を区別しないので、k_1個を区別するのに 必要だったビット数

 \log_{2} k_1

を節約出来ますし、また、x_2のときはk_2個を区別しないので、k_2個を区別するのに 必要だったビット数

  \log_{2} k_2

を節約出来ると考えていくと、結局、平均的に必要なビット数 Mは

 M = \log_{2} k - \sum_{i=1,..,m} p(x_i) \log_2 k_i 

となります。但しここで、

 p(x_i) = \frac{ k_i }{ k } 

を x_i の出現確率(つまり頻度)としました。

まとめると

 M = \log_{2} k - \sum_{i=1,...,m} p(x_i) \log_2 k_i 

です。

さらに、k_i を p(x_i) で書き直すと、

 M = \log_{2} k - \sum_{i=1,...,m} p(x_i) \log_{2} \{ p(x_i) k \} 
 = - \sum_{i=1,...,m} p(x_i) \log_{2} p(x_i) 

となります。

添え字 iの代わりに、x そのものを使えば、結局これは(Mが明示的に確率pに依存するためpの関数として記述しなおせば)

 M(p) = - \sum_{x \in X} p(x) \log_{2} p(x) 

となります。さらに底を変換して、logの底を自然対数の底 e(ネイピア数) にするために

 \alpha = \log_{2} e 

を導入すると、結局これは

 M(p) = - \alpha \sum_{x \in X} p(x) \log p(x) 

となります。

これは、最初に導入したエントロピーの式

 H(p) = - \sum_{x \in X} p(x) \log p(x)

と比べて、定数α倍だけ異なりますが、基本的には同じ式になります。

つまり、結論としては、このエントロピーというのは、データのパターンの出現確率(もしくは頻度)に偏りがある場合、それを表現するのに必要な平均ビット数(に比例した)関数である、ということがわかります。だからこそ「情報量」と考えることが出来るのですね。

交差エントロピー

もう一度、平均ビット数 M(p) の式をみると、

 M(p) = - \alpha \sum_{x \in X} p(x) \log p(x) 

を良くみましょう。これを書き直して

 M(p) = \sum_{x \in X} p(x) \times  \left( - \alpha \log p(x) \right) 

としてみます。これと、平均値の定義式

 M(p) = \sum_{x \in X} (xとなる確率) \times (xを表現するのに必要なビット数) 

を比較すれば、

 (xを表現するのに必要なビット数) = - \alpha  \log p(x) 

であることが分かります。

このことは、(xを表現するのに必要なビット数)を、確率 p に依存して決めることで、ビット数を節約して、効率の良い平均ビット数を得ることが出来ることを意味しています。

では、もしも、(xを表現するのに必要なビット数)を、確率pではなく、別の確率qで設定したらどうなるでしょうか? これは、もともとの平均ビット数 M(p)という関数を、人為的に変形して、pとqに別々に依存する関数M(p, q)としたものです。

 M(p, q) = - \alpha \sum_{x \in X} p(x) \log q(x) 

この(人為的に定義された)平均ビット数M(p, q)は、明らかに効率的ではありません。xが発生する確率p(x)と、xを表現するのに必要なビット数を決める確率q(x)を別々に設定してしまったからです。

逆に言うと、M(p, q)が最も効率が良くなる(M(p, q)が最小になる)のは、p = qのときなワケですから、M(p, q)を「pとqがどれくらい異なっているか?」の目安として使うことが出来ます。M(p, q)が小さい程、pとqは似ていることになります。

交差エントロピー H(p, q)は、上記の(人為的に定義された)平均ビット数 M(p, q)と定数倍しか違わず、p=qのときに普通のエントロピー H(p)となるように定義されたものです。

 H(p, q) = - \sum_{x \in X} p(x) \log q(x) 

繰り返しになりますが、交差エントロピーH(p, q)は、M(p, q)と定数倍しか違いませんから、結局、これが小さくなるほど、関数p(x)と関数q(x)は近いとみなすことが出来ます。