Skip to content

Commit 8ff8cab

Browse files
committed
wip: GPU acceleration
1 parent 443d53c commit 8ff8cab

31 files changed

Lines changed: 728 additions & 303 deletions

.gitignore

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ ipython_config.py
101101
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102102
poetry.lock
103103

104+
# UV
105+
uv.lock
106+
104107
# Cargo
105108
Cargo.lock
106109

@@ -166,4 +169,8 @@ cython_debug/
166169

167170
# Rust
168171
/target
169-
.env
172+
.env
173+
174+
test.py
175+
176+
*.class

CONTRIBUTING.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@
33
Thank you
44

55
## Rollout
6-
1. `pip install -r dev-requirements.txt` - Installing python dependencies
7-
2. `cargo build` - Installing rust dependencies
6+
1. `pip install -r dev-requirements.txt` or alternatively `uv sync` - Installing python dependencies
7+
2. `cargo build` - Installing rust dependencies
8+
9+
## Run benchmarks
10+
- `uv run pytest python/benches/compare_benchmark_test.py` (via uv)
11+
- `pytest python/benches/compare_benchmark_test.py`

Cargo.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ crate-type = ["cdylib", "rlib"]
1919
numpy = "0.25.0"
2020
pyo3 = { version = "0.25.1", features = ["extension-module"] }
2121
rayon = "1.10.0"
22+
anyhow = "1.0"
2223

23-
[dependencies.opencl3]
24-
version = "0.11"
25-
features = ["CL_VERSION_2_1", "CL_VERSION_2_2", "CL_VERSION_3_0"]
24+
[dependencies.rem_math_gpu]
25+
path = "rem_math_gpu"
2626

2727
[dev-dependencies]
2828
criterion = "0.3"
@@ -34,3 +34,8 @@ harness = false
3434
[profile.release]
3535
lto = true
3636
codegen-units = 1
37+
38+
[workspace]
39+
members = [
40+
"rem_math_gpu", "rem_math_simd",
41+
]

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,17 @@ Found 15 outliers among 100 measurements (15.00%)
155155
8 (8.00%) high severe
156156
```
157157

158+
## Integrations (java will be done later)
159+
<img height="50" src="https://raw.githubusercontent.com/marwin1991/profile-technology-icons/refs/heads/main/icons/python.png">
160+
<img height="50" src="https://raw.githubusercontent.com/marwin1991/profile-technology-icons/refs/heads/main/icons/java.png">
161+
162+
## GPU Integrations
163+
![nVIDIA](https://img.shields.io/badge/cuda-000000.svg?style=for-the-badge&logo=nVIDIA&logoColor=green)
164+
![nVIDIA](https://img.shields.io/badge/opencl-000000.svg?style=for-the-badge&logo=opencl&logoColor=green)
165+
158166
## Roadmap
159167

160-
- Add GPU-accelerated operations for improved performance.
168+
- Add GPU-accelerated operations for improved performance. (__in progress__)
161169
- Implement own custom type objects for best performance from ecosystem.
162170
- Expand mathematical functionality with additional features and algorithms.
163171

dev-requirements.txt

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,61 @@
1-
colorama==0.4.6 ; python_version >= "3.11" and sys_platform == "win32"
2-
iniconfig==2.1.0 ; python_version >= "3.11"
3-
maturin==1.9.0 ; python_version >= "3.11"
4-
numpy==2.3.1 ; python_version >= "3.11"
5-
packaging==25.0 ; python_version >= "3.11"
6-
pluggy==1.6.0 ; python_version >= "3.11"
7-
py-cpuinfo==9.0.0 ; python_version >= "3.11"
8-
pygments==2.19.2 ; python_version >= "3.11"
9-
pytest-benchmark==5.1.0 ; python_version >= "3.11"
10-
pytest==8.4.1 ; python_version >= "3.11"
1+
# This file was autogenerated by uv via the following command:
2+
# uv export --group dev --no-hashes --format requirements-txt
3+
-e .
4+
click==8.3.1
5+
# via mkdocs
6+
colorama==0.4.6 ; sys_platform == 'win32'
7+
# via
8+
# click
9+
# mkdocs
10+
# pytest
11+
ghp-import==2.1.0
12+
# via mkdocs
13+
iniconfig==2.3.0
14+
# via pytest
15+
jinja2==3.1.6
16+
# via mkdocs
17+
markdown==3.10
18+
# via mkdocs
19+
markupsafe==3.0.3
20+
# via
21+
# jinja2
22+
# mkdocs
23+
maturin==1.10.2
24+
mergedeep==1.3.4
25+
# via
26+
# mkdocs
27+
# mkdocs-get-deps
28+
mkdocs==1.6.1
29+
mkdocs-get-deps==0.2.0
30+
# via mkdocs
31+
numpy==2.4.0
32+
packaging==25.0
33+
# via
34+
# mkdocs
35+
# pytest
36+
pathspec==0.12.1
37+
# via mkdocs
38+
platformdirs==4.5.1
39+
# via mkdocs-get-deps
40+
pluggy==1.6.0
41+
# via pytest
42+
py-cpuinfo==9.0.0
43+
# via pytest-benchmark
44+
pygments==2.19.2
45+
# via pytest
46+
pytest==9.0.2
47+
# via pytest-benchmark
48+
pytest-benchmark==5.2.3
49+
python-dateutil==2.9.0.post0
50+
# via ghp-import
51+
pyyaml==6.0.3
52+
# via
53+
# mkdocs
54+
# mkdocs-get-deps
55+
# pyyaml-env-tag
56+
pyyaml-env-tag==1.1
57+
# via mkdocs
58+
six==1.17.0
59+
# via python-dateutil
60+
watchdog==6.0.0
61+
# via mkdocs

java/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "rem_math_java"
3+
version = "0.1.0"
4+
edition = "2024"
5+
6+
[dependencies]
7+
anyhow = "1.0"
8+
jni = "0.21.1"

java/RemMath.h

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

java/RemMath.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Experimental module.
3+
*
4+
* This code is under active development and provided for internal testing only.
5+
* APIs, behavior, and implementation details are unstable and may change
6+
* without notice. Do not rely on this functionality in production or
7+
* draw conclusions from current results.
8+
*/
9+
10+
class RemMath {
11+
private static native int[] sum_two_ints32(int[] a, int[] b, String method);
12+
static {
13+
System.loadLibrary("rem_math");
14+
}
15+
16+
public static void main(String[] args) {
17+
int[] a = {1, 2, 3};
18+
int[] b = {1, 2, 3};
19+
20+
int[] output = RemMath.sum_two_ints32(a, b, "gpu");
21+
System.out.println(output);
22+
}
23+
}

java/src/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
RemMath Java SDK integration.
3+
Note, this is template function, it will be replaced soon
4+
*/
5+
6+
use jni::JNIEnv;
7+
use jni::objects::{JClass, JString};
8+
use jni::sys::jstring;
9+
10+
#[no_mangle]
11+
pub extern "system" fn Java_HelloWorld_hello<'local>(
12+
mut env: JNIEnv<'local>,
13+
class: JClass<'local>,
14+
input: JString<'local>
15+
) -> jstring {
16+
let input: String = env.get_string(&input).expect("Couldn't get java string!").into();
17+
let output = env.new_string(format!("Hello, {}!", input)).expect("Couldn't create java string!");
18+
19+
output.into_raw()
20+
}

pyproject.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@ requires-python = ">=3.11"
1010
dependencies = [
1111
]
1212

13-
1413
[build-system]
1514
requires = ["maturin>=1.0,<2.0"]
1615
build-backend = "maturin"
1716

1817
[tool.maturin]
18+
python-source = "python"
1919
module-name = "rem_math._rem_math"
2020

21-
[tool.poetry.group.dev.dependencies]
22-
pytest-benchmark = "^5.1.0"
23-
numpy = "^2.3.1"
24-
maturin = "^1.9.0"
25-
mkdocs = "^1.6.1"
21+
[dependency-groups]
22+
dev = [
23+
"maturin>=1.10.2",
24+
"numpy>=2.4.0",
25+
"pytest-benchmark>=5.2.3",
26+
"mkdocs>=1.6.1",
27+
]

0 commit comments

Comments
 (0)