@@ -164,7 +164,7 @@ def matmul_kernel(
164164 # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
165165 # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
166166 # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
167- trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
167+ trun_mask = ~ tl .cast ((1 << chunk_trun_bits ) - 1 , tl .uint32 )
168168 round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
169169 ## ---------------------------------------------------------
170170
@@ -386,7 +386,7 @@ def matmul_kernel_DABC(
386386 # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
387387 # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
388388 # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
389- trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
389+ trun_mask = ~ tl .cast ((1 << chunk_trun_bits ) - 1 , tl .uint32 )
390390 round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
391391 ## ---------------------------------------------------------
392392
@@ -448,10 +448,11 @@ def round_and_trun(x, round_bit, trun_mask):
448448@triton .jit
449449def fp32_clamp_to_dl16 (x ):
450450 """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
451- # 1. rounding: add round bit to full uint representation , zero out last 13 bits, back to float
451+ # 1. rounding: add round bit, zero out last 13 bits, back to float
452452 x = libdevice .float_as_uint (x )
453453 round_bit = 1 << (23 - 9 - 1 )
454- x = libdevice .uint_as_float (((x + round_bit ) >> 13 ) << 13 )
454+ mask_13x0 = ~ tl .cast ((1 << 13 ) - 1 , tl .uint32 )
455+ x = libdevice .uint_as_float ((x + round_bit ) & mask_13x0 )
455456
456457 # 2. clamp to min/max:
457458 # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf
0 commit comments