mirror of https://github.com/milvus-io/milvus.git
enhance: [Cherry-pick][2.4] Upgrade bitset for ARM SVE (#33440)
issue: https://github.com/milvus-io/milvus/issues/32826 pr: https://github.com/milvus-io/milvus/pull/32718 improve ARM SVE performance for internal/core/src/bitset Baseline timings for gcc 11.4 + Graviton 3 + manually enabled SVE: https://gist.github.com/alexanderguzhva/a974b50134c8bb9255fb15f144e5ac83 Candidate timings for gcc 11.4 + Graviton 3 + manually enabled SVE: https://gist.github.com/alexanderguzhva/19fc88f4ad3757e05e0f7feaf563b3d3 Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>pull/33454/head
parent
2638735d05
commit
5a668a17a9
|
@ -42,63 +42,6 @@ namespace {
|
|||
//
|
||||
constexpr size_t MAX_SVE_WIDTH = 2048;
|
||||
|
||||
constexpr uint8_t SVE_LANES_8[MAX_SVE_WIDTH / 8] = {
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
|
||||
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
|
||||
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20,
|
||||
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B,
|
||||
0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36,
|
||||
0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
|
||||
|
||||
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A,
|
||||
0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55,
|
||||
0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60,
|
||||
0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B,
|
||||
0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
|
||||
0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F,
|
||||
|
||||
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8A,
|
||||
0x8B, 0x8C, 0x8D, 0x8E, 0x8F, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95,
|
||||
0x96, 0x97, 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F, 0xA0,
|
||||
0xA1, 0xA2, 0xA3, 0xA4, 0xA5, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xAB,
|
||||
0xAC, 0xAD, 0xAE, 0xAF, 0xB0, 0xB1, 0xB2, 0xB3, 0xB4, 0xB5, 0xB6,
|
||||
0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBC, 0xBD, 0xBE, 0xBF,
|
||||
|
||||
0xC0, 0xC1, 0xC2, 0xC3, 0xC4, 0xC5, 0xC6, 0xC7, 0xC8, 0xC9, 0xCA,
|
||||
0xCB, 0xCC, 0xCD, 0xCE, 0xCF, 0xD0, 0xD1, 0xD2, 0xD3, 0xD4, 0xD5,
|
||||
0xD6, 0xD7, 0xD8, 0xD9, 0xDA, 0xDB, 0xDC, 0xDD, 0xDE, 0xDF, 0xE0,
|
||||
0xE1, 0xE2, 0xE3, 0xE4, 0xE5, 0xE6, 0xE7, 0xE8, 0xE9, 0xEA, 0xEB,
|
||||
0xEC, 0xED, 0xEE, 0xEF, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6,
|
||||
0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF};
|
||||
|
||||
constexpr uint16_t SVE_LANES_16[MAX_SVE_WIDTH / 16] = {
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
|
||||
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
|
||||
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20,
|
||||
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B,
|
||||
0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36,
|
||||
0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
|
||||
|
||||
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A,
|
||||
0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55,
|
||||
0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60,
|
||||
0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B,
|
||||
0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
|
||||
0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F};
|
||||
|
||||
constexpr uint32_t SVE_LANES_32[MAX_SVE_WIDTH / 32] = {
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
|
||||
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
|
||||
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20,
|
||||
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B,
|
||||
0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36,
|
||||
0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F};
|
||||
|
||||
constexpr uint64_t SVE_LANES_64[MAX_SVE_WIDTH / 64] = {
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
|
||||
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
|
||||
0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F};
|
||||
|
||||
/*
|
||||
// debugging facilities
|
||||
|
||||
|
@ -131,179 +74,28 @@ void print_svuint8_t(const svuint8_t value) {
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// todo: replace with pext whenever available
|
||||
|
||||
// generate 16-bit bitmask from 8 serialized 16-bit svbool_t values
|
||||
void
|
||||
write_bitmask_16_8x(uint8_t* const __restrict res_u8,
|
||||
const svbool_t pred_op,
|
||||
const svbool_t pred_write,
|
||||
const uint8_t* const __restrict pred_buf) {
|
||||
// perform parallel pext
|
||||
// 2048b -> 32 bytes mask -> 256 bytes total, 128 uint16_t values
|
||||
// 512b -> 8 bytes mask -> 64 bytes total, 32 uint16_t values
|
||||
// 256b -> 4 bytes mask -> 32 bytes total, 16 uint16_t values
|
||||
// 128b -> 2 bytes mask -> 16 bytes total, 8 uint16_t values
|
||||
|
||||
// this code does reduction of 16-bit 0b0A0B0C0D0E0F0G0H words into
|
||||
// uint8_t values 0bABCDEFGH, then writes ones to the memory
|
||||
|
||||
// we need to operate in uint8_t
|
||||
const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf);
|
||||
|
||||
const svuint8_t mask_04_8b = svand_n_u8_z(pred_op, mask_8b, 0x01);
|
||||
const svuint8_t mask_15_8b = svand_n_u8_z(pred_op, mask_8b, 0x04);
|
||||
const svuint8_t mask_15s_8b = svlsr_n_u8_z(pred_op, mask_15_8b, 1);
|
||||
const svuint8_t mask_26_8b = svand_n_u8_z(pred_op, mask_8b, 0x10);
|
||||
const svuint8_t mask_26s_8b = svlsr_n_u8_z(pred_op, mask_26_8b, 2);
|
||||
const svuint8_t mask_37_8b = svand_n_u8_z(pred_op, mask_8b, 0x40);
|
||||
const svuint8_t mask_37s_8b = svlsr_n_u8_z(pred_op, mask_37_8b, 3);
|
||||
|
||||
const svuint8_t mask_0347_8b = svorr_u8_z(pred_op, mask_04_8b, mask_37s_8b);
|
||||
const svuint8_t mask_1256_8b =
|
||||
svorr_u8_z(pred_op, mask_15s_8b, mask_26s_8b);
|
||||
const svuint8_t mask_cmb_8b =
|
||||
svorr_u8_z(pred_op, mask_0347_8b, mask_1256_8b);
|
||||
|
||||
//
|
||||
const svuint16_t shifts_16b = svdup_u16(0x0400UL);
|
||||
const svuint8_t shifts_8b = svreinterpret_u8_u16(shifts_16b);
|
||||
const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b);
|
||||
|
||||
const svuint8_t zero_8b = svdup_n_u8(0);
|
||||
|
||||
const svuint8_t shifted_8b_m3 =
|
||||
svorr_u8_z(pred_op,
|
||||
svuzp1_u8(shifted_8b_m0, zero_8b),
|
||||
svuzp2_u8(shifted_8b_m0, zero_8b));
|
||||
|
||||
// write a finished bitmask
|
||||
svst1_u8(pred_write, res_u8, shifted_8b_m3);
|
||||
}
|
||||
|
||||
// generate 32-bit bitmask from 8 serialized 32-bit svbool_t values
|
||||
void
|
||||
write_bitmask_32_8x(uint8_t* const __restrict res_u8,
|
||||
const svbool_t pred_op,
|
||||
const svbool_t pred_write,
|
||||
const uint8_t* const __restrict pred_buf) {
|
||||
// perform parallel pext
|
||||
// 2048b -> 32 bytes mask -> 256 bytes total, 64 uint32_t values
|
||||
// 512b -> 8 bytes mask -> 64 bytes total, 16 uint32_t values
|
||||
// 256b -> 4 bytes mask -> 32 bytes total, 8 uint32_t values
|
||||
// 128b -> 2 bytes mask -> 16 bytes total, 4 uint32_t values
|
||||
|
||||
// this code does reduction of 32-bit 0b000A000B000C000D... dwords into
|
||||
// uint8_t values 0bABCDEFGH, then writes ones to the memory
|
||||
|
||||
// we need to operate in uint8_t
|
||||
const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf);
|
||||
|
||||
const svuint8_t mask_024_8b = svand_n_u8_z(pred_op, mask_8b, 0x01);
|
||||
const svuint8_t mask_135s_8b = svlsr_n_u8_z(pred_op, mask_8b, 3);
|
||||
const svuint8_t mask_cmb_8b =
|
||||
svorr_u8_z(pred_op, mask_024_8b, mask_135s_8b);
|
||||
|
||||
//
|
||||
const svuint32_t shifts_32b = svdup_u32(0x06040200UL);
|
||||
const svuint8_t shifts_8b = svreinterpret_u8_u32(shifts_32b);
|
||||
const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b);
|
||||
|
||||
const svuint8_t zero_8b = svdup_n_u8(0);
|
||||
|
||||
const svuint8_t shifted_8b_m2 =
|
||||
svorr_u8_z(pred_op,
|
||||
svuzp1_u8(shifted_8b_m0, zero_8b),
|
||||
svuzp2_u8(shifted_8b_m0, zero_8b));
|
||||
const svuint8_t shifted_8b_m3 =
|
||||
svorr_u8_z(pred_op,
|
||||
svuzp1_u8(shifted_8b_m2, zero_8b),
|
||||
svuzp2_u8(shifted_8b_m2, zero_8b));
|
||||
|
||||
// write a finished bitmask
|
||||
svst1_u8(pred_write, res_u8, shifted_8b_m3);
|
||||
}
|
||||
|
||||
// generate 64-bit bitmask from 8 serialized 64-bit svbool_t values
|
||||
void
|
||||
write_bitmask_64_8x(uint8_t* const __restrict res_u8,
|
||||
const svbool_t pred_op,
|
||||
const svbool_t pred_write,
|
||||
const uint8_t* const __restrict pred_buf) {
|
||||
// perform parallel pext
|
||||
// 2048b -> 32 bytes mask -> 256 bytes total, 32 uint64_t values
|
||||
// 512b -> 8 bytes mask -> 64 bytes total, 4 uint64_t values
|
||||
// 256b -> 4 bytes mask -> 32 bytes total, 2 uint64_t values
|
||||
// 128b -> 2 bytes mask -> 16 bytes total, 1 uint64_t values
|
||||
|
||||
// this code does reduction of 64-bit 0b0000000A0000000B... qwords into
|
||||
// uint8_t values 0bABCDEFGH, then writes ones to the memory
|
||||
|
||||
// we need to operate in uint8_t
|
||||
const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf);
|
||||
const svuint64_t shifts_64b = svdup_u64(0x706050403020100ULL);
|
||||
const svuint8_t shifts_8b = svreinterpret_u8_u64(shifts_64b);
|
||||
const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_8b, shifts_8b);
|
||||
|
||||
const svuint8_t zero_8b = svdup_n_u8(0);
|
||||
|
||||
const svuint8_t shifted_8b_m1 =
|
||||
svorr_u8_z(pred_op,
|
||||
svuzp1_u8(shifted_8b_m0, zero_8b),
|
||||
svuzp2_u8(shifted_8b_m0, zero_8b));
|
||||
const svuint8_t shifted_8b_m2 =
|
||||
svorr_u8_z(pred_op,
|
||||
svuzp1_u8(shifted_8b_m1, zero_8b),
|
||||
svuzp2_u8(shifted_8b_m1, zero_8b));
|
||||
const svuint8_t shifted_8b_m3 =
|
||||
svorr_u8_z(pred_op,
|
||||
svuzp1_u8(shifted_8b_m2, zero_8b),
|
||||
svuzp2_u8(shifted_8b_m2, zero_8b));
|
||||
|
||||
// write a finished bitmask
|
||||
svst1_u8(pred_write, res_u8, shifted_8b_m3);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
inline svbool_t
|
||||
get_pred_op_8(const size_t n_elements) {
|
||||
const svbool_t pred_all_8 = svptrue_b8();
|
||||
const svuint8_t lanes_8 = svld1_u8(pred_all_8, SVE_LANES_8);
|
||||
const svuint8_t leftovers_op = svdup_n_u8(n_elements);
|
||||
const svbool_t pred_op = svcmpgt_u8(pred_all_8, leftovers_op, lanes_8);
|
||||
return pred_op;
|
||||
return svwhilelt_b8(uint32_t(0), uint32_t(n_elements));
|
||||
}
|
||||
|
||||
//
|
||||
inline svbool_t
|
||||
get_pred_op_16(const size_t n_elements) {
|
||||
const svbool_t pred_all_16 = svptrue_b16();
|
||||
const svuint16_t lanes_16 = svld1_u16(pred_all_16, SVE_LANES_16);
|
||||
const svuint16_t leftovers_op = svdup_n_u16(n_elements);
|
||||
const svbool_t pred_op = svcmpgt_u16(pred_all_16, leftovers_op, lanes_16);
|
||||
return pred_op;
|
||||
return svwhilelt_b16(uint32_t(0), uint32_t(n_elements));
|
||||
}
|
||||
|
||||
//
|
||||
inline svbool_t
|
||||
get_pred_op_32(const size_t n_elements) {
|
||||
const svbool_t pred_all_32 = svptrue_b32();
|
||||
const svuint32_t lanes_32 = svld1_u32(pred_all_32, SVE_LANES_32);
|
||||
const svuint32_t leftovers_op = svdup_n_u32(n_elements);
|
||||
const svbool_t pred_op = svcmpgt_u32(pred_all_32, leftovers_op, lanes_32);
|
||||
return pred_op;
|
||||
return svwhilelt_b32(uint32_t(0), uint32_t(n_elements));
|
||||
}
|
||||
|
||||
//
|
||||
inline svbool_t
|
||||
get_pred_op_64(const size_t n_elements) {
|
||||
const svbool_t pred_all_64 = svptrue_b64();
|
||||
const svuint64_t lanes_64 = svld1_u64(pred_all_64, SVE_LANES_64);
|
||||
const svuint64_t leftovers_op = svdup_n_u64(n_elements);
|
||||
const svbool_t pred_op = svcmpgt_u64(pred_all_64, leftovers_op, lanes_64);
|
||||
return pred_op;
|
||||
return svwhilelt_b64(uint32_t(0), uint32_t(n_elements));
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -579,7 +371,7 @@ struct SVEVector<int8_t> {
|
|||
using sve_type = svint8_t;
|
||||
|
||||
// measured in the number of elements that an SVE register can hold
|
||||
static inline size_t
|
||||
static inline uint64_t
|
||||
width() {
|
||||
return svcntb();
|
||||
}
|
||||
|
@ -606,7 +398,7 @@ struct SVEVector<int16_t> {
|
|||
using sve_type = svint16_t;
|
||||
|
||||
// measured in the number of elements that an SVE register can hold
|
||||
static inline size_t
|
||||
static inline uint64_t
|
||||
width() {
|
||||
return svcnth();
|
||||
}
|
||||
|
@ -633,7 +425,7 @@ struct SVEVector<int32_t> {
|
|||
using sve_type = svint32_t;
|
||||
|
||||
// measured in the number of elements that an SVE register can hold
|
||||
static inline size_t
|
||||
static inline uint64_t
|
||||
width() {
|
||||
return svcntw();
|
||||
}
|
||||
|
@ -660,7 +452,7 @@ struct SVEVector<int64_t> {
|
|||
using sve_type = svint64_t;
|
||||
|
||||
// measured in the number of elements that an SVE register can hold
|
||||
static inline size_t
|
||||
static inline uint64_t
|
||||
width() {
|
||||
return svcntd();
|
||||
}
|
||||
|
@ -687,7 +479,7 @@ struct SVEVector<float> {
|
|||
using sve_type = svfloat32_t;
|
||||
|
||||
// measured in the number of elements that an SVE register can hold
|
||||
static inline size_t
|
||||
static inline uint64_t
|
||||
width() {
|
||||
return svcntw();
|
||||
}
|
||||
|
@ -714,7 +506,7 @@ struct SVEVector<double> {
|
|||
using sve_type = svfloat64_t;
|
||||
|
||||
// measured in the number of elements that an SVE register can hold
|
||||
static inline size_t
|
||||
static inline uint64_t
|
||||
width() {
|
||||
return svcntd();
|
||||
}
|
||||
|
@ -737,159 +529,262 @@ struct SVEVector<double> {
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// an interesting discussion here:
|
||||
// https://stackoverflow.com/questions/77834169/what-is-a-fast-fallback-algorithm-which-emulates-pdep-and-pext-in-software
|
||||
|
||||
// SVE2 has bitperm, which contains the implementation of pext
|
||||
|
||||
// todo: replace with pext whenever available
|
||||
|
||||
//
|
||||
// NBYTES is the size of the underlying datatype in bytes.
|
||||
// So, for example, for i8/u8 use 1, for i64/u64/f64 use 8/
|
||||
template <size_t NBYTES>
|
||||
struct MaskHelper {};
|
||||
|
||||
template <>
|
||||
struct MaskHelper<1> {
|
||||
static inline void
|
||||
write(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred0,
|
||||
const svbool_t pred1,
|
||||
const svbool_t pred2,
|
||||
const svbool_t pred3,
|
||||
const svbool_t pred4,
|
||||
const svbool_t pred5,
|
||||
const svbool_t pred6,
|
||||
const svbool_t pred7) {
|
||||
const size_t sve_width = svcntb();
|
||||
if (sve_width == 8 * sve_width) {
|
||||
// perform a full write
|
||||
*((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred0;
|
||||
*((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred1;
|
||||
*((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred2;
|
||||
*((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred3;
|
||||
*((svbool_t*)(bitmask + 4 * sve_width / 8)) = pred4;
|
||||
*((svbool_t*)(bitmask + 5 * sve_width / 8)) = pred5;
|
||||
*((svbool_t*)(bitmask + 6 * sve_width / 8)) = pred6;
|
||||
*((svbool_t*)(bitmask + 7 * sve_width / 8)) = pred7;
|
||||
} else {
|
||||
// perform a partial write
|
||||
write_full(uint8_t* const __restrict bitmask,
|
||||
const svbool_t pred0,
|
||||
const svbool_t pred1,
|
||||
const svbool_t pred2,
|
||||
const svbool_t pred3,
|
||||
const svbool_t pred4,
|
||||
const svbool_t pred5,
|
||||
const svbool_t pred6,
|
||||
const svbool_t pred7) {
|
||||
const uint64_t sve_width = svcntb();
|
||||
|
||||
// this is the buffer for the maximum possible case of 2048 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
|
||||
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred0;
|
||||
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred1;
|
||||
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred2;
|
||||
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred3;
|
||||
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 8)) = pred4;
|
||||
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 8)) = pred5;
|
||||
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 8)) = pred6;
|
||||
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 8)) = pred7;
|
||||
// perform a full write
|
||||
*((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred0;
|
||||
*((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred1;
|
||||
*((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred2;
|
||||
*((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred3;
|
||||
*((svbool_t*)(bitmask + 4 * sve_width / 8)) = pred4;
|
||||
*((svbool_t*)(bitmask + 5 * sve_width / 8)) = pred5;
|
||||
*((svbool_t*)(bitmask + 6 * sve_width / 8)) = pred6;
|
||||
*((svbool_t*)(bitmask + 7 * sve_width / 8)) = pred7;
|
||||
}
|
||||
|
||||
// make the write mask
|
||||
const svbool_t pred_write = get_pred_op_8(size / 8);
|
||||
static inline void
|
||||
write_partial(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred_0,
|
||||
const svbool_t pred_1,
|
||||
const svbool_t pred_2,
|
||||
const svbool_t pred_3,
|
||||
const svbool_t pred_4,
|
||||
const svbool_t pred_5,
|
||||
const svbool_t pred_6,
|
||||
const svbool_t pred_7) {
|
||||
const uint64_t sve_width = svcntb();
|
||||
|
||||
// load the buffer
|
||||
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
|
||||
// write it to the bitmask
|
||||
svst1_u8(pred_write, bitmask, mask_u8);
|
||||
}
|
||||
// perform a partial write
|
||||
|
||||
// this is a temporary buffer for the maximum possible case of 2048 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
|
||||
// write to the temporary buffer
|
||||
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_0;
|
||||
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_1;
|
||||
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred_2;
|
||||
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred_3;
|
||||
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 8)) = pred_4;
|
||||
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 8)) = pred_5;
|
||||
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 8)) = pred_6;
|
||||
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 8)) = pred_7;
|
||||
|
||||
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
|
||||
const svbool_t pred_write =
|
||||
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
|
||||
|
||||
// load the buffer
|
||||
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
|
||||
// write it to the bitmask
|
||||
svst1_u8(pred_write, bitmask, mask_u8);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaskHelper<2> {
|
||||
static inline void
|
||||
write(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred0,
|
||||
const svbool_t pred1,
|
||||
const svbool_t pred2,
|
||||
const svbool_t pred3,
|
||||
const svbool_t pred4,
|
||||
const svbool_t pred5,
|
||||
const svbool_t pred6,
|
||||
const svbool_t pred7) {
|
||||
const size_t sve_width = svcnth();
|
||||
write_full(uint8_t* const __restrict bitmask,
|
||||
const svbool_t pred_0,
|
||||
const svbool_t pred_1,
|
||||
const svbool_t pred_2,
|
||||
const svbool_t pred_3,
|
||||
const svbool_t pred_4,
|
||||
const svbool_t pred_5,
|
||||
const svbool_t pred_6,
|
||||
const svbool_t pred_7) {
|
||||
const uint64_t sve_width = svcntb();
|
||||
|
||||
// this is the buffer for the maximum possible case of 2048 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
|
||||
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 4)) = pred0;
|
||||
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 4)) = pred1;
|
||||
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 4)) = pred2;
|
||||
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 4)) = pred3;
|
||||
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 4)) = pred4;
|
||||
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 4)) = pred5;
|
||||
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 4)) = pred6;
|
||||
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 4)) = pred7;
|
||||
// compact predicates
|
||||
const svbool_t pred_01 = svuzp1_b8(pred_0, pred_1);
|
||||
const svbool_t pred_23 = svuzp1_b8(pred_2, pred_3);
|
||||
const svbool_t pred_45 = svuzp1_b8(pred_4, pred_5);
|
||||
const svbool_t pred_67 = svuzp1_b8(pred_6, pred_7);
|
||||
|
||||
const svbool_t pred_op_8 = get_pred_op_8(size / 4);
|
||||
const svbool_t pred_write_8 = get_pred_op_8(size / 8);
|
||||
write_bitmask_16_8x(bitmask, pred_op_8, pred_write_8, pred_buf);
|
||||
// perform a full write
|
||||
*((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred_01;
|
||||
*((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred_23;
|
||||
*((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred_45;
|
||||
*((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred_67;
|
||||
}
|
||||
|
||||
static inline void
|
||||
write_partial(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred_0,
|
||||
const svbool_t pred_1,
|
||||
const svbool_t pred_2,
|
||||
const svbool_t pred_3,
|
||||
const svbool_t pred_4,
|
||||
const svbool_t pred_5,
|
||||
const svbool_t pred_6,
|
||||
const svbool_t pred_7) {
|
||||
const uint64_t sve_width = svcntb();
|
||||
|
||||
// compact predicates
|
||||
const svbool_t pred_01 = svuzp1_b8(pred_0, pred_1);
|
||||
const svbool_t pred_23 = svuzp1_b8(pred_2, pred_3);
|
||||
const svbool_t pred_45 = svuzp1_b8(pred_4, pred_5);
|
||||
const svbool_t pred_67 = svuzp1_b8(pred_6, pred_7);
|
||||
|
||||
// this is a temporary buffer for the maximum possible case of 1024 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 16];
|
||||
// write to the temporary buffer
|
||||
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_01;
|
||||
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_23;
|
||||
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred_45;
|
||||
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred_67;
|
||||
|
||||
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
|
||||
const svbool_t pred_write =
|
||||
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
|
||||
|
||||
// load the buffer
|
||||
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
|
||||
// write it to the bitmask
|
||||
svst1_u8(pred_write, bitmask, mask_u8);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaskHelper<4> {
|
||||
static inline void
|
||||
write(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred0,
|
||||
const svbool_t pred1,
|
||||
const svbool_t pred2,
|
||||
const svbool_t pred3,
|
||||
const svbool_t pred4,
|
||||
const svbool_t pred5,
|
||||
const svbool_t pred6,
|
||||
const svbool_t pred7) {
|
||||
const size_t sve_width = svcntw();
|
||||
write_full(uint8_t* const __restrict bitmask,
|
||||
const svbool_t pred_0,
|
||||
const svbool_t pred_1,
|
||||
const svbool_t pred_2,
|
||||
const svbool_t pred_3,
|
||||
const svbool_t pred_4,
|
||||
const svbool_t pred_5,
|
||||
const svbool_t pred_6,
|
||||
const svbool_t pred_7) {
|
||||
const uint64_t sve_width = svcntb();
|
||||
|
||||
// this is the buffer for the maximum possible case of 2048 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
|
||||
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 2)) = pred0;
|
||||
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 2)) = pred1;
|
||||
*((volatile svbool_t*)(pred_buf + 2 * sve_width / 2)) = pred2;
|
||||
*((volatile svbool_t*)(pred_buf + 3 * sve_width / 2)) = pred3;
|
||||
*((volatile svbool_t*)(pred_buf + 4 * sve_width / 2)) = pred4;
|
||||
*((volatile svbool_t*)(pred_buf + 5 * sve_width / 2)) = pred5;
|
||||
*((volatile svbool_t*)(pred_buf + 6 * sve_width / 2)) = pred6;
|
||||
*((volatile svbool_t*)(pred_buf + 7 * sve_width / 2)) = pred7;
|
||||
// compact predicates
|
||||
const svbool_t pred_01 = svuzp1_b16(pred_0, pred_1);
|
||||
const svbool_t pred_23 = svuzp1_b16(pred_2, pred_3);
|
||||
const svbool_t pred_45 = svuzp1_b16(pred_4, pred_5);
|
||||
const svbool_t pred_67 = svuzp1_b16(pred_6, pred_7);
|
||||
const svbool_t pred_0123 = svuzp1_b8(pred_01, pred_23);
|
||||
const svbool_t pred_4567 = svuzp1_b8(pred_45, pred_67);
|
||||
|
||||
const svbool_t pred_op_8 = get_pred_op_8(size / 2);
|
||||
const svbool_t pred_write_8 = get_pred_op_8(size / 8);
|
||||
write_bitmask_32_8x(bitmask, pred_op_8, pred_write_8, pred_buf);
|
||||
// perform a full write
|
||||
*((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred_0123;
|
||||
*((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred_4567;
|
||||
}
|
||||
|
||||
static inline void
|
||||
write_partial(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred_0,
|
||||
const svbool_t pred_1,
|
||||
const svbool_t pred_2,
|
||||
const svbool_t pred_3,
|
||||
const svbool_t pred_4,
|
||||
const svbool_t pred_5,
|
||||
const svbool_t pred_6,
|
||||
const svbool_t pred_7) {
|
||||
const uint64_t sve_width = svcntb();
|
||||
|
||||
// compact predicates
|
||||
const svbool_t pred_01 = svuzp1_b16(pred_0, pred_1);
|
||||
const svbool_t pred_23 = svuzp1_b16(pred_2, pred_3);
|
||||
const svbool_t pred_45 = svuzp1_b16(pred_4, pred_5);
|
||||
const svbool_t pred_67 = svuzp1_b16(pred_6, pred_7);
|
||||
const svbool_t pred_0123 = svuzp1_b8(pred_01, pred_23);
|
||||
const svbool_t pred_4567 = svuzp1_b8(pred_45, pred_67);
|
||||
|
||||
// this is a temporary buffer for the maximum possible case of 512 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 32];
|
||||
// write to the temporary buffer
|
||||
*((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_0123;
|
||||
*((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_4567;
|
||||
|
||||
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
|
||||
const svbool_t pred_write =
|
||||
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
|
||||
|
||||
// load the buffer
|
||||
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
|
||||
// write it to the bitmask
|
||||
svst1_u8(pred_write, bitmask, mask_u8);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaskHelper<8> {
|
||||
static inline void
|
||||
write(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred0,
|
||||
const svbool_t pred1,
|
||||
const svbool_t pred2,
|
||||
const svbool_t pred3,
|
||||
const svbool_t pred4,
|
||||
const svbool_t pred5,
|
||||
const svbool_t pred6,
|
||||
const svbool_t pred7) {
|
||||
const size_t sve_width = svcntd();
|
||||
write_full(uint8_t* const __restrict bitmask,
|
||||
const svbool_t pred_0,
|
||||
const svbool_t pred_1,
|
||||
const svbool_t pred_2,
|
||||
const svbool_t pred_3,
|
||||
const svbool_t pred_4,
|
||||
const svbool_t pred_5,
|
||||
const svbool_t pred_6,
|
||||
const svbool_t pred_7) {
|
||||
// compact predicates
|
||||
const svbool_t pred_01 = svuzp1_b32(pred_0, pred_1);
|
||||
const svbool_t pred_23 = svuzp1_b32(pred_2, pred_3);
|
||||
const svbool_t pred_45 = svuzp1_b32(pred_4, pred_5);
|
||||
const svbool_t pred_67 = svuzp1_b32(pred_6, pred_7);
|
||||
const svbool_t pred_0123 = svuzp1_b16(pred_01, pred_23);
|
||||
const svbool_t pred_4567 = svuzp1_b16(pred_45, pred_67);
|
||||
const svbool_t pred_01234567 = svuzp1_b8(pred_0123, pred_4567);
|
||||
|
||||
// this is the buffer for the maximum possible case of 2048 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 8];
|
||||
*((volatile svbool_t*)(pred_buf + 0 * sve_width)) = pred0;
|
||||
*((volatile svbool_t*)(pred_buf + 1 * sve_width)) = pred1;
|
||||
*((volatile svbool_t*)(pred_buf + 2 * sve_width)) = pred2;
|
||||
*((volatile svbool_t*)(pred_buf + 3 * sve_width)) = pred3;
|
||||
*((volatile svbool_t*)(pred_buf + 4 * sve_width)) = pred4;
|
||||
*((volatile svbool_t*)(pred_buf + 5 * sve_width)) = pred5;
|
||||
*((volatile svbool_t*)(pred_buf + 6 * sve_width)) = pred6;
|
||||
*((volatile svbool_t*)(pred_buf + 7 * sve_width)) = pred7;
|
||||
// perform a full write
|
||||
*((svbool_t*)bitmask) = pred_01234567;
|
||||
}
|
||||
|
||||
const svbool_t pred_op_8 = get_pred_op_8(size / 1);
|
||||
const svbool_t pred_write_8 = get_pred_op_8(size / 8);
|
||||
write_bitmask_64_8x(bitmask, pred_op_8, pred_write_8, pred_buf);
|
||||
static inline void
|
||||
write_partial(uint8_t* const __restrict bitmask,
|
||||
const size_t size,
|
||||
const svbool_t pred_0,
|
||||
const svbool_t pred_1,
|
||||
const svbool_t pred_2,
|
||||
const svbool_t pred_3,
|
||||
const svbool_t pred_4,
|
||||
const svbool_t pred_5,
|
||||
const svbool_t pred_6,
|
||||
const svbool_t pred_7) {
|
||||
// compact predicates
|
||||
const svbool_t pred_01 = svuzp1_b32(pred_0, pred_1);
|
||||
const svbool_t pred_23 = svuzp1_b32(pred_2, pred_3);
|
||||
const svbool_t pred_45 = svuzp1_b32(pred_4, pred_5);
|
||||
const svbool_t pred_67 = svuzp1_b32(pred_6, pred_7);
|
||||
const svbool_t pred_0123 = svuzp1_b16(pred_01, pred_23);
|
||||
const svbool_t pred_4567 = svuzp1_b16(pred_45, pred_67);
|
||||
const svbool_t pred_01234567 = svuzp1_b8(pred_0123, pred_4567);
|
||||
|
||||
// this is a temporary buffer for the maximum possible case of 256 bits
|
||||
uint8_t pred_buf[MAX_SVE_WIDTH / 64];
|
||||
// write to the temporary buffer
|
||||
*((volatile svbool_t*)(pred_buf)) = pred_01234567;
|
||||
|
||||
// make the write mask. (size % 8) == 0 is guaranteed by the caller.
|
||||
const svbool_t pred_write =
|
||||
svwhilelt_b8(uint32_t(0), uint32_t(size / 8));
|
||||
|
||||
// load the buffer
|
||||
const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf);
|
||||
// write it to the bitmask
|
||||
svst1_u8(pred_write, bitmask, mask_u8);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -924,16 +819,8 @@ op_mask_helper(uint8_t* const __restrict res_u8, const size_t size, Func func) {
|
|||
const svbool_t cmp6 = func(pred_all, i + 6 * sve_width);
|
||||
const svbool_t cmp7 = func(pred_all, i + 7 * sve_width);
|
||||
|
||||
MaskHelper<sizeof(T)>::write(res_u8 + i / 8,
|
||||
sve_width * 8,
|
||||
cmp0,
|
||||
cmp1,
|
||||
cmp2,
|
||||
cmp3,
|
||||
cmp4,
|
||||
cmp5,
|
||||
cmp6,
|
||||
cmp7);
|
||||
MaskHelper<sizeof(T)>::write_full(
|
||||
res_u8 + i / 8, cmp0, cmp1, cmp2, cmp3, cmp4, cmp5, cmp6, cmp7);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -985,16 +872,16 @@ op_mask_helper(uint8_t* const __restrict res_u8, const size_t size, Func func) {
|
|||
cmp7 = func(get_partial_pred(7), size_sve8 + 7 * sve_width);
|
||||
}
|
||||
|
||||
MaskHelper<sizeof(T)>::write(res_u8 + size_sve8 / 8,
|
||||
size - size_sve8,
|
||||
cmp0,
|
||||
cmp1,
|
||||
cmp2,
|
||||
cmp3,
|
||||
cmp4,
|
||||
cmp5,
|
||||
cmp6,
|
||||
cmp7);
|
||||
MaskHelper<sizeof(T)>::write_partial(res_u8 + size_sve8 / 8,
|
||||
size - size_sve8,
|
||||
cmp0,
|
||||
cmp1,
|
||||
cmp2,
|
||||
cmp3,
|
||||
cmp4,
|
||||
cmp5,
|
||||
cmp6,
|
||||
cmp7);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
Loading…
Reference in New Issue