Skip to content

Add swizzled local memory accessors for FP16 and INT8 on the PTX backend#841

Merged
stratika merged 4 commits into
beehive-lab:developfrom
mairooni:feat/swizzled_local_memory
May 28, 2026
Merged

Add swizzled local memory accessors for FP16 and INT8 on the PTX backend#841
stratika merged 4 commits into
beehive-lab:developfrom
mairooni:feat/swizzled_local_memory

Conversation

@mairooni
Copy link
Copy Markdown
Collaborator

Description

This PR adds swizzled shared-memory accessors to the KernelContext for the PTX backend, providing bank-conflict-free layouts for FP16 and INT8 matrix tiles. The new accessors exposed through the KernelContext are:

  • swizzleLoadFp16Stride32 / swizzleStoreFp16Stride32
  • swizzleLoadFp16Stride16 / swizzleStoreFp16Stride16
  • swizzleLoadInt8 / swizzleStoreInt8

Each applies an involutive XOR permutation to the logical (row, col) coordinate before accessing shared memory, so that the resulting access pattern spreads across distinct memory banks instead of colliding. On NVIDIA GPUs shared
memory has 32 banks of 4 bytes each. A naive row-major tile layout causes many threads in a warp to hit the same bank, serializing the access. The XOR rotates each row's bank assignment to avoid this.

The constants follow the CUTLASS Swizzle<>, parameterized per layout:

  • FP16 stride-32: permutation at the 16-byte boundary over groups of 8 rows (a 32-byte row holds 16 fp16 elements).
  • FP16 stride-16: the same shape shifted down one bit, for a 16-byte row (8 fp16 elements), used for narrower, transposed tiles.
  • INT8: identical math to FP16 stride-32. The permutation operates on a 16-byte granularity that is independent of element type, so int8 reuses the stride-32 constants directly. Because int8 elements are one byte, the caller expresses the row stride in bytes (32 for a wide tile, 16 for a narrow one) and the same swizzle serves both.

The layout is primarily intended for staging matrix tiles for future Tensor Core (MMA) work, but is a general bank conflict-avoidance mechanism usable by any kernel with a matching access pattern.

Problem description

Efficient Tensor Core (MMA) matrix multiplication requires shared-memory tiles to be laid out so that warp-level matrix loads do not incur bank conflicts. Currently shared arrays are addressed linearly, which produces heavy bank conflicts for the strided access patterns matrix loads use. This PR adds the swizzled-layout support needed to lay those tiles out in shared memory conflict-free, ahead of the MMA work that will consume it.

Backend/s tested

Mark the backends affected by this PR.

  • OpenCL
  • PTX
  • SPIRV
  • Metal

OS tested

Mark the OS where this PR is tested.

  • Linux
  • OSx
  • Windows

Did you check on FPGAs?

If it is applicable, check your changes on FPGAs.

  • Yes
  • No

How to test the new patch?

The load/store functionality can be verified by running the unittest:
tornado-test -V uk.ac.manchester.tornado.unittests.kernelcontext.local.memory.TestSwizzledLocalArrays

Each kernel in the test was also profiled with Nsight Compute and produces zero shared-memory bank conflicts (l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld/st = 0), with non-zero smsp__inst_executed_op_shared_ld/st, confirming the accesses execute.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds PTX-only swizzled local/shared-memory accessors to KernelContext for FP16 and INT8 tile layouts, with backend plugin registration and unit coverage for the new API.

Changes:

  • Adds public KernelContext swizzled load/store helpers for FP16 stride-32, FP16 stride-16, and INT8.
  • Implements PTX graph builder plugins, Graal nodes, and PTX LIR emission for the new accessors.
  • Registers unsupported-backend stubs and adds unit tests to the Tornado test suite.

Reviewed changes

Copilot reviewed 13 out of 14 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tornado-api/src/main/java/uk/ac/manchester/tornado/api/KernelContext.java Adds Java API and fallback implementations for swizzled local-memory accessors.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java Registers PTX lowering plugins for the new accessors.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXLIRStmt.java Emits PTX shared-memory swizzle address calculations and load/store instructions.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/SwizzledLoadFP16Stride32Node.java Adds FP16 stride-32 swizzled load node.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/SwizzledStoreFP16Stride32Node.java Adds FP16 stride-32 swizzled store node.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/SwizzledLoadFP16Stride16Node.java Adds FP16 stride-16 swizzled load node.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/SwizzledStoreFP16Stride16Node.java Adds FP16 stride-16 swizzled store node.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/SwizzledLoadInt8Node.java Adds INT8 swizzled load node.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/SwizzledStoreInt8Node.java Adds INT8 swizzled store node.
tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java Registers unsupported stubs for the new PTX-only API.
tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java Registers unsupported stubs for the new PTX-only API.
tornado-drivers/metal/src/main/java/uk/ac/manchester/tornado/drivers/metal/graal/compiler/plugins/MetalGraphBuilderPlugins.java Registers unsupported stubs for the new PTX-only API.
tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/kernelcontext/local/memory/TestSwizzledLocalArrays.java Adds PTX-focused unit tests for FP16 and INT8 swizzled local-memory round trips.
tornado-assembly/src/bin/tornado-test Adds the new swizzled local-array test class to the test suite.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

r.register(new InvocationPlugin("swizzleStoreFp16Stride32", InvocationPlugin.Receiver.class, HalfFloat[].class, int.class, int.class, int.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode local_array, ValueNode row, ValueNode column, ValueNode stride, ValueNode value) {
b.addPush(JavaKind.Object, new SwizzledStoreFP16Stride32Node(local_array, row, column, stride, value));
r.register(new InvocationPlugin("swizzleStoreFp16Stride16", InvocationPlugin.Receiver.class, HalfFloat[].class, int.class, int.class, int.class, HalfFloat.class) {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode local_array, ValueNode row, ValueNode column, ValueNode stride, ValueNode value) {
b.addPush(JavaKind.Object, new SwizzledStoreFP16Stride16Node(local_array, row, column, stride, value));
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver,
ValueNode local_array, ValueNode row, ValueNode column, ValueNode stride, ValueNode value) {
b.addPush(JavaKind.Byte, new SwizzledStoreInt8Node(local_array, row, column, stride, value));
Copy link
Copy Markdown
Collaborator

@stratika stratika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please consider and try the comment from co-pilot.

…d instead of addPush for stores, extend the unittest and include the swizzledstore* nodes in the TornadoHalfFloatReplacement)
@mairooni
Copy link
Copy Markdown
Collaborator Author

The comments have been applied on both this PR and PR #843

@stratika stratika merged commit 1742067 into beehive-lab:develop May 28, 2026
7 checks passed
@stratika stratika moved this from Backport PR Open to Merged to JDK25 in JDK PR Backport Tracking Project May 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Merged to JDK25

Development

Successfully merging this pull request may close these issues.

3 participants