Skip to content
This repository was archived by the owner on May 11, 2023. It is now read-only.
This repository was archived by the owner on May 11, 2023. It is now read-only.

feat: Allowing the use of Hyper-Efficient Kernel Operator Library. #56

@adam-hartshorne

Description

@adam-hartshorne

I wondered if you aware / had considered making use of a third-party hyper-efficient kernel library such KeOps or Triton.

KeOps currently has bindings for PyTorch and you can use it with GPyTorch for defining kernels. Although there is no current JAX support, I believe it is probably possible via the custom_call functionality now available within JAX to allow interaction with C++ libraries. It is both extremely fast and capable of processing massive matrix operations that wouldn't normally fit into memory.

Alternatively, there is already a wrapper for JAX around Triton (triton-lang.org), which is language created by OpenAI for high-performance kernel computations. I believe JAX is already leveraging Triton for some internal operations. Effectively it allows you to define a custom kernel in the way you can do with CUDA, but with added advantages (one being you can trivially include it with JAX).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions