Skip to content

Commit 6136d99

Browse files
authored
Merge pull request #63 from axelhzf/complex-dtype-support
Implement loading and dumping of complex64 and complex128 dtypes
2 parents f33a3eb + d970d0a commit 6136d99

13 files changed

Lines changed: 252 additions & 15 deletions

README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ console.log(tensor.get(10, 15));
103103
- `float32`
104104
- `float64`
105105
- `float16` (converted to float32 by default)
106+
- `complex64` (as `Float32Array` with interleaved real/imag)
107+
- `complex128` (as `Float64Array` with interleaved real/imag)
106108

107109
### Float16 Control
108110

@@ -114,6 +116,56 @@ const n1 = new npyjs();
114116
const n2 = new npyjs({ convertFloat16: false });
115117
```
116118

119+
### Complex Numbers
120+
121+
Complex arrays are returned as typed arrays with interleaved real and imaginary parts: `[real0, imag0, real1, imag1, ...]`
122+
123+
```ts
124+
import { load } from "npyjs";
125+
126+
const { data, shape } = await load("complex-array.npy");
127+
// For a shape of [3], data will have 6 elements: [re0, im0, re1, im1, re2, im2]
128+
129+
// Access the first complex number
130+
const real0 = data[0];
131+
const imag0 = data[1];
132+
```
133+
134+
---
135+
136+
## Writing .npy Files
137+
138+
Use the `dump` function to create `.npy` files:
139+
140+
```ts
141+
import { dump } from "npyjs";
142+
import { writeFileSync } from "fs";
143+
144+
// Dump a typed array
145+
const arr = new Float32Array([1.0, 2.0, 3.0, 4.0]);
146+
const bytes = dump(arr, [2, 2]); // 2x2 shape
147+
writeFileSync("output.npy", Buffer.from(bytes));
148+
149+
// Dump a plain array (dtype is inferred)
150+
const plain = [1, 2, 3, 4];
151+
const bytes2 = dump(plain, [4]);
152+
```
153+
154+
### Dumping Complex Arrays
155+
156+
Since complex types cannot be inferred from plain number arrays, use the `dtype` option:
157+
158+
```ts
159+
import { dump } from "npyjs";
160+
161+
// Complex array: 1+2j, 3-4j as interleaved [real, imag, ...]
162+
const complexData = [1, 2, 3, -4];
163+
const bytes = dump(complexData, [2], { dtype: "c8" }); // complex64
164+
165+
// Or use c16 for complex128
166+
const bytes128 = dump(complexData, [2], { dtype: "c16" });
167+
```
168+
117169
---
118170

119171
## Development

docs/index.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ console.log(tensor.get(10, 15));
9797
- `float32`
9898
- `float64`
9999
- `float16` (converted to float32 by default)
100+
- `complex64` (as `Float32Array` with interleaved real/imag)
101+
- `complex128` (as `Float64Array` with interleaved real/imag)
100102

101103
### Float16 Control
102104

@@ -108,6 +110,56 @@ const n1 = new npyjs();
108110
const n2 = new npyjs({ convertFloat16: false });
109111
```
110112

113+
### Complex Numbers
114+
115+
Complex arrays are returned as typed arrays with interleaved real and imaginary parts: `[real0, imag0, real1, imag1, ...]`
116+
117+
```ts
118+
import { load } from "npyjs";
119+
120+
const { data, shape } = await load("complex-array.npy");
121+
// For a shape of [3], data will have 6 elements: [re0, im0, re1, im1, re2, im2]
122+
123+
// Access the first complex number
124+
const real0 = data[0];
125+
const imag0 = data[1];
126+
```
127+
128+
---
129+
130+
## Writing .npy Files
131+
132+
Use the `dump` function to create `.npy` files:
133+
134+
```ts
135+
import { dump } from "npyjs";
136+
import { writeFileSync } from "fs";
137+
138+
// Dump a typed array
139+
const arr = new Float32Array([1.0, 2.0, 3.0, 4.0]);
140+
const bytes = dump(arr, [2, 2]); // 2x2 shape
141+
writeFileSync("output.npy", Buffer.from(bytes));
142+
143+
// Dump a plain array (dtype is inferred)
144+
const plain = [1, 2, 3, 4];
145+
const bytes2 = dump(plain, [4]);
146+
```
147+
148+
### Dumping Complex Arrays
149+
150+
Since complex types cannot be inferred from plain number arrays, use the `dtype` option:
151+
152+
```ts
153+
import { dump } from "npyjs";
154+
155+
// Complex array: 1+2j, 3-4j as interleaved [real, imag, ...]
156+
const complexData = [1, 2, 3, -4];
157+
const bytes = dump(complexData, [2], { dtype: "c8" }); // complex64
158+
159+
// Or use c16 for complex128
160+
const bytes128 = dump(complexData, [2], { dtype: "c16" });
161+
```
162+
111163
---
112164

113165
## Development

src/index.ts

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export type DType =
22
| "i1" | "u1" | "i2" | "u2" | "i4" | "u4" | "i8" | "u8"
3-
| "f2" | "f4" | "f8" | "b1" | `U${number}`; // e.g., U10 for strings of length 10
3+
| "f2" | "f4" | "f8" | "b1" | "c8" | "c16" | `U${number}`; // e.g., U10 for strings of length 10
44

55
export type TypedArray =
66
| Int8Array
@@ -27,6 +27,14 @@ export interface Options {
2727
convertFloat16?: boolean;
2828
}
2929

30+
export interface DumpOptions {
31+
/**
32+
* Specify the dtype for the output. Required for complex types (c8, c16)
33+
* since they cannot be inferred from a plain number array.
34+
*/
35+
dtype?: DType;
36+
}
37+
3038
class StringFromCodePoint extends String {
3139
constructor(buf: ArrayBufferLike, byteOffset?: number, length?: number) {
3240
const uint32 = new Uint32Array(buf, byteOffset, length);
@@ -109,6 +117,8 @@ function dtypeToArray(dtype: string, buf: ArrayBufferLike, offset: number, opts:
109117
case "u8": return new BigUint64Array(buf, offset);
110118
case "f4": return new Float32Array(buf, offset);
111119
case "f8": return new Float64Array(buf, offset);
120+
case "c8": return new Float32Array(buf, offset);
121+
case "c16": return new Float64Array(buf, offset);
112122
case "f2": {
113123
if (opts.convertFloat16 !== false) {
114124
const u16 = new Uint16Array(buf, offset);
@@ -235,6 +245,8 @@ export function arrayToTypedArray(dtype: DType, array: ArrayLike<number | string
235245
case "u8": return new BigUint64Array(array);
236246
case "f4": return new Float32Array(array);
237247
case "f8": return new Float64Array(array);
248+
case "c8": return new Float32Array(array);
249+
case "c16": return new Float64Array(array);
238250
default: throw new Error(`Unsupported dtype: ${dtype}`);
239251
}
240252
}
@@ -332,11 +344,20 @@ function createPyDescription(dtype : DType, shape: number[]) : string {
332344
return `{'descr':'${descr}','fortran_order':False,'shape':(${pyShape})}`;
333345
}
334346

335-
export function dump(array: TypedArray | Array<number | string>, shape: number[] | undefined) : ArrayBuffer{
336-
const dtype = array instanceof Array ? inferDtypeFromArray(array) : arrayToDtype(array);
347+
export function dump(array: TypedArray | Array<number | string>, shape?: number[], options?: DumpOptions) : ArrayBuffer{
348+
let dtype: DType;
349+
if (options?.dtype) {
350+
dtype = options.dtype;
351+
} else if (array instanceof Array) {
352+
dtype = inferDtypeFromArray(array);
353+
} else {
354+
dtype = arrayToDtype(array);
355+
}
337356
array = array instanceof Array ? arrayToTypedArray(dtype, array) : array;
338-
339-
let pyDesc = createPyDescription(dtype, shape ?? [array.length]);
357+
358+
// For complex types, shape refers to number of complex elements, not the flat array length
359+
const effectiveLength = (dtype === "c8" || dtype === "c16") ? array.length / 2 : array.length;
360+
let pyDesc = createPyDescription(dtype, shape ?? [effectiveLength]);
340361
let headerSize = 10 + pyDesc.length;
341362
const pad = 8 - ((headerSize + 1) % 8);
342363
pyDesc = pyDesc + " ".repeat(pad) + "\x0A";
@@ -368,7 +389,7 @@ export default class N {
368389
return f16toF32(u16);
369390
}
370391

371-
dump(array: TypedArray | Array<number | string>, shape: number[]) {
372-
return dump(array, shape);
392+
dump(array: TypedArray | Array<number | string>, shape?: number[], options?: DumpOptions) {
393+
return dump(array, shape, options);
373394
}
374395
}

test/data/10-complex128.npy

288 Bytes
Binary file not shown.

test/data/10-complex64.npy

208 Bytes
Binary file not shown.

test/data/4x4x4x4x4-complex128.npy

16.1 KB
Binary file not shown.

test/data/4x4x4x4x4-complex64.npy

8.13 KB
Binary file not shown.

test/data/65x65-complex128.npy

66.1 KB
Binary file not shown.

test/data/65x65-complex64.npy

33.1 KB
Binary file not shown.

test/dump.test.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,54 @@ describe("npyjs dump", () => {
6161
expect(error.toString()).toContain("MyArray")
6262
}
6363
});
64+
65+
it("complex64 (c8) with plain array", async () => {
66+
const complexData = [1, 2, 3, -4, 5, 6];
67+
const bytes = npyjs.dump(complexData, [3], { dtype: "c8" });
68+
const result = await npyjs.load(bytes);
69+
70+
expect(result.dtype).toBe("c8");
71+
expect(result.shape).toEqual([3]);
72+
expect(result.data).toEqual(new Float32Array(complexData));
73+
});
74+
75+
it("complex128 (c16) with plain array", async () => {
76+
const complexData = [1.5, 2.5, -3.5, 4.5];
77+
const bytes = npyjs.dump(complexData, [2], { dtype: "c16" });
78+
const result = await npyjs.load(bytes);
79+
80+
expect(result.dtype).toBe("c16");
81+
expect(result.shape).toEqual([2]);
82+
expect(result.data).toEqual(new Float64Array(complexData));
83+
});
84+
85+
it("complex64 (c8) with Float32Array", async () => {
86+
const original = new Float32Array([1, 2, 3, -4, 5.5, 6.5]);
87+
const bytes = npyjs.dump(original, [3], { dtype: "c8" });
88+
const result = await npyjs.load(bytes);
89+
90+
expect(result.dtype).toBe("c8");
91+
expect(result.shape).toEqual([3]);
92+
expect(result.data).toEqual(original);
93+
});
94+
95+
it("complex128 (c16) with Float64Array", async () => {
96+
const original = new Float64Array([1.1, 2.2, 3.3, -4.4]);
97+
const bytes = npyjs.dump(original, [2], { dtype: "c16" });
98+
const result = await npyjs.load(bytes);
99+
100+
expect(result.dtype).toBe("c16");
101+
expect(result.shape).toEqual([2]);
102+
expect(result.data).toEqual(original);
103+
});
104+
105+
it("2D complex array", async () => {
106+
const complexData = [1, 1, 2, 2, 3, 3, 4, 4];
107+
const bytes = npyjs.dump(complexData, [2, 2], { dtype: "c8" });
108+
const result = await npyjs.load(bytes);
109+
110+
expect(result.dtype).toBe("c8");
111+
expect(result.shape).toEqual([2, 2]);
112+
expect(result.data).toEqual(new Float32Array(complexData));
113+
});
64114
});

0 commit comments

Comments
 (0)