Home > 9月, 2010

2010.09.02

exp(x)の高速計算 ~実験編~

最後に,e^xの計算がどのくらい高速化できたのか,実験してみました.アルゴリズムとして,以下のものを比較しました.

  • libcのexp関数
  • パデ近似を用いたもの(cephesの実装に近いです)
  • 5次~13次でテーラー展開を0を中心として行ったもの
  • 5次~13次の最良多項式近似を[-0.5, +0.5]の範囲で行ったもの
  • 5次~13次の最良多項式近似を[0, \log 2]の範囲で行ったもの
  • 5次~11次の最良多項式近似を[0, \log 2]の範囲で行ったものをSSE2で実装したもの

SSE2の実装で11次で打ち止めしているのは,11次と13次で演算精度の差が殆どないことが分かり,単に実装するのが面倒くさかったからだけです.0を中心としたテーラー展開や,最良多項式を[-0.5, +0.5]の範囲で構成する手法は,nの値を切り下げではなく,四捨五入(つまり0.5を足してから切り下げる)で求めます(Wapitisse_mathfun.hのexp計算ルーチンは四捨五入を使っています).

実験では,0を平均としたガウス分布で乱数を10,000,000個発生させ,その値のexpを全て計算するのに要する時間と,libcの演算結果を正しいものと仮定したときの誤差を測定しました.実験環境は,Intel(R) Xeon(R) CPU 5140 (2.33GHz) 上でDebian Linuxを動作させたアプリケーション・サーバーです.gccのバージョンは4.1.2で,コンパイルオプションは「-O3 -fomit-frame-pointer -msse2 -mfpmath=sse -ffast-math -lm」です.実験に用いたコードは,こちらからダウンロードもしくは閲覧できます.

横軸に計算時間,縦軸に誤差のRMS(二乗平均平方根)をプロットしたものを示します.Performance of fast exp computation

縦に一本線が入っているのが,libcの計算時間です.これを基準に誤差を測定しているので,エラーは0です.この線よりも左側に入っていないと,高速化の意味が無いことになります.各アルゴリズムは,左上が5次の多項式近似で,右下に向かって7次,9次,11次,13次と,精密かつ遅くなっていきます.だいたい,10^{-16}あたりで計算誤差の改善が止まっています.これは倍精度浮動小数点の有効桁数である16桁と大体一致しています.

計算精度に着目すると,テイラー展開 < 最良近似式 [-0.5, +0.5] < 最良近似式 [0, log 2] という順で,どの近似式も次数を上げていくと,10^{-16}付近まで誤差が改善します.ただ,アルゴリズムをCで実装してしまうと,libcと比較しても大した速度向上が得られません.5次まで近似精度を落としたとしても,2倍程度しか速くなりません.libcのexpの実装も相当に高速化されているので,一筋縄ではいかないようです.

これに対し,アルゴリズムをSSE2で実装したものは,4倍~8倍程度高速化されています.しかも,計算精度はSSE2化していないものとぴったり一致しているので,計算精度を犠牲にすることなく,高速化が達成できたことになります.e^xを大量に計算する状況では,11次の最良近似式を用いたSSE2ルーチンを使えば,演算精度を落とすことなく高速化できますし,近似式の精度が低くて良ければ,次数を下げてさらに高速化できます.次期バージョンのCRFsuiteでは,その他の改良も含めて,SSE2ルーチンを搭載する予定です.

exp(x)の高速計算 ~SSE2実装編~

exp(x)の高速計算 ~理論編~ の内容を元に,SSE2で実装してみます.ここからは,SSE2に関する基礎知識が必要になるのですが,すべて書くのは面倒なので,簡単にまとめると,

  • SSE2では,倍精度浮動小数点の2個の値に対して同じ演算をすると速い.
  • 現在のIntel CPUでは,SSE専用のレジスタ(128bit)が8個ある.
  • 加減乗除の算術演算や,ビットシフトなど,殆どの命令がSSE専用に準備されている.
  • メモリとSSE専用のレジスタ間でデータ転送を行うときは,メモリアドレスが16バイトでアライメントされている必要がある(厳密には「すべき」という言い方が正しいが).
  • C言語のソースコードの中でSSE用のコードを手っ取り早く書く方法は,コンパイラのintrinsicを使う方法.Microsoft Visual C++とgccでほぼ完全な互換性があるのが嬉しい.
  • コンパイラintrinsicは,アセンブラではないので,自分が意図したとおりの最適化コードが生成されるとは限らない.一応,コンパイラが最適化して速いコードを生成することになっているが,実際はそう上手くいかないことの方が多いので,コンパイル後のアセンブラコードを確認する必要がある.
  • SSEは,かなり昔にへるみさんのページで勉強しました.大体感触を掴んだら,MSDNのドキュメントなどを見ながら,コンパイラintrinsicを書けば,そんなに難しくないと思います.
  • SSEの命令を並べるときは,命令のロード・演算・ストアのパイプライン処理を頭の中で考えるべき.

だいたいこんな所ですかね.先ほどのexp(x)の計算ルーチン(5次最良多項式近似)をSSE2で書くと,こういうコードになります.

#include < emmintrin.h >

#ifdef _MSC_VER
#define MIE_ALIGN(x) __declspec(align(x))
#else
#define MIE_ALIGN(x) __attribute__((aligned(x)))
#endif

#define CONST_128D(var, val) \
    MIE_ALIGN(16) static const double var[2] = {(val), (val)}

void remez5_0_log2_sse(double *values, int num)
{
    int i;
    CONST_128D(one, 1.);
    CONST_128D(log2e, 1.4426950408889634073599);
    CONST_128D(c1, 6.93145751953125E-1);
    CONST_128D(c2, 1.42860682030941723212E-6);
    CONST_128D(w5, 1.185268231308989403584147407056378360798378534739e-2);
    CONST_128D(w4, 3.87412011356070379615759057344100690905653320886699e-2);
    CONST_128D(w3, 0.16775408658617866431779970932853611481292418818223);
    CONST_128D(w2, 0.49981934577169208735732248650232562589934399402426);
    CONST_128D(w1, 1.00001092396453942157124178508842412412025643386873);
    CONST_128D(w0, 0.99999989311082729779536722205742989232069120354073);
    const __m128i offset = _mm_setr_epi32(1023, 1023, 0, 0);

    for (i = 0;i < num;i += 4) {
        __m128i k1, k2;
        __m128d p1, p2;
        __m128d a1, a2;
        __m128d xmm0, xmm1;
        __m128d x1, x2;

        /* Load four double values. */
        x1 = _mm_load_pd(values+i);
        x2 = _mm_load_pd(values+i+2);

        /* a = x / log2; */
        xmm0 = _mm_load_pd(log2e);
        xmm1 = _mm_setzero_pd();
        a1 = _mm_mul_pd(x1, xmm0);
        a2 = _mm_mul_pd(x2, xmm0);

        /* k = (int)floor(a); p = (float)k; */
        p1 = _mm_cmplt_pd(a1, xmm1);
        p2 = _mm_cmplt_pd(a2, xmm1);
        xmm0 = _mm_load_pd(one);
        p1 = _mm_and_pd(p1, xmm0);
        p2 = _mm_and_pd(p2, xmm0);
        a1 = _mm_sub_pd(a1, p1);
        a2 = _mm_sub_pd(a2, p2);
        k1 = _mm_cvttpd_epi32(a1);
        k2 = _mm_cvttpd_epi32(a2);
        p1 = _mm_cvtepi32_pd(k1);
        p2 = _mm_cvtepi32_pd(k2);

        /* x -= p * log2; */
        xmm0 = _mm_load_pd(c1);
        xmm1 = _mm_load_pd(c2);
        a1 = _mm_mul_pd(p1, xmm0);
        a2 = _mm_mul_pd(p2, xmm0);
        x1 = _mm_sub_pd(x1, a1);
        x2 = _mm_sub_pd(x2, a2);
        a1 = _mm_mul_pd(p1, xmm1);
        a2 = _mm_mul_pd(p2, xmm1);
        x1 = _mm_sub_pd(x1, a1);
        x2 = _mm_sub_pd(x2, a2);

        /* Compute e^x using a polynomial approximation. */
        xmm0 = _mm_load_pd(w5);
        xmm1 = _mm_load_pd(w4);
        a1 = _mm_mul_pd(x1, xmm0);
        a2 = _mm_mul_pd(x2, xmm0);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w3);
        xmm1 = _mm_load_pd(w2);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w1);
        xmm1 = _mm_load_pd(w0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        /* p = 2^k; */
        k1 = _mm_add_epi32(k1, offset);
        k2 = _mm_add_epi32(k2, offset);
        k1 = _mm_slli_epi32(k1, 20);
        k2 = _mm_slli_epi32(k2, 20);
        k1 = _mm_shuffle_epi32(k1, _MM_SHUFFLE(1,3,0,2));
        k2 = _mm_shuffle_epi32(k2, _MM_SHUFFLE(1,3,0,2));
        p1 = _mm_castsi128_pd(k1);
        p2 = _mm_castsi128_pd(k2);

        /* a *= 2^k. */
        a1 = _mm_mul_pd(a1, p1);
        a2 = _mm_mul_pd(a2, p2);

        /* Store the results. */
        _mm_store_pd(values+i, a1);
        _mm_store_pd(values+i+2, a2);
    }
}

SSE2のintrinsicに書き下しただけという感じですが,いくつか注釈を付けるとすると,

  • CONST_128Dマクロは,倍精度浮動小数点の定数を2個並べた定数を定義し,そのまま128bitのSSEレジスタにロードできるようにしている.
  • MIE_ALIGNマクロは,difference between gcc and vcのページを参照.他にも,Visual C++とgccの差異に関してまとめられていて,このページの情報はすごく有用です.
  • 128bit(倍精度浮動小数点×2)の演算を2つ並べて,パイプライン処理になりやすいように配慮している.そのため,引数nは4の倍数でなければならない.

exp(x)の高速計算 ~理論編~

ロジスティック回帰やCRFなどの対数線形モデルの学習でよく出てくるのが,expの計算です.これをSSE2を使って高速化するのが,今回のテーマです.まずは,背景の理論を説明します.

まず,指数関数e^xを2の指数関数2^aに変換することを考えます(なぜ2の指数関数かはいずれ分かります).

e^x = 2^a

両辺の自然対数をとり,aについて解くと,a = x/\log 2 .IEEE754など,2を基数とした指数部を採用している浮動小数点形式では,整数nに対して2^nを容易に構築できるので,上式の実数解aの代わりに整数,n = \left\lfloor x/\log 2 \right\rfloor を用い,e^xの大まかな値を計算することを考えます.ただし,\lfloor a \rflooraを超えない最大の整数を表します.

さて,anで近似したときの誤差(a - n)の範囲は,0 \leq a - n < 1 .誤差(a - n)\log 2を乗じたものを,bと定義すると,

b = (a - n) \log 2 = x - n \log2

上式の両辺の指数をとると,

e^x = 2^n e^b

ここで,naを超えない最大の整数なので,bの値域は,0 \leq b < \log 2

これらのことから,e^xは以下のステップで計算出来ることが分かります.

  1. n = \left\lfloor x/\log 2 \right\rfloor を計算する.
  2. b = x - n \log2 を計算する.
  3. e^bを定義域(0 \leq b < \log 2)において,テーラー展開もしくは最良近似多項式などで近似計算する.
  4. 浮動小数点形式を利用して2^nを構築し,e^bとの積を計算する.

ステップ1は簡単そうに見えますが,負の数を整数にキャストすると,0に近づいてしまうことに注意が必要です.つまり,doubleからintへのキャストは,正の値は切り下げ,負の値は切り上げになってしまいます.このことに注意すると,このステップは次のように書けます(a < 0で,ifによる分岐を作らないのが賢いやり方).

a = LOG2E * x;
a -= (a < 0);
n = (int)a;

ステップ2はそのまま書けばOKです.

ステップ3では,e^bを多項式に近似します.多項式近似というとテーラー展開が一般的ですが,「・・・点の周りについてテーラー展開すると」という表現の通り,ある点の近傍の近似は良いのですが,そこから離れた点では近似精度があんまりよくないという欠点があります.Cephesというライブラリでは,パデ近似(Padé approximant)を採用していますが,除算を1回だけ行う必要があります.最近のCPUは速くなったとはいえ,SSE2で倍精度のパック除算(DIVPD)をやるのに必要なレイテンシ・スループットは,70・70で,倍精度のパック乗算(MULPD)の7・6と比べると,除算が圧倒的に遅いことになります.そこで,今回は最良近似式を用います.なにやら難しそうに見えますが,Sollya というソフトウェアを使えば,簡単に最良多項式が求まります.以下の例では,[0, \log 2]の範囲で,5次の最良多項式を求めています.

> remez(exp(x), 5, [0, log(2)]);
0.99999989311082729779536722205742989232069120354073 + x *
(1.00001092396453942157124178508842412412025643386873 + x *
(0.49981934577169208735732248650232562589934399402426 + x *
(0.16775408658617866431779970932853611481292418818223 + x *
(3.87412011356070379615759057344100690905653320886699e-2 + x *
1.185268231308989403584147407056378360798378534739e-2))))

多項式がホーナー形式(horner scheme)で出てくるので,至れり尽くせりです.これをそのままC言語に書き直せば,ステップ3の実装は完了です.

y = 1.185268231308989403584147407056378360798378534739e-2;
y *= b;
y += 3.87412011356070379615759057344100690905653320886699e-2;
y *= b;
y += 0.16775408658617866431779970932853611481292418818223;
y *= b;
y += 0.49981934577169208735732248650232562589934399402426;
y *= b;
y += 1.00001092396453942157124178508842412412025643386873;
y *= b;
y += 0.99999989311082729779536722205742989232069120354073;

最後のステップ4ですが,もちろんpow(2, n)なんてやってしまったら,せっかくの高速化が台無しです.浮動小数点形式IEEE754では,2を基数として指数部が構成されているので,2^nはちょっとした工夫で一瞬で作れます.あんまり移植性がないコードですが,2^nは次のように作ります.

typedef union {
    double d;
    unsigned short s[4];
} ieee754;

ieee754 u;
u.d = 0;
u.s[3] = (unsigned short)(((n + 1023) << 4) & 0x7FF0);

こうすると,u.dで2^nの値にアクセスできます.なぜそうなるかは,IEEE754の仕様と睨めっこすればすぐに分かると思います.あとは,u.dとyの積を計算すれば,e^xの計算が出来たことになります.