« | »

2010.09.02

exp(x)の高速計算 ~SSE2実装編~

exp(x)の高速計算 ~理論編~ の内容を元に,SSE2で実装してみます.ここからは,SSE2に関する基礎知識が必要になるのですが,すべて書くのは面倒なので,簡単にまとめると,

  • SSE2では,倍精度浮動小数点の2個の値に対して同じ演算をすると速い.
  • 現在のIntel CPUでは,SSE専用のレジスタ(128bit)が8個ある.
  • 加減乗除の算術演算や,ビットシフトなど,殆どの命令がSSE専用に準備されている.
  • メモリとSSE専用のレジスタ間でデータ転送を行うときは,メモリアドレスが16バイトでアライメントされている必要がある(厳密には「すべき」という言い方が正しいが).
  • C言語のソースコードの中でSSE用のコードを手っ取り早く書く方法は,コンパイラのintrinsicを使う方法.Microsoft Visual C++とgccでほぼ完全な互換性があるのが嬉しい.
  • コンパイラintrinsicは,アセンブラではないので,自分が意図したとおりの最適化コードが生成されるとは限らない.一応,コンパイラが最適化して速いコードを生成することになっているが,実際はそう上手くいかないことの方が多いので,コンパイル後のアセンブラコードを確認する必要がある.
  • SSEは,かなり昔にへるみさんのページで勉強しました.大体感触を掴んだら,MSDNのドキュメントなどを見ながら,コンパイラintrinsicを書けば,そんなに難しくないと思います.
  • SSEの命令を並べるときは,命令のロード・演算・ストアのパイプライン処理を頭の中で考えるべき.

だいたいこんな所ですかね.先ほどのexp(x)の計算ルーチン(5次最良多項式近似)をSSE2で書くと,こういうコードになります.

#include < emmintrin.h >

#ifdef _MSC_VER
#define MIE_ALIGN(x) __declspec(align(x))
#else
#define MIE_ALIGN(x) __attribute__((aligned(x)))
#endif

#define CONST_128D(var, val) \
    MIE_ALIGN(16) static const double var[2] = {(val), (val)}

void remez5_0_log2_sse(double *values, int num)
{
    int i;
    CONST_128D(one, 1.);
    CONST_128D(log2e, 1.4426950408889634073599);
    CONST_128D(c1, 6.93145751953125E-1);
    CONST_128D(c2, 1.42860682030941723212E-6);
    CONST_128D(w5, 1.185268231308989403584147407056378360798378534739e-2);
    CONST_128D(w4, 3.87412011356070379615759057344100690905653320886699e-2);
    CONST_128D(w3, 0.16775408658617866431779970932853611481292418818223);
    CONST_128D(w2, 0.49981934577169208735732248650232562589934399402426);
    CONST_128D(w1, 1.00001092396453942157124178508842412412025643386873);
    CONST_128D(w0, 0.99999989311082729779536722205742989232069120354073);
    const __m128i offset = _mm_setr_epi32(1023, 1023, 0, 0);

    for (i = 0;i < num;i += 4) {
        __m128i k1, k2;
        __m128d p1, p2;
        __m128d a1, a2;
        __m128d xmm0, xmm1;
        __m128d x1, x2;

        /* Load four double values. */
        x1 = _mm_load_pd(values+i);
        x2 = _mm_load_pd(values+i+2);

        /* a = x / log2; */
        xmm0 = _mm_load_pd(log2e);
        xmm1 = _mm_setzero_pd();
        a1 = _mm_mul_pd(x1, xmm0);
        a2 = _mm_mul_pd(x2, xmm0);

        /* k = (int)floor(a); p = (float)k; */
        p1 = _mm_cmplt_pd(a1, xmm1);
        p2 = _mm_cmplt_pd(a2, xmm1);
        xmm0 = _mm_load_pd(one);
        p1 = _mm_and_pd(p1, xmm0);
        p2 = _mm_and_pd(p2, xmm0);
        a1 = _mm_sub_pd(a1, p1);
        a2 = _mm_sub_pd(a2, p2);
        k1 = _mm_cvttpd_epi32(a1);
        k2 = _mm_cvttpd_epi32(a2);
        p1 = _mm_cvtepi32_pd(k1);
        p2 = _mm_cvtepi32_pd(k2);

        /* x -= p * log2; */
        xmm0 = _mm_load_pd(c1);
        xmm1 = _mm_load_pd(c2);
        a1 = _mm_mul_pd(p1, xmm0);
        a2 = _mm_mul_pd(p2, xmm0);
        x1 = _mm_sub_pd(x1, a1);
        x2 = _mm_sub_pd(x2, a2);
        a1 = _mm_mul_pd(p1, xmm1);
        a2 = _mm_mul_pd(p2, xmm1);
        x1 = _mm_sub_pd(x1, a1);
        x2 = _mm_sub_pd(x2, a2);

        /* Compute e^x using a polynomial approximation. */
        xmm0 = _mm_load_pd(w5);
        xmm1 = _mm_load_pd(w4);
        a1 = _mm_mul_pd(x1, xmm0);
        a2 = _mm_mul_pd(x2, xmm0);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w3);
        xmm1 = _mm_load_pd(w2);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w1);
        xmm1 = _mm_load_pd(w0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        /* p = 2^k; */
        k1 = _mm_add_epi32(k1, offset);
        k2 = _mm_add_epi32(k2, offset);
        k1 = _mm_slli_epi32(k1, 20);
        k2 = _mm_slli_epi32(k2, 20);
        k1 = _mm_shuffle_epi32(k1, _MM_SHUFFLE(1,3,0,2));
        k2 = _mm_shuffle_epi32(k2, _MM_SHUFFLE(1,3,0,2));
        p1 = _mm_castsi128_pd(k1);
        p2 = _mm_castsi128_pd(k2);

        /* a *= 2^k. */
        a1 = _mm_mul_pd(a1, p1);
        a2 = _mm_mul_pd(a2, p2);

        /* Store the results. */
        _mm_store_pd(values+i, a1);
        _mm_store_pd(values+i+2, a2);
    }
}

SSE2のintrinsicに書き下しただけという感じですが,いくつか注釈を付けるとすると,

  • CONST_128Dマクロは,倍精度浮動小数点の定数を2個並べた定数を定義し,そのまま128bitのSSEレジスタにロードできるようにしている.
  • MIE_ALIGNマクロは,difference between gcc and vcのページを参照.他にも,Visual C++とgccの差異に関してまとめられていて,このページの情報はすごく有用です.
  • 128bit(倍精度浮動小数点×2)の演算を2つ並べて,パイプライン処理になりやすいように配慮している.そのため,引数nは4の倍数でなければならない.

Trackback URL

Comment & Trackback

No comments.

Comment feed

Comment





XHTML: You can use these tags:
<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>