Modern Arm Assembly Language Programming: Covers Armv8-A 32-bit, 64-bit, and SIMD

by Daniel Kusswurm
2021.07.28: updated by
Up

Chapter 15: Armv8-64 SIMD Floating-Point Programming

Packed Floating-Point Arrays

Correlation Coefficient

NEON を使って相関係数の計算を行うプログラムの説明

ch15_04/main.cpp
#include <iostream>
#include <iomanip>
#include <string>
#include <random>
#include "AlignedMem.h"

using namespace std;

extern bool CalcCorrCoef_(float* rho, float sums[5], const float* x, const float* y, size_t n, float epsilon);

const size_t c_Alignment = 16;

void Init(float* x, float* y, size_t n, unsigned int seed) {
    uniform_real_distribution<float> dist1 {0.0, 50.0};
    normal_distribution<float> dist2 {25.0, 7.0};
    mt19937 rng {seed};

    for (size_t i = 0; i < n; i++) {
        x[i] = round(dist1(rng));
        y[i] = x[i] + round(dist2(rng));
//      cout << setw(10) << x[i] << ", " << setw(10) << y[i] << endl;
    }
}

bool CalcCorrCoef(float* rho, float sums[5], const float* x, const float* y, size_t n, float epsilon) {
    // Make sure n is valid
    if (n == 0)
        return false;

    // Make sure x and y are properly aligned
    if (!AlignedMem::IsAligned(x, c_Alignment))
        return false;
    if (!AlignedMem::IsAligned(y, c_Alignment))
        return false;

    // Calculate and save sum variables
    float sum_x = 0, sum_y = 0, sum_xx = 0, sum_yy = 0, sum_xy = 0;

    for (size_t i = 0; i < n; i++)
    {
        sum_x += x[i];
        sum_y += y[i];
        sum_xx += x[i] * x[i];
        sum_yy += y[i] * y[i];
        sum_xy += x[i] * y[i];
    }

    sums[0] = sum_x;
    sums[1] = sum_y;
    sums[2] = sum_xx;
    sums[3] = sum_yy;
    sums[4] = sum_xy;

    // Calculate rho
    float rho_num = n * sum_xy - sum_x * sum_y;
    float rho_den = sqrt(n * sum_xx - sum_x * sum_x) * sqrt(n * sum_yy - sum_y * sum_y);

    if (rho_den >= epsilon)
    {
        *rho = rho_num / rho_den;
        return true;
    }
    else
    {
        *rho = 0;
        return false;
    }
}

int main()
{
    const char nl = '\n';
    const size_t n = 103;
    AlignedArray<float> x_aa(n, c_Alignment);
    AlignedArray<float> y_aa(n, c_Alignment);
    float sums1[5], sums2[5];
    float rho1, rho2;
    float epsilon = 1.0e-9;
    float* x = x_aa.Data();
    float* y = y_aa.Data();

    Init(x, y, n, 71);

    bool rc1 = CalcCorrCoef(&rho1, sums1, x, y, n, epsilon);
    bool rc2 = CalcCorrCoef_(&rho2, sums2, x, y, n, epsilon);

    cout << "Results for CalcCorrCoef\n\n";

    if (!rc1 || !rc2)
    {
        cout << "Invalid return code ";
        cout << "rc1 = " << boolalpha << rc1 << ", ";
        cout << "rc2 = " << boolalpha << rc2 << nl;
        return 1;
    }

    int w = 14;
    string sep(w * 3, '-');

    cout << fixed << setprecision(6);
    cout << "Value    " << setw(w) << "C++" << " " << setw(w) << "A64 SIMD" << nl;
    cout << sep << nl;

    cout << setprecision(2);
    cout << "sum_x:   " << setw(w) << sums1[0] << " " << setw(w) << sums2[0] << nl;
    cout << "sum_y:   " << setw(w) << sums1[1] << " " << setw(w) << sums2[1] << nl;
    cout << "sum_xx:  " << setw(w) << sums1[2] << " " << setw(w) << sums2[2] << nl;
    cout << "sum_yy:  " << setw(w) << sums1[3] << " " << setw(w) << sums2[3] << nl;
    cout << "sum_xy:  " << setw(w) << sums1[4] << " " << setw(w) << sums2[4] << nl;
    cout << "rho:     " << setw(w) << rho1 << " " << setw(w) << rho2 << nl;
    return 0;
}
ch15_04/neon.cpp
#include "Vec128.h"

//#define UpdateSums(VregX, VregY)					\
//	fadd	v16.4s, v16.4s, \VregX\().4s     // update sum_x \n\
//	fadd	v17.4s, v17.4s, \VregX\().4s     // update sum_y \n\



void CalcCorrCoef_(float* rho, float sums[5], const float* x, const float* y, size_t n, float epsilon) {
  __asm volatile ("\n\
	                                              \n\
	.macro UpdateSums VregX, VregY                \n\
	fadd.4s	v16, v16, \\VregX\\() // sum_x: v16 += VregX  \n\
	fadd.4s	v17, v17, \\VregY\\() // sum_y: v17 += VregX  \n\
	fmla.4s	v18, \\VregX\\(), \\VregX\\() // sum_xx: v18 += VregX^2 \n\
	fmla.4s	v19, \\VregY\\(), \\VregY\\() // sum_yy: v19 += VregY^2 \n\
	fmla.4s	v20, \\VregX\\(), \\VregY\\() // sum_xy: v20 += VregX * VregY \n\
	.endm                                         \n\
	                                              \n\
	cbz	x4, LInvalidArg // if n == 0 goto END \n\
	tst	x2, 0x0f            // if (x2 != 15)  \n\
	b.ne	LInvalidArg     //   goto END         \n\
	tst	x3, 0x0f            // if (x3 != 15)  \n\
	b.ne	LInvalidArg     //   goto END         \n\
	mov	x5, x4	        // save n to x5       \n\
	                                              \n\
	eor	v16.16b, v16.16b, v16.16b // sum_x = 0    \n\
	eor	v17.16b, v17.16b, v17.16b // sum_y = 0    \n\
	eor	v18.16b, v18.16b, v18.16b // sum_xx = 0   \n\
	eor	v19.16b, v19.16b, v19.16b // sum_yy = 0   \n\
	eor	v20.16b, v20.16b, v20.16b // sum_xy = 0   \n\
	                                              \n\
	cmp	x4, 16          // if n<=16           \n\
	b.lo	LSkipLoop1      //   goto SkipLoop1   \n\
	                                              \n\
LLoop1:	                                              \n\
	ld1	{v0.4s, v1.4s, v2.4s, v3.4s}, [x2], 64 // load x[0:16] \n\
	ld1	{v4.4s, v5.4s, v6.4s, v7.4s}, [x3], 64 // load y[0:16] \n\
	                                              \n\
	UpdateSums v0, v4                             \n\
	UpdateSums v1, v5                             \n\
	UpdateSums v2, v6                             \n\
	UpdateSums v3, v7                             \n\
	sub	x4, x4, 16               // n -= 16   \n\
	cmp	x4, 16              // if x4 >= 16    \n\
	b.hs	LLoop1             //   goto Loop     \n\
	                                              \n\
LSkipLoop1:                                           \n\
	faddp	v16.4s, v16.4s, v16.4s // lane0=lane0+lane1,lane1=lane2+lane3 \n\
	faddp	v16.4s, v16.4s, v16.4s // s16 = lane0+lane1      \n\
	faddp	v17.4s, v17.4s, v17.4s // lane0=lane0+lane1,lane1=lane2+lane3 \n\
	faddp	v17.4s, v17.4s, v17.4s // s17 = lane0+lane1      \n\
	faddp	v18.4s, v18.4s, v18.4s // lane0=lane0+lane1,lane1=lane2+lane3 \n\
	faddp	v18.4s, v18.4s, v18.4s // s18 = lane0+lane1      \n\
	faddp	v19.4s, v19.4s, v19.4s // lane0=lane0+lane1,lane1=lane2+lane3 \n\
	faddp	v19.4s, v19.4s, v19.4s // s19 = lane0+lane1      \n\
	faddp	v20.4s, v20.4s, v20.4s // lane0=lane0+lane1,lane1=lane2+lane3 \n\
	faddp	v20.4s, v20.4s, v20.4s // s20 = lane0+lane1      \n\
	cbz	x4, LSkipLoop2   // if x4==0 goto SkipLoop2      \n\
	                                              \n\
LLoop2:	                                              \n\
	ldr	s1, [x2], 4       // s1 = x; x+=4     \n\
	ldr	s2, [x3], 4       // s2 = y; y+=4     \n\
	fadd	s16, s16, s1      // s16 += s1        \n\
	fadd	s17, s17, s2      // s17 += s2        \n\
	fmla.4s v18, v1, v1[0] // v18 += s1 * s1   \n\
	fmla.4s v19, v2, v2[0] // f19 += s2 * s2   \n\
	fmla.4s v20, v1, v2[0] // f20 += s1 * s2   \n\
	subs	x4, x4, 1         // if (--n !=0)     \n\
	b.ne LLoop2               //   goto Loop2     \n\
	                                              \n\
LSkipLoop2:                                           \n\
	stp	s16, s17, [x1], 8 // [x1]=s16,s17; x1+=8 \n\
	stp	s18, s19, [x1], 8 // [x1]=s18,s19; x1+=8 \n\
	str	s20, [x1]         // [x1]=s20; x1+=8  \n\
	                                              \n\
	// rho numerator                              \n\
	scvtf	s21, x5	          // s21 = n          \n\
	fmul	s1, s21, s20      // s1 = n * sum_xy  \n\
	fmls.4s v1, v16, v17[0]   // s1 -= sum_x * sum_y \n\
	                                              \n\
	// rho denominator                            \n\
	fmul	s2, s21, s18      // s2 = n * sum_xx  \n\
	fmsub	s2, s16, s16, s2  // s2 = s2 - sum_x * sum_x \n	\
	fsqrt	s2, s2            // s2 = sqrt(s2)    \n\
	                                              \n\
	fmul	s3, s21, s19      // s2 = n * sum_yy \n\
	fmsub	s3, s17, s17, s3  // s2 = s3 - sum_y * sum_y \n\
	fsqrt	s3, s3           // s3 = sqrt(s3)     \n\
	                                              \n\
	fmul	s4, s2, s3       // s4 = s2 * s3      \n\
	fcmp	s4, s0           // if rho_den < epsilon \n\
	b.lo	LBadRhoDen       //   goto BadRhoDen  \n\
	                                              \n\
	fdiv	s5, s1, s4       // s5 = rho          \n\
	str	s5, [x0]         // [x0] = s5         \n\
	mov	w0, 1            // return code: success \n\
	ret                                           \n\
	                                              \n\
LBadRhoDen:                                           \n\
	eor	v5.16b, v5.16b, v5.16b   // rho = 0   \n\
	str	s5, [x0]         // [x0] = rho        \n\
	mov	w0, 0		// return code: fail  \n\
	ret                                           \n\
	                                              \n\
LInvalidArg:                                          \n\
	mov	w0, 0		// return code: fail  \n\
	ret                                           \n\
LReturn:                                              \n\
	"
	:
	:
	: "v0", "v1", "v2", "v3", "v4", "v5", "v16", "v17", "v18", "v19", "v20", "x0", "x1"
		  );
}
ch15_04/main.cpp の実行例
arm64@manet ch15_04 % g++ -I.. -std=c++11 -S neon.cpp
arm64@manet ch15_04 % g++ -I.. -std=c++11 -O main.cpp neon.cpp -o a.out
arm64@manet ch15_04 % ./a.out
Results for CalcCorrCoef

Value               C++       A64 SIMD
------------------------------------------
sum_x:          2567.00        2567.00
sum_y:          5160.00        5160.00
sum_xx:        88805.00       88805.00
sum_yy:       287412.00      287412.00
sum_xy:       153065.00      153065.00
rho:               0.91           0.91


http://ynitta.com/