2010.09.02
exp(x)の高速計算 ~理論編~
ロジスティック回帰やCRFなどの対数線形モデルの学習でよく出てくるのが,expの計算です.これをSSE2を使って高速化するのが,今回のテーマです.まずは,背景の理論を説明します.
まず,指数関数を2の指数関数
に変換することを考えます(なぜ2の指数関数かはいずれ分かります).
両辺の自然対数をとり,について解くと,
.IEEE754など,2を基数とした指数部を採用している浮動小数点形式では,整数
に対して
を容易に構築できるので,上式の実数解
の代わりに整数,
を用い,
の大まかな値を計算することを考えます.ただし,
は
を超えない最大の整数を表します.
さて,を
で近似したときの誤差
の範囲は,
.誤差
に
を乗じたものを,
と定義すると,
上式の両辺の指数をとると,
ここで,は
を超えない最大の整数なので,
の値域は,
.
これらのことから,は以下のステップで計算出来ることが分かります.
を計算する.
を計算する.
を定義域
において,テーラー展開もしくは最良近似多項式などで近似計算する.
- 浮動小数点形式を利用して
を構築し,
との積を計算する.
ステップ1は簡単そうに見えますが,負の数を整数にキャストすると,0に近づいてしまうことに注意が必要です.つまり,doubleからintへのキャストは,正の値は切り下げ,負の値は切り上げになってしまいます.このことに注意すると,このステップは次のように書けます(a < 0で,ifによる分岐を作らないのが賢いやり方).
a = LOG2E * x; a -= (a < 0); n = (int)a;
ステップ2はそのまま書けばOKです.
ステップ3では,を多項式に近似します.多項式近似というとテーラー展開が一般的ですが,「・・・点の周りについてテーラー展開すると」という表現の通り,ある点の近傍の近似は良いのですが,そこから離れた点では近似精度があんまりよくないという欠点があります.Cephesというライブラリでは,パデ近似(Padé approximant)を採用していますが,除算を1回だけ行う必要があります.最近のCPUは速くなったとはいえ,SSE2で倍精度のパック除算(DIVPD)をやるのに必要なレイテンシ・スループットは,70・70で,倍精度のパック乗算(MULPD)の7・6と比べると,除算が圧倒的に遅いことになります.そこで,今回は最良近似式を用います.なにやら難しそうに見えますが,Sollya というソフトウェアを使えば,簡単に最良多項式が求まります.以下の例では,
の範囲で,5次の最良多項式を求めています.
> remez(exp(x), 5, [0, log(2)]); 0.99999989311082729779536722205742989232069120354073 + x * (1.00001092396453942157124178508842412412025643386873 + x * (0.49981934577169208735732248650232562589934399402426 + x * (0.16775408658617866431779970932853611481292418818223 + x * (3.87412011356070379615759057344100690905653320886699e-2 + x * 1.185268231308989403584147407056378360798378534739e-2))))
多項式がホーナー形式(horner scheme)で出てくるので,至れり尽くせりです.これをそのままC言語に書き直せば,ステップ3の実装は完了です.
y = 1.185268231308989403584147407056378360798378534739e-2; y *= b; y += 3.87412011356070379615759057344100690905653320886699e-2; y *= b; y += 0.16775408658617866431779970932853611481292418818223; y *= b; y += 0.49981934577169208735732248650232562589934399402426; y *= b; y += 1.00001092396453942157124178508842412412025643386873; y *= b; y += 0.99999989311082729779536722205742989232069120354073;
最後のステップ4ですが,もちろんpow(2, n)なんてやってしまったら,せっかくの高速化が台無しです.浮動小数点形式IEEE754では,2を基数として指数部が構成されているので,はちょっとした工夫で一瞬で作れます.あんまり移植性がないコードですが,
は次のように作ります.
typedef union {
double d;
unsigned short s[4];
} ieee754;
ieee754 u;
u.d = 0;
u.s[3] = (unsigned short)(((n + 1023) << 4) & 0x7FF0);
こうすると,u.dでの値にアクセスできます.なぜそうなるかは,IEEE754の仕様と睨めっこすればすぐに分かると思います.あとは,u.dとyの積を計算すれば,
の計算が出来たことになります.



