diff --git a/internal/core/src/bitset/detail/platform/arm/sve-impl.h b/internal/core/src/bitset/detail/platform/arm/sve-impl.h index 18433402d0..dfc84f2824 100644 --- a/internal/core/src/bitset/detail/platform/arm/sve-impl.h +++ b/internal/core/src/bitset/detail/platform/arm/sve-impl.h @@ -42,63 +42,6 @@ namespace { // constexpr size_t MAX_SVE_WIDTH = 2048; -constexpr uint8_t SVE_LANES_8[MAX_SVE_WIDTH / 8] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, - 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, - 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, - 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, - 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, - 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, - - 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, - 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, - 0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60, - 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B, - 0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, - 0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F, - - 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8A, - 0x8B, 0x8C, 0x8D, 0x8E, 0x8F, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, - 0x96, 0x97, 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F, 0xA0, - 0xA1, 0xA2, 0xA3, 0xA4, 0xA5, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, - 0xAC, 0xAD, 0xAE, 0xAF, 0xB0, 0xB1, 0xB2, 0xB3, 0xB4, 0xB5, 0xB6, - 0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBC, 0xBD, 0xBE, 0xBF, - - 0xC0, 0xC1, 0xC2, 0xC3, 0xC4, 0xC5, 0xC6, 0xC7, 0xC8, 0xC9, 0xCA, - 0xCB, 0xCC, 0xCD, 0xCE, 0xCF, 0xD0, 0xD1, 0xD2, 0xD3, 0xD4, 0xD5, - 0xD6, 0xD7, 0xD8, 0xD9, 0xDA, 0xDB, 0xDC, 0xDD, 0xDE, 0xDF, 0xE0, - 0xE1, 0xE2, 0xE3, 0xE4, 0xE5, 0xE6, 0xE7, 0xE8, 0xE9, 0xEA, 0xEB, - 0xEC, 0xED, 0xEE, 0xEF, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, - 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF}; - -constexpr uint16_t SVE_LANES_16[MAX_SVE_WIDTH / 16] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, - 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, - 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, - 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, - 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, - 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, - - 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, - 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, - 0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60, - 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B, - 0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, - 0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F}; - -constexpr uint32_t SVE_LANES_32[MAX_SVE_WIDTH / 32] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, - 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, - 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, - 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, - 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, - 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F}; - -constexpr uint64_t SVE_LANES_64[MAX_SVE_WIDTH / 64] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, - 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, - 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}; - /* // debugging facilities @@ -131,179 +74,28 @@ void print_svuint8_t(const svuint8_t value) { /////////////////////////////////////////////////////////////////////////// -// todo: replace with pext whenever available - -// generate 16-bit bitmask from 8 serialized 16-bit svbool_t values -void -write_bitmask_16_8x(uint8_t* const __restrict res_u8, - const svbool_t pred_op, - const svbool_t pred_write, - const uint8_t* const __restrict pred_buf) { - // perform parallel pext - // 2048b -> 32 bytes mask -> 256 bytes total, 128 uint16_t values - // 512b -> 8 bytes mask -> 64 bytes total, 32 uint16_t values - // 256b -> 4 bytes mask -> 32 bytes total, 16 uint16_t values - // 128b -> 2 bytes mask -> 16 bytes total, 8 uint16_t values - - // this code does reduction of 16-bit 0b0A0B0C0D0E0F0G0H words into - // uint8_t values 0bABCDEFGH, then writes ones to the memory - - // we need to operate in uint8_t - const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf); - - const svuint8_t mask_04_8b = svand_n_u8_z(pred_op, mask_8b, 0x01); - const svuint8_t mask_15_8b = svand_n_u8_z(pred_op, mask_8b, 0x04); - const svuint8_t mask_15s_8b = svlsr_n_u8_z(pred_op, mask_15_8b, 1); - const svuint8_t mask_26_8b = svand_n_u8_z(pred_op, mask_8b, 0x10); - const svuint8_t mask_26s_8b = svlsr_n_u8_z(pred_op, mask_26_8b, 2); - const svuint8_t mask_37_8b = svand_n_u8_z(pred_op, mask_8b, 0x40); - const svuint8_t mask_37s_8b = svlsr_n_u8_z(pred_op, mask_37_8b, 3); - - const svuint8_t mask_0347_8b = svorr_u8_z(pred_op, mask_04_8b, mask_37s_8b); - const svuint8_t mask_1256_8b = - svorr_u8_z(pred_op, mask_15s_8b, mask_26s_8b); - const svuint8_t mask_cmb_8b = - svorr_u8_z(pred_op, mask_0347_8b, mask_1256_8b); - - // - const svuint16_t shifts_16b = svdup_u16(0x0400UL); - const svuint8_t shifts_8b = svreinterpret_u8_u16(shifts_16b); - const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b); - - const svuint8_t zero_8b = svdup_n_u8(0); - - const svuint8_t shifted_8b_m3 = - svorr_u8_z(pred_op, - svuzp1_u8(shifted_8b_m0, zero_8b), - svuzp2_u8(shifted_8b_m0, zero_8b)); - - // write a finished bitmask - svst1_u8(pred_write, res_u8, shifted_8b_m3); -} - -// generate 32-bit bitmask from 8 serialized 32-bit svbool_t values -void -write_bitmask_32_8x(uint8_t* const __restrict res_u8, - const svbool_t pred_op, - const svbool_t pred_write, - const uint8_t* const __restrict pred_buf) { - // perform parallel pext - // 2048b -> 32 bytes mask -> 256 bytes total, 64 uint32_t values - // 512b -> 8 bytes mask -> 64 bytes total, 16 uint32_t values - // 256b -> 4 bytes mask -> 32 bytes total, 8 uint32_t values - // 128b -> 2 bytes mask -> 16 bytes total, 4 uint32_t values - - // this code does reduction of 32-bit 0b000A000B000C000D... dwords into - // uint8_t values 0bABCDEFGH, then writes ones to the memory - - // we need to operate in uint8_t - const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf); - - const svuint8_t mask_024_8b = svand_n_u8_z(pred_op, mask_8b, 0x01); - const svuint8_t mask_135s_8b = svlsr_n_u8_z(pred_op, mask_8b, 3); - const svuint8_t mask_cmb_8b = - svorr_u8_z(pred_op, mask_024_8b, mask_135s_8b); - - // - const svuint32_t shifts_32b = svdup_u32(0x06040200UL); - const svuint8_t shifts_8b = svreinterpret_u8_u32(shifts_32b); - const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b); - - const svuint8_t zero_8b = svdup_n_u8(0); - - const svuint8_t shifted_8b_m2 = - svorr_u8_z(pred_op, - svuzp1_u8(shifted_8b_m0, zero_8b), - svuzp2_u8(shifted_8b_m0, zero_8b)); - const svuint8_t shifted_8b_m3 = - svorr_u8_z(pred_op, - svuzp1_u8(shifted_8b_m2, zero_8b), - svuzp2_u8(shifted_8b_m2, zero_8b)); - - // write a finished bitmask - svst1_u8(pred_write, res_u8, shifted_8b_m3); -} - -// generate 64-bit bitmask from 8 serialized 64-bit svbool_t values -void -write_bitmask_64_8x(uint8_t* const __restrict res_u8, - const svbool_t pred_op, - const svbool_t pred_write, - const uint8_t* const __restrict pred_buf) { - // perform parallel pext - // 2048b -> 32 bytes mask -> 256 bytes total, 32 uint64_t values - // 512b -> 8 bytes mask -> 64 bytes total, 4 uint64_t values - // 256b -> 4 bytes mask -> 32 bytes total, 2 uint64_t values - // 128b -> 2 bytes mask -> 16 bytes total, 1 uint64_t values - - // this code does reduction of 64-bit 0b0000000A0000000B... qwords into - // uint8_t values 0bABCDEFGH, then writes ones to the memory - - // we need to operate in uint8_t - const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf); - const svuint64_t shifts_64b = svdup_u64(0x706050403020100ULL); - const svuint8_t shifts_8b = svreinterpret_u8_u64(shifts_64b); - const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_8b, shifts_8b); - - const svuint8_t zero_8b = svdup_n_u8(0); - - const svuint8_t shifted_8b_m1 = - svorr_u8_z(pred_op, - svuzp1_u8(shifted_8b_m0, zero_8b), - svuzp2_u8(shifted_8b_m0, zero_8b)); - const svuint8_t shifted_8b_m2 = - svorr_u8_z(pred_op, - svuzp1_u8(shifted_8b_m1, zero_8b), - svuzp2_u8(shifted_8b_m1, zero_8b)); - const svuint8_t shifted_8b_m3 = - svorr_u8_z(pred_op, - svuzp1_u8(shifted_8b_m2, zero_8b), - svuzp2_u8(shifted_8b_m2, zero_8b)); - - // write a finished bitmask - svst1_u8(pred_write, res_u8, shifted_8b_m3); -} - -/////////////////////////////////////////////////////////////////////////// - // inline svbool_t get_pred_op_8(const size_t n_elements) { - const svbool_t pred_all_8 = svptrue_b8(); - const svuint8_t lanes_8 = svld1_u8(pred_all_8, SVE_LANES_8); - const svuint8_t leftovers_op = svdup_n_u8(n_elements); - const svbool_t pred_op = svcmpgt_u8(pred_all_8, leftovers_op, lanes_8); - return pred_op; + return svwhilelt_b8(uint32_t(0), uint32_t(n_elements)); } // inline svbool_t get_pred_op_16(const size_t n_elements) { - const svbool_t pred_all_16 = svptrue_b16(); - const svuint16_t lanes_16 = svld1_u16(pred_all_16, SVE_LANES_16); - const svuint16_t leftovers_op = svdup_n_u16(n_elements); - const svbool_t pred_op = svcmpgt_u16(pred_all_16, leftovers_op, lanes_16); - return pred_op; + return svwhilelt_b16(uint32_t(0), uint32_t(n_elements)); } // inline svbool_t get_pred_op_32(const size_t n_elements) { - const svbool_t pred_all_32 = svptrue_b32(); - const svuint32_t lanes_32 = svld1_u32(pred_all_32, SVE_LANES_32); - const svuint32_t leftovers_op = svdup_n_u32(n_elements); - const svbool_t pred_op = svcmpgt_u32(pred_all_32, leftovers_op, lanes_32); - return pred_op; + return svwhilelt_b32(uint32_t(0), uint32_t(n_elements)); } // inline svbool_t get_pred_op_64(const size_t n_elements) { - const svbool_t pred_all_64 = svptrue_b64(); - const svuint64_t lanes_64 = svld1_u64(pred_all_64, SVE_LANES_64); - const svuint64_t leftovers_op = svdup_n_u64(n_elements); - const svbool_t pred_op = svcmpgt_u64(pred_all_64, leftovers_op, lanes_64); - return pred_op; + return svwhilelt_b64(uint32_t(0), uint32_t(n_elements)); } // @@ -579,7 +371,7 @@ struct SVEVector { using sve_type = svint8_t; // measured in the number of elements that an SVE register can hold - static inline size_t + static inline uint64_t width() { return svcntb(); } @@ -606,7 +398,7 @@ struct SVEVector { using sve_type = svint16_t; // measured in the number of elements that an SVE register can hold - static inline size_t + static inline uint64_t width() { return svcnth(); } @@ -633,7 +425,7 @@ struct SVEVector { using sve_type = svint32_t; // measured in the number of elements that an SVE register can hold - static inline size_t + static inline uint64_t width() { return svcntw(); } @@ -660,7 +452,7 @@ struct SVEVector { using sve_type = svint64_t; // measured in the number of elements that an SVE register can hold - static inline size_t + static inline uint64_t width() { return svcntd(); } @@ -687,7 +479,7 @@ struct SVEVector { using sve_type = svfloat32_t; // measured in the number of elements that an SVE register can hold - static inline size_t + static inline uint64_t width() { return svcntw(); } @@ -714,7 +506,7 @@ struct SVEVector { using sve_type = svfloat64_t; // measured in the number of elements that an SVE register can hold - static inline size_t + static inline uint64_t width() { return svcntd(); } @@ -737,159 +529,262 @@ struct SVEVector { /////////////////////////////////////////////////////////////////////////// -// an interesting discussion here: -// https://stackoverflow.com/questions/77834169/what-is-a-fast-fallback-algorithm-which-emulates-pdep-and-pext-in-software - -// SVE2 has bitperm, which contains the implementation of pext - -// todo: replace with pext whenever available - -// +// NBYTES is the size of the underlying datatype in bytes. +// So, for example, for i8/u8 use 1, for i64/u64/f64 use 8/ template struct MaskHelper {}; template <> struct MaskHelper<1> { static inline void - write(uint8_t* const __restrict bitmask, - const size_t size, - const svbool_t pred0, - const svbool_t pred1, - const svbool_t pred2, - const svbool_t pred3, - const svbool_t pred4, - const svbool_t pred5, - const svbool_t pred6, - const svbool_t pred7) { - const size_t sve_width = svcntb(); - if (sve_width == 8 * sve_width) { - // perform a full write - *((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred0; - *((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred1; - *((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred2; - *((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred3; - *((svbool_t*)(bitmask + 4 * sve_width / 8)) = pred4; - *((svbool_t*)(bitmask + 5 * sve_width / 8)) = pred5; - *((svbool_t*)(bitmask + 6 * sve_width / 8)) = pred6; - *((svbool_t*)(bitmask + 7 * sve_width / 8)) = pred7; - } else { - // perform a partial write + write_full(uint8_t* const __restrict bitmask, + const svbool_t pred0, + const svbool_t pred1, + const svbool_t pred2, + const svbool_t pred3, + const svbool_t pred4, + const svbool_t pred5, + const svbool_t pred6, + const svbool_t pred7) { + const uint64_t sve_width = svcntb(); - // this is the buffer for the maximum possible case of 2048 bits - uint8_t pred_buf[MAX_SVE_WIDTH / 8]; - *((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred0; - *((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred1; - *((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred2; - *((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred3; - *((volatile svbool_t*)(pred_buf + 4 * sve_width / 8)) = pred4; - *((volatile svbool_t*)(pred_buf + 5 * sve_width / 8)) = pred5; - *((volatile svbool_t*)(pred_buf + 6 * sve_width / 8)) = pred6; - *((volatile svbool_t*)(pred_buf + 7 * sve_width / 8)) = pred7; + // perform a full write + *((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred0; + *((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred1; + *((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred2; + *((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred3; + *((svbool_t*)(bitmask + 4 * sve_width / 8)) = pred4; + *((svbool_t*)(bitmask + 5 * sve_width / 8)) = pred5; + *((svbool_t*)(bitmask + 6 * sve_width / 8)) = pred6; + *((svbool_t*)(bitmask + 7 * sve_width / 8)) = pred7; + } - // make the write mask - const svbool_t pred_write = get_pred_op_8(size / 8); + static inline void + write_partial(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred_0, + const svbool_t pred_1, + const svbool_t pred_2, + const svbool_t pred_3, + const svbool_t pred_4, + const svbool_t pred_5, + const svbool_t pred_6, + const svbool_t pred_7) { + const uint64_t sve_width = svcntb(); - // load the buffer - const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf); - // write it to the bitmask - svst1_u8(pred_write, bitmask, mask_u8); - } + // perform a partial write + + // this is a temporary buffer for the maximum possible case of 2048 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 8]; + // write to the temporary buffer + *((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_0; + *((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_1; + *((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred_2; + *((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred_3; + *((volatile svbool_t*)(pred_buf + 4 * sve_width / 8)) = pred_4; + *((volatile svbool_t*)(pred_buf + 5 * sve_width / 8)) = pred_5; + *((volatile svbool_t*)(pred_buf + 6 * sve_width / 8)) = pred_6; + *((volatile svbool_t*)(pred_buf + 7 * sve_width / 8)) = pred_7; + + // make the write mask. (size % 8) == 0 is guaranteed by the caller. + const svbool_t pred_write = + svwhilelt_b8(uint32_t(0), uint32_t(size / 8)); + + // load the buffer + const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf); + // write it to the bitmask + svst1_u8(pred_write, bitmask, mask_u8); } }; template <> struct MaskHelper<2> { static inline void - write(uint8_t* const __restrict bitmask, - const size_t size, - const svbool_t pred0, - const svbool_t pred1, - const svbool_t pred2, - const svbool_t pred3, - const svbool_t pred4, - const svbool_t pred5, - const svbool_t pred6, - const svbool_t pred7) { - const size_t sve_width = svcnth(); + write_full(uint8_t* const __restrict bitmask, + const svbool_t pred_0, + const svbool_t pred_1, + const svbool_t pred_2, + const svbool_t pred_3, + const svbool_t pred_4, + const svbool_t pred_5, + const svbool_t pred_6, + const svbool_t pred_7) { + const uint64_t sve_width = svcntb(); - // this is the buffer for the maximum possible case of 2048 bits - uint8_t pred_buf[MAX_SVE_WIDTH / 8]; - *((volatile svbool_t*)(pred_buf + 0 * sve_width / 4)) = pred0; - *((volatile svbool_t*)(pred_buf + 1 * sve_width / 4)) = pred1; - *((volatile svbool_t*)(pred_buf + 2 * sve_width / 4)) = pred2; - *((volatile svbool_t*)(pred_buf + 3 * sve_width / 4)) = pred3; - *((volatile svbool_t*)(pred_buf + 4 * sve_width / 4)) = pred4; - *((volatile svbool_t*)(pred_buf + 5 * sve_width / 4)) = pred5; - *((volatile svbool_t*)(pred_buf + 6 * sve_width / 4)) = pred6; - *((volatile svbool_t*)(pred_buf + 7 * sve_width / 4)) = pred7; + // compact predicates + const svbool_t pred_01 = svuzp1_b8(pred_0, pred_1); + const svbool_t pred_23 = svuzp1_b8(pred_2, pred_3); + const svbool_t pred_45 = svuzp1_b8(pred_4, pred_5); + const svbool_t pred_67 = svuzp1_b8(pred_6, pred_7); - const svbool_t pred_op_8 = get_pred_op_8(size / 4); - const svbool_t pred_write_8 = get_pred_op_8(size / 8); - write_bitmask_16_8x(bitmask, pred_op_8, pred_write_8, pred_buf); + // perform a full write + *((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred_01; + *((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred_23; + *((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred_45; + *((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred_67; + } + + static inline void + write_partial(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred_0, + const svbool_t pred_1, + const svbool_t pred_2, + const svbool_t pred_3, + const svbool_t pred_4, + const svbool_t pred_5, + const svbool_t pred_6, + const svbool_t pred_7) { + const uint64_t sve_width = svcntb(); + + // compact predicates + const svbool_t pred_01 = svuzp1_b8(pred_0, pred_1); + const svbool_t pred_23 = svuzp1_b8(pred_2, pred_3); + const svbool_t pred_45 = svuzp1_b8(pred_4, pred_5); + const svbool_t pred_67 = svuzp1_b8(pred_6, pred_7); + + // this is a temporary buffer for the maximum possible case of 1024 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 16]; + // write to the temporary buffer + *((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_01; + *((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_23; + *((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred_45; + *((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred_67; + + // make the write mask. (size % 8) == 0 is guaranteed by the caller. + const svbool_t pred_write = + svwhilelt_b8(uint32_t(0), uint32_t(size / 8)); + + // load the buffer + const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf); + // write it to the bitmask + svst1_u8(pred_write, bitmask, mask_u8); } }; template <> struct MaskHelper<4> { static inline void - write(uint8_t* const __restrict bitmask, - const size_t size, - const svbool_t pred0, - const svbool_t pred1, - const svbool_t pred2, - const svbool_t pred3, - const svbool_t pred4, - const svbool_t pred5, - const svbool_t pred6, - const svbool_t pred7) { - const size_t sve_width = svcntw(); + write_full(uint8_t* const __restrict bitmask, + const svbool_t pred_0, + const svbool_t pred_1, + const svbool_t pred_2, + const svbool_t pred_3, + const svbool_t pred_4, + const svbool_t pred_5, + const svbool_t pred_6, + const svbool_t pred_7) { + const uint64_t sve_width = svcntb(); - // this is the buffer for the maximum possible case of 2048 bits - uint8_t pred_buf[MAX_SVE_WIDTH / 8]; - *((volatile svbool_t*)(pred_buf + 0 * sve_width / 2)) = pred0; - *((volatile svbool_t*)(pred_buf + 1 * sve_width / 2)) = pred1; - *((volatile svbool_t*)(pred_buf + 2 * sve_width / 2)) = pred2; - *((volatile svbool_t*)(pred_buf + 3 * sve_width / 2)) = pred3; - *((volatile svbool_t*)(pred_buf + 4 * sve_width / 2)) = pred4; - *((volatile svbool_t*)(pred_buf + 5 * sve_width / 2)) = pred5; - *((volatile svbool_t*)(pred_buf + 6 * sve_width / 2)) = pred6; - *((volatile svbool_t*)(pred_buf + 7 * sve_width / 2)) = pred7; + // compact predicates + const svbool_t pred_01 = svuzp1_b16(pred_0, pred_1); + const svbool_t pred_23 = svuzp1_b16(pred_2, pred_3); + const svbool_t pred_45 = svuzp1_b16(pred_4, pred_5); + const svbool_t pred_67 = svuzp1_b16(pred_6, pred_7); + const svbool_t pred_0123 = svuzp1_b8(pred_01, pred_23); + const svbool_t pred_4567 = svuzp1_b8(pred_45, pred_67); - const svbool_t pred_op_8 = get_pred_op_8(size / 2); - const svbool_t pred_write_8 = get_pred_op_8(size / 8); - write_bitmask_32_8x(bitmask, pred_op_8, pred_write_8, pred_buf); + // perform a full write + *((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred_0123; + *((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred_4567; + } + + static inline void + write_partial(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred_0, + const svbool_t pred_1, + const svbool_t pred_2, + const svbool_t pred_3, + const svbool_t pred_4, + const svbool_t pred_5, + const svbool_t pred_6, + const svbool_t pred_7) { + const uint64_t sve_width = svcntb(); + + // compact predicates + const svbool_t pred_01 = svuzp1_b16(pred_0, pred_1); + const svbool_t pred_23 = svuzp1_b16(pred_2, pred_3); + const svbool_t pred_45 = svuzp1_b16(pred_4, pred_5); + const svbool_t pred_67 = svuzp1_b16(pred_6, pred_7); + const svbool_t pred_0123 = svuzp1_b8(pred_01, pred_23); + const svbool_t pred_4567 = svuzp1_b8(pred_45, pred_67); + + // this is a temporary buffer for the maximum possible case of 512 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 32]; + // write to the temporary buffer + *((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred_0123; + *((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred_4567; + + // make the write mask. (size % 8) == 0 is guaranteed by the caller. + const svbool_t pred_write = + svwhilelt_b8(uint32_t(0), uint32_t(size / 8)); + + // load the buffer + const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf); + // write it to the bitmask + svst1_u8(pred_write, bitmask, mask_u8); } }; template <> struct MaskHelper<8> { static inline void - write(uint8_t* const __restrict bitmask, - const size_t size, - const svbool_t pred0, - const svbool_t pred1, - const svbool_t pred2, - const svbool_t pred3, - const svbool_t pred4, - const svbool_t pred5, - const svbool_t pred6, - const svbool_t pred7) { - const size_t sve_width = svcntd(); + write_full(uint8_t* const __restrict bitmask, + const svbool_t pred_0, + const svbool_t pred_1, + const svbool_t pred_2, + const svbool_t pred_3, + const svbool_t pred_4, + const svbool_t pred_5, + const svbool_t pred_6, + const svbool_t pred_7) { + // compact predicates + const svbool_t pred_01 = svuzp1_b32(pred_0, pred_1); + const svbool_t pred_23 = svuzp1_b32(pred_2, pred_3); + const svbool_t pred_45 = svuzp1_b32(pred_4, pred_5); + const svbool_t pred_67 = svuzp1_b32(pred_6, pred_7); + const svbool_t pred_0123 = svuzp1_b16(pred_01, pred_23); + const svbool_t pred_4567 = svuzp1_b16(pred_45, pred_67); + const svbool_t pred_01234567 = svuzp1_b8(pred_0123, pred_4567); - // this is the buffer for the maximum possible case of 2048 bits - uint8_t pred_buf[MAX_SVE_WIDTH / 8]; - *((volatile svbool_t*)(pred_buf + 0 * sve_width)) = pred0; - *((volatile svbool_t*)(pred_buf + 1 * sve_width)) = pred1; - *((volatile svbool_t*)(pred_buf + 2 * sve_width)) = pred2; - *((volatile svbool_t*)(pred_buf + 3 * sve_width)) = pred3; - *((volatile svbool_t*)(pred_buf + 4 * sve_width)) = pred4; - *((volatile svbool_t*)(pred_buf + 5 * sve_width)) = pred5; - *((volatile svbool_t*)(pred_buf + 6 * sve_width)) = pred6; - *((volatile svbool_t*)(pred_buf + 7 * sve_width)) = pred7; + // perform a full write + *((svbool_t*)bitmask) = pred_01234567; + } - const svbool_t pred_op_8 = get_pred_op_8(size / 1); - const svbool_t pred_write_8 = get_pred_op_8(size / 8); - write_bitmask_64_8x(bitmask, pred_op_8, pred_write_8, pred_buf); + static inline void + write_partial(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred_0, + const svbool_t pred_1, + const svbool_t pred_2, + const svbool_t pred_3, + const svbool_t pred_4, + const svbool_t pred_5, + const svbool_t pred_6, + const svbool_t pred_7) { + // compact predicates + const svbool_t pred_01 = svuzp1_b32(pred_0, pred_1); + const svbool_t pred_23 = svuzp1_b32(pred_2, pred_3); + const svbool_t pred_45 = svuzp1_b32(pred_4, pred_5); + const svbool_t pred_67 = svuzp1_b32(pred_6, pred_7); + const svbool_t pred_0123 = svuzp1_b16(pred_01, pred_23); + const svbool_t pred_4567 = svuzp1_b16(pred_45, pred_67); + const svbool_t pred_01234567 = svuzp1_b8(pred_0123, pred_4567); + + // this is a temporary buffer for the maximum possible case of 256 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 64]; + // write to the temporary buffer + *((volatile svbool_t*)(pred_buf)) = pred_01234567; + + // make the write mask. (size % 8) == 0 is guaranteed by the caller. + const svbool_t pred_write = + svwhilelt_b8(uint32_t(0), uint32_t(size / 8)); + + // load the buffer + const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf); + // write it to the bitmask + svst1_u8(pred_write, bitmask, mask_u8); } }; @@ -924,16 +819,8 @@ op_mask_helper(uint8_t* const __restrict res_u8, const size_t size, Func func) { const svbool_t cmp6 = func(pred_all, i + 6 * sve_width); const svbool_t cmp7 = func(pred_all, i + 7 * sve_width); - MaskHelper::write(res_u8 + i / 8, - sve_width * 8, - cmp0, - cmp1, - cmp2, - cmp3, - cmp4, - cmp5, - cmp6, - cmp7); + MaskHelper::write_full( + res_u8 + i / 8, cmp0, cmp1, cmp2, cmp3, cmp4, cmp5, cmp6, cmp7); } } @@ -985,16 +872,16 @@ op_mask_helper(uint8_t* const __restrict res_u8, const size_t size, Func func) { cmp7 = func(get_partial_pred(7), size_sve8 + 7 * sve_width); } - MaskHelper::write(res_u8 + size_sve8 / 8, - size - size_sve8, - cmp0, - cmp1, - cmp2, - cmp3, - cmp4, - cmp5, - cmp6, - cmp7); + MaskHelper::write_partial(res_u8 + size_sve8 / 8, + size - size_sve8, + cmp0, + cmp1, + cmp2, + cmp3, + cmp4, + cmp5, + cmp6, + cmp7); } return true;