@@ -553,58 +553,54 @@ func.func @load_4bit_vector_boundary_case(%mem: memref<4294967295xi4>) -> vector
553553}
554554
555555// CHECK-LABEL: func.func @load_paged_scalar
556- // CHECK-SAME: (%[[mem:.*]]: memref<1x64x8192xf16 >, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
557- func.func @load_paged_scalar (%mem: memref <1 x 64 x 8192 x f16 >, %pagePtr: i64 , %offset: index ) -> f16 attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
556+ // CHECK-SAME: (%[[mem:.*]]: memref<8192xf16 >, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
557+ func.func @load_paged_scalar (%mem: memref <8192 x f16 >, %pagePtr: i64 , %offset: index ) -> f16 attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
558558 %true = arith.constant true
559- // Paged load converts page ptr to llvm.ptr, creates buffer resource, and loads
560559 // CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(16384 : i64) : i64
561560 // CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
562561 // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
563562 // CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
564563 %ret = rock.global_load %mem [%offset ] if %true paged %pagePtr {pageSize = 8192 : i64 }
565- : memref <1 x 64 x 8192 x f16 > -> f16
564+ : memref <8192 x f16 > -> f16
566565 return %ret : f16
567566}
568567
569568// CHECK-LABEL: func.func @load_paged_vector
570- // CHECK-SAME: (%[[mem:.*]]: memref<1x64x8192xf16 >, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
571- func.func @load_paged_vector (%mem: memref <1 x 64 x 8192 x f16 >, %pagePtr: i64 , %offset: index ) -> vector <2 xf16 > attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
569+ // CHECK-SAME: (%[[mem:.*]]: memref<8192xf16 >, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
570+ func.func @load_paged_vector (%mem: memref <8192 x f16 >, %pagePtr: i64 , %offset: index ) -> vector <2 xf16 > attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
572571 %true = arith.constant true
573- // Paged vector load converts page ptr to llvm.ptr, creates buffer resource, and loads
574572 // CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(16384 : i64) : i64
575573 // CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
576574 // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
577575 // CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
578576 %ret = rock.global_load %mem [%offset ] if %true paged %pagePtr {pageSize = 8192 : i64 }
579- : memref <1 x 64 x 8192 x f16 > -> vector <2 xf16 >
577+ : memref <8192 x f16 > -> vector <2 xf16 >
580578 return %ret : vector <2 xf16 >
581579}
582580
583581// CHECK-LABEL: func.func @load_paged_vector_maybe_oob
584- // CHECK-SAME: (%[[mem:.*]]: memref<1x64x8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index, %[[valid:.*]]: i1)
585- func.func @load_paged_vector_maybe_oob (%mem: memref <1 x64 x8192 xf16 >, %pagePtr: i64 , %offset: index , %valid: i1 ) -> vector <2 xf16 > attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
586- // Paged load with validity check - scf.if guards the buffer load
582+ // CHECK-SAME: (%[[mem:.*]]: memref<8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index, %[[valid:.*]]: i1)
583+ func.func @load_paged_vector_maybe_oob (%mem: memref <8192 xf16 >, %pagePtr: i64 , %offset: index , %valid: i1 ) -> vector <2 xf16 > attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
587584 // CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(16384 : i64) : i64
588585 // CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
589586 // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
590587 // CHECK: scf.if %[[valid]]
591588 // CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
592589 %ret = rock.global_load %mem [%offset ] if %valid paged %pagePtr {pageSize = 8192 : i64 }
593- : memref <1 x 64 x 8192 x f16 > -> vector <2 xf16 >
590+ : memref <8192 x f16 > -> vector <2 xf16 >
594591 return %ret : vector <2 xf16 >
595592}
596593
597594// CHECK-LABEL: func.func @load_paged_vector_large_page
598- // CHECK-SAME: (%[[mem:.*]]: memref<1x64x16384xf32 >, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
599- func.func @load_paged_vector_large_page (%mem: memref <1 x 64 x 16384 x f32 >, %pagePtr: i64 , %offset: index ) -> vector <4 xf32 > attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
595+ // CHECK-SAME: (%[[mem:.*]]: memref<16384xf32 >, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
596+ func.func @load_paged_vector_large_page (%mem: memref <16384 x f32 >, %pagePtr: i64 , %offset: index ) -> vector <4 xf32 > attributes {arch = " amdgcn-amd-amdhsa:gfx942" } {
600597 %true = arith.constant true
601- // Larger page size (16384 elements * 4 bytes = 65536 bytes)
602598 // CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(65536 : i64) : i64
603599 // CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
604600 // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
605601 // CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
606602 %ret = rock.global_load %mem [%offset ] if %true paged %pagePtr {pageSize = 16384 : i64 }
607- : memref <1 x 64 x 16384 x f32 > -> vector <4 xf32 >
603+ : memref <16384 x f32 > -> vector <4 xf32 >
608604 return %ret : vector <4 xf32 >
609605}
610606}
0 commit comments