enhance: [Cherry-pick][2.4] Upgrade bitset for ARM SVE (#33440)

issue: https://github.com/milvus-io/milvus/issues/32826
pr: https://github.com/milvus-io/milvus/pull/32718
improve ARM SVE performance for internal/core/src/bitset

Baseline timings for gcc 11.4 + Graviton 3 + manually enabled SVE:
https://gist.github.com/alexanderguzhva/a974b50134c8bb9255fb15f144e5ac83

Candidate timings for gcc 11.4 + Graviton 3 + manually enabled SVE:
https://gist.github.com/alexanderguzhva/19fc88f4ad3757e05e0f7feaf563b3d3

Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
pull/33454/head
Alexander Guzhva 2024-05-29 04:17:51 -04:00 committed by GitHub
parent 2638735d05
commit 5a668a17a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 245 additions and 358 deletions

View File

@ -42,63 +42,6 @@ namespace {
//
constexpr size_t MAX_SVE_WIDTH = 2048;
constexpr uint8_t SVE_LANES_8[MAX_SVE_WIDTH / 8] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B,
0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36,
0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A,
0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55,
0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60,
0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B,
0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8A,
0x8B, 0x8C, 0x8D, 0x8E, 0x8F, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95,
0x96, 0x97, 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F, 0xA0,
0xA1, 0xA2, 0xA3, 0xA4, 0xA5, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xAB,
0xAC, 0xAD, 0xAE, 0xAF, 0xB0, 0xB1, 0xB2, 0xB3, 0xB4, 0xB5, 0xB6,
0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBC, 0xBD, 0xBE, 0xBF,
0xC0, 0xC1, 0xC2, 0xC3, 0xC4, 0xC5, 0xC6, 0xC7, 0xC8, 0xC9, 0xCA,
0xCB, 0xCC, 0xCD, 0xCE, 0xCF, 0xD0, 0xD1, 0xD2, 0xD3, 0xD4, 0xD5,
0xD6, 0xD7, 0xD8, 0xD9, 0xDA, 0xDB, 0xDC, 0xDD, 0xDE, 0xDF, 0xE0,
0xE1, 0xE2, 0xE3, 0xE4, 0xE5, 0xE6, 0xE7, 0xE8, 0xE9, 0xEA, 0xEB,
0xEC, 0xED, 0xEE, 0xEF, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6,
0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF};
constexpr uint16_t SVE_LANES_16[MAX_SVE_WIDTH / 16] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B,
0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36,
0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A,
0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55,
0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60,
0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B,
0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F};
constexpr uint32_t SVE_LANES_32[MAX_SVE_WIDTH / 32] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B,
0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36,
0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F};
constexpr uint64_t SVE_LANES_64[MAX_SVE_WIDTH / 64] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F};
/*
// debugging facilities
@ -131,179 +74,28 @@ void print_svuint8_t(const svuint8_t value) {
///////////////////////////////////////////////////////////////////////////
// todo: replace with pext whenever available
// generate 16-bit bitmask from 8 serialized 16-bit svbool_t values
void
write_bitmask_16_8x(uint8_t* const __restrict res_u8,
const svbool_t pred_op,
const svbool_t pred_write,
const uint8_t* const __restrict pred_buf) {
// perform parallel pext
// 2048b -> 32 bytes mask -> 256 bytes total, 128 uint16_t values
// 512b -> 8 bytes mask -> 64 bytes total, 32 uint16_t values
// 256b -> 4 bytes mask -> 32 bytes total, 16 uint16_t values
// 128b -> 2 bytes mask -> 16 bytes total, 8 uint16_t values
// this code does reduction of 16-bit 0b0A0B0C0D0E0F0G0H words into
// uint8_t values 0bABCDEFGH, then writes ones to the memory
// we need to operate in uint8_t
const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf);
const svuint8_t mask_04_8b = svand_n_u8_z(pred_op, mask_8b, 0x01);
const svuint8_t mask_15_8b = svand_n_u8_z(pred_op, mask_8b, 0x04);
const svuint8_t mask_15s_8b = svlsr_n_u8_z(pred_op, mask_15_8b, 1);
const svuint8_t mask_26_8b = svand_n_u8_z(pred_op, mask_8b, 0x10);
const svuint8_t mask_26s_8b = svlsr_n_u8_z(pred_op, mask_26_8b, 2);
const svuint8_t mask_37_8b = svand_n_u8_z(pred_op, mask_8b, 0x40);
const svuint8_t mask_37s_8b = svlsr_n_u8_z(pred_op, mask_37_8b, 3);
const svuint8_t mask_0347_8b = svorr_u8_z(pred_op, mask_04_8b, mask_37s_8b);
const svuint8_t mask_1256_8b =
svorr_u8_z(pred_op, mask_15s_8b, mask_26s_8b);
const svuint8_t mask_cmb_8b =
svorr_u8_z(pred_op, mask_0347_8b, mask_1256_8b);
//
const svuint16_t shifts_16b = svdup_u16(0x0400UL);
const svuint8_t shifts_8b = svreinterpret_u8_u16(shifts_16b);
const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b);
const svuint8_t zero_8b = svdup_n_u8(0);
const svuint8_t shifted_8b_m3 =
svorr_u8_z(pred_op,
svuzp1_u8(shifted_8b_m0, zero_8b),
svuzp2_u8(shifted_8b_m0, zero_8b));
// write a finished bitmask
svst1_u8(pred_write, res_u8, shifted_8b_m3);
}
// generate 32-bit bitmask from 8 serialized 32-bit svbool_t values
void
write_bitmask_32_8x(uint8_t* const __restrict res_u8,
const svbool_t pred_op,
const svbool_t pred_write,
const uint8_t* const __restrict pred_buf) {
// perform parallel pext
// 2048b -> 32 bytes mask -> 256 bytes total, 64 uint32_t values
// 512b -> 8 bytes mask -> 64 bytes total, 16 uint32_t values
// 256b -> 4 bytes mask -> 32 bytes total, 8 uint32_t values
// 128b -> 2 bytes mask -> 16 bytes total, 4 uint32_t values
// this code does reduction of 32-bit 0b000A000B000C000D... dwords into
// uint8_t values 0bABCDEFGH, then writes ones to the memory
// we need to operate in uint8_t
const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf);
const svuint8_t mask_024_8b = svand_n_u8_z(pred_op, mask_8b, 0x01);
const svuint8_t mask_135s_8b = svlsr_n_u8_z(pred_op, mask_8b, 3);
const svuint8_t mask_cmb_8b =
svorr_u8_z(pred_op, mask_024_8b, mask_135s_8b);
//
const svuint32_t shifts_32b = svdup_u32(0x06040200UL);
const svuint8_t shifts_8b = svreinterpret_u8_u32(shifts_32b);
const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b);
const svuint8_t zero_8b = svdup_n_u8(0);
const svuint8_t shifted_8b_m2 =
svorr_u8_z(pred_op,
svuzp1_u8(shifted_8b_m0, zero_8b),
svuzp2_u8(shifted_8b_m0, zero_8b));
const svuint8_t shifted_8b_m3 =
svorr_u8_z(pred_op,
svuzp1_u8(shifted_8b_m2, zero_8b),
svuzp2_u8(shifted_8b_m2, zero_8b));
// write a finished bitmask
svst1_u8(pred_write, res_u8, shifted_8b_m3);
}
// generate 64-bit bitmask from 8 serialized 64-bit svbool_t values
void
write_bitmask_64_8x(uint8_t* const __restrict res_u8,
const svbool_t pred_op,
const svbool_t pred_write,
const uint8_t* const __restrict pred_buf) {
// perform parallel pext
// 2048b -> 32 bytes mask -> 256 bytes total, 32 uint64_t values
// 512b -> 8 bytes mask -> 64 bytes total, 4 uint64_t values
// 256b -> 4 bytes mask -> 32 bytes total, 2 uint64_t values
// 128b -> 2 bytes mask -> 16 bytes total, 1 uint64_t values
// this code does reduction of 64-bit 0b0000000A0000000B... qwords into
// uint8_t values 0bABCDEFGH, then writes ones to the memory
// we need to operate in uint8_t
const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf);
const svuint64_t shifts_64b = svdup_u64(0x706050403020100ULL);
const svuint8_t shifts_8b = svreinterpret_u8_u64(shifts_64b);
const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_8b, shifts_8b);
const svuint8_t zero_8b = svdup_n_u8(0);
const svuint8_t shifted_8b_m1 =
svorr_u8_z(pred_op,
svuzp1_u8(shifted_8b_m0, zero_8b),
svuzp2_u8(shifted_8b_m0, zero_8b));
const svuint8_t shifted_8b_m2 =
svorr_u8_z(pred_op,
svuzp1_u8(shifted_8b_m1, zero_8b),
svuzp2_u8(shifted_8b_m1, zero_8b));
const svuint8_t shifted_8b_m3 =
svorr_u8_z(pred_op,
svuzp1_u8(shifted_8b_m2, zero_8b),
svuzp2_u8(shifted_8b_m2, zero_8b));
// write a finished bitmask
svst1_u8(pred_write, res_u8, shifted_8b_m3);
}
///////////////////////////////////////////////////////////////////////////
//
inline svbool_t
get_pred_op_8(const size_t n_elements) {
const svbool_t pred_all_8 = svptrue_b8();
const svuint8_t lanes_8 = svld1_u8(pred_all_8, SVE_LANES_8);
const svuint8_t leftovers_op = svdup_n_u8(n_elements);
const svbool_t pred_op = svcmpgt_u8(pred_all_8, leftovers_op, lanes_8);
return pred_op;
return svwhilelt_b8(uint32_t(0), uint32_t(n_elements));
}
//
inline svbool_t
get_pred_op_16(const size_t n_elements) {
const svbool_t pred_all_16 = svptrue_b16();
const svuint16_t lanes_16 = svld1_u16(pred_all_16, SVE_LANES_16);
const svuint16_t leftovers_op = svdup_n_u16(n_elements);
const svbool_t pred_op = svcmpgt_u16(pred_all_16, leftovers_op, lanes_16);
return pred_op;
return svwhilelt_b16(uint32_t(0), uint32_t(n_elements));
}
//
inline svbool_t
get_pred_op_32(const size_t n_elements) {
const svbool_t pred_all_32 = svptrue_b32();
const svuint32_t lanes_32 = svld1_u32(pred_all_32, SVE_LANES_32);
const svuint32_t leftovers_op = svdup_n_u32(n_elements);
const svbool_t pred_op = svcmpgt_u32(pred_all_32, leftovers_op, lanes_32);
return pred_op;
return svwhilelt_b32(uint32_t(0), uint32_t(n_elements));
}
//
inline svbool_t
get_pred_op_64(const size_t n_elements) {
const svbool_t pred_all_64 = svptrue_b64();
const svuint64_t lanes_64 = svld1_u64(pred_all_64, SVE_LANES_64);
const svuint64_t leftovers_op = svdup_n_u64(n_elements);
const svbool_t pred_op = svcmpgt_u64(pred_all_64, leftovers_op, lanes_64);
return pred_op;
return svwhilelt_b64(uint32_t(0), uint32_t(n_elements));
}
//
@ -579,7 +371,7 @@ struct SVEVector<int8_t> {
using sve_type = svint8_t;
// measured in the number of elements that an SVE register can hold
static inline size_t
static inline uint64_t
width() {
return svcntb();
}
@ -606,7 +398,7 @@ struct SVEVector<int16_t> {
using sve_type = svint16_t;
// measured in the number of elements that an SVE register can hold
static inline size_t
static inline uint64_t
width() {
return svcnth();
}
@ -633,7 +425,7 @@ struct SVEVector<int32_t> {
using sve_type = svint32_t;
// measured in the number of elements that an SVE register can hold
static inline size_t
static inline uint64_t
width() {
return svcntw();
}
@ -660,7 +452,7 @@ struct SVEVector<int64_t> {
using sve_type = svint64_t;
// measured in the number of elements that an SVE register can hold
static inline size_t
static inline uint64_t
width() {
return svcntd();
}
@ -687,7 +479,7 @@ struct SVEVector<float> {
using sve_type = svfloat32_t;
// measured in the number of elements that an SVE register can hold
static inline size_t
static inline uint64_t
width() {
return svcntw();
}
@ -714,7 +506,7 @@ struct SVEVector<double> {
using sve_type = svfloat64_t;
// measured in the number of elements that an SVE register can hold
static inline size_t
static inline uint64_t
width() {
return svcntd();
}
@ -737,22 +529,15 @@ struct SVEVector<double> {
///////////////////////////////////////////////////////////////////////////
// an interesting discussion here:
// https://stackoverflow.com/questions/77834169/what-is-a-fast-fallback-algorithm-which-emulates-pdep-and-pext-in-software
// SVE2 has bitperm, which contains the implementation of pext
// todo: replace with pext whenever available
//
// NBYTES is the size of the underlying datatype in bytes.
// So, for example, for i8/u8 use 1, for i64/u64/f64 use 8/
template <size_t NBYTES>
struct MaskHelper {};
template <>
struct MaskHelper<1> {
static inline void
write(uint8_t* const __restrict bitmask,
const size_t size,
write_full(uint8_t* const __restrict bitmask,
const svbool_t pred0,
const svbool_t pred1,
const svbool_t pred2,
@ -761,8 +546,8 @@ struct MaskHelper<1> {
const svbool_t pred5,
const svbool_t pred6,
const svbool_t pred7) {
const size_t sve_width = svcntb();
if (sve_width == 8 * sve_width) {
const uint64_t sve_width = svcntb();
// perform a full write
*((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred0;
*((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred1;
@ -772,124 +557,234 @@ struct MaskHelper<1> {
*((svbool_t*)(bitmask + 5 * sve_width / 8)) = pred5;
*((svbool_t*)(bitmask + 6 * sve_width / 8)) = pred6;
*((svbool_t*)(bitmask + 7 * sve_width / 8)) = pred7;
} else {
}
static inline void
write_partial(uint8_t* const __restrict bitmask,
const size_t size,
const svbool_t pred_0,
const svbool_t pred_1,
const svbool_t pred_2,
const svbool_t pred_3,
const svbool_t pred_4,
const svbool_t pred_5,
const svbool_t pred_6,
const svbool_t pred_7) {
const uint64_t sve_width = svcntb();
// perform a partial write
// this is the buffer for the maximum possible case of 2048 bits
// this is a temporary buffer for the maximum possible case of 2048 bits
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred0;
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred1;
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred2;
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred3;
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 8)) = pred4;
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 8)) = pred5;
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 8)) = pred6;
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 8)) = pred7;
// write to the temporary buffer
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_0;
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_1;
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred_2;
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred_3;
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 8)) = pred_4;
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 8)) = pred_5;
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 8)) = pred_6;
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 8)) = pred_7;
// make the write mask
const svbool_t pred_write = get_pred_op_8(size / 8);
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
const svbool_t pred_write =
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
// load the buffer
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
// write it to the bitmask
svst1_u8(pred_write, bitmask, mask_u8);
}
}
};
template <>
struct MaskHelper<2> {
static inline void
write(uint8_t* const __restrict bitmask,
write_full(uint8_t* const __restrict bitmask,
const svbool_t pred_0,
const svbool_t pred_1,
const svbool_t pred_2,
const svbool_t pred_3,
const svbool_t pred_4,
const svbool_t pred_5,
const svbool_t pred_6,
const svbool_t pred_7) {
const uint64_t sve_width = svcntb();
// compact predicates
const svbool_t pred_01 = svuzp1_b8(pred_0, pred_1);
const svbool_t pred_23 = svuzp1_b8(pred_2, pred_3);
const svbool_t pred_45 = svuzp1_b8(pred_4, pred_5);
const svbool_t pred_67 = svuzp1_b8(pred_6, pred_7);
// perform a full write
*((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred_01;
*((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred_23;
*((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred_45;
*((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred_67;
}
static inline void
write_partial(uint8_t* const __restrict bitmask,
const size_t size,
const svbool_t pred0,
const svbool_t pred1,
const svbool_t pred2,
const svbool_t pred3,
const svbool_t pred4,
const svbool_t pred5,
const svbool_t pred6,
const svbool_t pred7) {
const size_t sve_width = svcnth();
const svbool_t pred_0,
const svbool_t pred_1,
const svbool_t pred_2,
const svbool_t pred_3,
const svbool_t pred_4,
const svbool_t pred_5,
const svbool_t pred_6,
const svbool_t pred_7) {
const uint64_t sve_width = svcntb();
// this is the buffer for the maximum possible case of 2048 bits
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 4)) = pred0;
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 4)) = pred1;
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 4)) = pred2;
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 4)) = pred3;
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 4)) = pred4;
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 4)) = pred5;
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 4)) = pred6;
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 4)) = pred7;
// compact predicates
const svbool_t pred_01 = svuzp1_b8(pred_0, pred_1);
const svbool_t pred_23 = svuzp1_b8(pred_2, pred_3);
const svbool_t pred_45 = svuzp1_b8(pred_4, pred_5);
const svbool_t pred_67 = svuzp1_b8(pred_6, pred_7);
const svbool_t pred_op_8 = get_pred_op_8(size / 4);
const svbool_t pred_write_8 = get_pred_op_8(size / 8);
write_bitmask_16_8x(bitmask, pred_op_8, pred_write_8, pred_buf);
// this is a temporary buffer for the maximum possible case of 1024 bits
uint8_t pred_buf[MAX_SVE_WIDTH / 16];
// write to the temporary buffer
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_01;
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_23;
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred_45;
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred_67;
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
const svbool_t pred_write =
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
// load the buffer
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
// write it to the bitmask
svst1_u8(pred_write, bitmask, mask_u8);
}
};
template <>
struct MaskHelper<4> {
static inline void
write(uint8_t* const __restrict bitmask,
write_full(uint8_t* const __restrict bitmask,
const svbool_t pred_0,
const svbool_t pred_1,
const svbool_t pred_2,
const svbool_t pred_3,
const svbool_t pred_4,
const svbool_t pred_5,
const svbool_t pred_6,
const svbool_t pred_7) {
const uint64_t sve_width = svcntb();
// compact predicates
const svbool_t pred_01 = svuzp1_b16(pred_0, pred_1);
const svbool_t pred_23 = svuzp1_b16(pred_2, pred_3);
const svbool_t pred_45 = svuzp1_b16(pred_4, pred_5);
const svbool_t pred_67 = svuzp1_b16(pred_6, pred_7);
const svbool_t pred_0123 = svuzp1_b8(pred_01, pred_23);
const svbool_t pred_4567 = svuzp1_b8(pred_45, pred_67);
// perform a full write
*((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred_0123;
*((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred_4567;
}
static inline void
write_partial(uint8_t* const __restrict bitmask,
const size_t size,
const svbool_t pred0,
const svbool_t pred1,
const svbool_t pred2,
const svbool_t pred3,
const svbool_t pred4,
const svbool_t pred5,
const svbool_t pred6,
const svbool_t pred7) {
const size_t sve_width = svcntw();
const svbool_t pred_0,
const svbool_t pred_1,
const svbool_t pred_2,
const svbool_t pred_3,
const svbool_t pred_4,
const svbool_t pred_5,
const svbool_t pred_6,
const svbool_t pred_7) {
const uint64_t sve_width = svcntb();
// this is the buffer for the maximum possible case of 2048 bits
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 2)) = pred0;
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 2)) = pred1;
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 2)) = pred2;
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 2)) = pred3;
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 2)) = pred4;
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 2)) = pred5;
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 2)) = pred6;
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 2)) = pred7;
// compact predicates
const svbool_t pred_01 = svuzp1_b16(pred_0, pred_1);
const svbool_t pred_23 = svuzp1_b16(pred_2, pred_3);
const svbool_t pred_45 = svuzp1_b16(pred_4, pred_5);
const svbool_t pred_67 = svuzp1_b16(pred_6, pred_7);
const svbool_t pred_0123 = svuzp1_b8(pred_01, pred_23);
const svbool_t pred_4567 = svuzp1_b8(pred_45, pred_67);
const svbool_t pred_op_8 = get_pred_op_8(size / 2);
const svbool_t pred_write_8 = get_pred_op_8(size / 8);
write_bitmask_32_8x(bitmask, pred_op_8, pred_write_8, pred_buf);
// this is a temporary buffer for the maximum possible case of 512 bits
uint8_t pred_buf[MAX_SVE_WIDTH / 32];
// write to the temporary buffer
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_0123;
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_4567;
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
const svbool_t pred_write =
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
// load the buffer
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
// write it to the bitmask
svst1_u8(pred_write, bitmask, mask_u8);
}
};
template <>
struct MaskHelper<8> {
static inline void
write(uint8_t* const __restrict bitmask,
write_full(uint8_t* const __restrict bitmask,
const svbool_t pred_0,
const svbool_t pred_1,
const svbool_t pred_2,
const svbool_t pred_3,
const svbool_t pred_4,
const svbool_t pred_5,
const svbool_t pred_6,
const svbool_t pred_7) {
// compact predicates
const svbool_t pred_01 = svuzp1_b32(pred_0, pred_1);
const svbool_t pred_23 = svuzp1_b32(pred_2, pred_3);
const svbool_t pred_45 = svuzp1_b32(pred_4, pred_5);
const svbool_t pred_67 = svuzp1_b32(pred_6, pred_7);
const svbool_t pred_0123 = svuzp1_b16(pred_01, pred_23);
const svbool_t pred_4567 = svuzp1_b16(pred_45, pred_67);
const svbool_t pred_01234567 = svuzp1_b8(pred_0123, pred_4567);
// perform a full write
*((svbool_t*)bitmask) = pred_01234567;
}
static inline void
write_partial(uint8_t* const __restrict bitmask,
const size_t size,
const svbool_t pred0,
const svbool_t pred1,
const svbool_t pred2,
const svbool_t pred3,
const svbool_t pred4,
const svbool_t pred5,
const svbool_t pred6,
const svbool_t pred7) {
const size_t sve_width = svcntd();
const svbool_t pred_0,
const svbool_t pred_1,
const svbool_t pred_2,
const svbool_t pred_3,
const svbool_t pred_4,
const svbool_t pred_5,
const svbool_t pred_6,
const svbool_t pred_7) {
// compact predicates
const svbool_t pred_01 = svuzp1_b32(pred_0, pred_1);
const svbool_t pred_23 = svuzp1_b32(pred_2, pred_3);
const svbool_t pred_45 = svuzp1_b32(pred_4, pred_5);
const svbool_t pred_67 = svuzp1_b32(pred_6, pred_7);
const svbool_t pred_0123 = svuzp1_b16(pred_01, pred_23);
const svbool_t pred_4567 = svuzp1_b16(pred_45, pred_67);
const svbool_t pred_01234567 = svuzp1_b8(pred_0123, pred_4567);
// this is the buffer for the maximum possible case of 2048 bits
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
*((volatile svbool_t*)(pred_buf + 0 * sve_width)) = pred0;
*((volatile svbool_t*)(pred_buf + 1 * sve_width)) = pred1;
*((volatile svbool_t*)(pred_buf + 2 * sve_width)) = pred2;
*((volatile svbool_t*)(pred_buf + 3 * sve_width)) = pred3;
*((volatile svbool_t*)(pred_buf + 4 * sve_width)) = pred4;
*((volatile svbool_t*)(pred_buf + 5 * sve_width)) = pred5;
*((volatile svbool_t*)(pred_buf + 6 * sve_width)) = pred6;
*((volatile svbool_t*)(pred_buf + 7 * sve_width)) = pred7;
// this is a temporary buffer for the maximum possible case of 256 bits
uint8_t pred_buf[MAX_SVE_WIDTH / 64];
// write to the temporary buffer
*((volatile svbool_t*)(pred_buf)) = pred_01234567;
const svbool_t pred_op_8 = get_pred_op_8(size / 1);
const svbool_t pred_write_8 = get_pred_op_8(size / 8);
write_bitmask_64_8x(bitmask, pred_op_8, pred_write_8, pred_buf);
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
const svbool_t pred_write =
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
// load the buffer
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
// write it to the bitmask
svst1_u8(pred_write, bitmask, mask_u8);
}
};
@ -924,16 +819,8 @@ op_mask_helper(uint8_t* const __restrict res_u8, const size_t size, Func func) {
const svbool_t cmp6 = func(pred_all, i + 6 * sve_width);
const svbool_t cmp7 = func(pred_all, i + 7 * sve_width);
MaskHelper<sizeof(T)>::write(res_u8 + i / 8,
sve_width * 8,
cmp0,
cmp1,
cmp2,
cmp3,
cmp4,
cmp5,
cmp6,
cmp7);
MaskHelper<sizeof(T)>::write_full(
res_u8 + i / 8, cmp0, cmp1, cmp2, cmp3, cmp4, cmp5, cmp6, cmp7);
}
}
@ -985,7 +872,7 @@ op_mask_helper(uint8_t* const __restrict res_u8, const size_t size, Func func) {
cmp7 = func(get_partial_pred(7), size_sve8 + 7 * sve_width);
}
MaskHelper<sizeof(T)>::write(res_u8 + size_sve8 / 8,
MaskHelper<sizeof(T)>::write_partial(res_u8 + size_sve8 / 8,
size - size_sve8,
cmp0,
cmp1,