Skip to content

Commit 874cd0b

Browse files
sauterpexedev
andauthored
[Routines] Add column_stack, concatenate, hstack, vstack, and row_stack (#352)
Implement numpy-compatible array joining routines in numojo/routines/manipulation.mojo: - concatenate(): Join arrays along an existing axis - column_stack(): Stack 1-D arrays as columns into 2-D, or hstack 2-D+ arrays - row_stack(): Stack arrays vertically (1-D reshaped to (1, N)) - hstack(): Horizontal stacking (axis=0 for 1-D, axis=1 for 2-D+) - vstack(): Vertical stacking (alias for row_stack) Tests validate against numpy equivalents for 1-D, 2-D, 3-D, and mixed-dimension inputs. Co-authored-by: exedev <exedev@exe.dev>
1 parent df3e56f commit 874cd0b

4 files changed

Lines changed: 492 additions & 0 deletions

File tree

numojo/__init__.mojo

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ from numojo.routines.manipulation import (
270270
transpose,
271271
broadcast_to,
272272
flip,
273+
concatenate,
274+
column_stack,
275+
row_stack,
276+
hstack,
277+
vstack,
273278
)
274279

275280
from numojo.routines import random

numojo/routines/__init__.mojo

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ from .manipulation import (
171171
transpose,
172172
broadcast_to,
173173
flip,
174+
concatenate,
175+
column_stack,
176+
row_stack,
177+
hstack,
178+
vstack,
174179
)
175180

176181
from .sorting import sort, argsort

numojo/routines/manipulation.mojo

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,363 @@ def flip[
691691
A._buf.ptr[I._buf.ptr[i + A.shape[axis] - 1 - j]] = temp
692692

693693
return A^
694+
695+
696+
# ===----------------------------------------------------------------------=== #
697+
# Joining arrays
698+
# ===----------------------------------------------------------------------=== #
699+
700+
701+
def _concatenate_list[
702+
dtype: DType
703+
](arrays: List[NDArray[dtype]], axis: Int = 0) raises -> NDArray[dtype]:
704+
"""Internal: Join a list of arrays along an existing axis."""
705+
if len(arrays) == 0:
706+
raise Error(
707+
NumojoError(
708+
category="value",
709+
message="Need at least one array to concatenate.",
710+
location="concatenate()",
711+
)
712+
)
713+
714+
if len(arrays) == 1:
715+
return arrays[0].contiguous()
716+
717+
ref first = arrays[0]
718+
var ndims = first.ndim
719+
720+
var ax = axis
721+
if ax < 0:
722+
ax += ndims
723+
if ax < 0 or ax >= ndims:
724+
raise Error(
725+
NumojoError(
726+
category="value",
727+
message=String(
728+
"axis {} is out of bounds for array of dimension {}."
729+
).format(axis, ndims),
730+
location="concatenate()",
731+
)
732+
)
733+
734+
# Validate shapes and compute the total size along the concat axis.
735+
var total_along_axis: Int = first.shape[ax]
736+
for i in range(1, len(arrays)):
737+
ref arr = arrays[i]
738+
if arr.ndim != ndims:
739+
raise Error(
740+
NumojoError(
741+
category="value",
742+
message=String(
743+
"All arrays must have the same number of dimensions."
744+
" Array 0 has {} dims, array {} has {} dims."
745+
).format(ndims, i, arr.ndim),
746+
location="concatenate()",
747+
)
748+
)
749+
for d in range(ndims):
750+
if d != ax and arr.shape[d] != first.shape[d]:
751+
raise Error(
752+
NumojoError(
753+
category="shape",
754+
message=String(
755+
"All array dimensions except for the"
756+
" concatenation axis must match. Dimension {}"
757+
" of array {} has size {} but expected {}."
758+
).format(d, i, arr.shape[d], first.shape[d]),
759+
location="concatenate()",
760+
)
761+
)
762+
total_along_axis += arr.shape[ax]
763+
764+
# Build the output shape.
765+
var out_shape_list = List[Int]()
766+
for d in range(ndims):
767+
if d == ax:
768+
out_shape_list.append(total_along_axis)
769+
else:
770+
out_shape_list.append(first.shape[d])
771+
var out_shape = NDArrayShape(out_shape_list)
772+
var result = NDArray[dtype](out_shape)
773+
774+
# Copy data array by array.
775+
# We iterate over the output in C-order and figure out which source
776+
# array each element comes from.
777+
#
778+
# Strategy: walk the output linearly, convert flat index to
779+
# multi-dimensional index, map the concat-axis coordinate back to the
780+
# source array, read from the (contiguous) source.
781+
782+
# Pre-compute the boundary offsets along the concat axis for each array.
783+
var boundaries = List[Int]()
784+
var running: Int = 0
785+
for i in range(len(arrays)):
786+
boundaries.append(running)
787+
running += arrays[i].shape[ax]
788+
789+
# For each element in the result, determine the source array and index.
790+
for flat_idx in range(result.size):
791+
# Convert flat_idx to nd-index (C-order).
792+
var remainder = flat_idx
793+
var nd_index = List[Int]()
794+
for _ in range(ndims):
795+
nd_index.append(0)
796+
for d in range(ndims):
797+
nd_index[d] = remainder // result.strides[d]
798+
remainder = remainder % result.strides[d]
799+
800+
# Determine which source array this element comes from.
801+
var coord_along_axis = nd_index[ax]
802+
var src_idx: Int = len(arrays) - 1
803+
for i in range(len(arrays) - 1, -1, -1):
804+
if coord_along_axis >= boundaries[i]:
805+
src_idx = i
806+
break
807+
808+
# Adjust the coordinate along the concat axis to be local.
809+
nd_index[ax] = coord_along_axis - boundaries[src_idx]
810+
811+
result._buf.ptr[flat_idx] = arrays[src_idx]._getitem(nd_index)
812+
813+
return result^
814+
815+
816+
def concatenate[
817+
dtype: DType
818+
](*arrays: NDArray[dtype], axis: Int = 0) raises -> NDArray[dtype]:
819+
"""Join a sequence of arrays along an existing axis.
820+
821+
Parameters:
822+
dtype: The data type of the arrays.
823+
824+
Args:
825+
arrays: The arrays to concatenate. All arrays must have the same
826+
shape except in the dimension corresponding to `axis`.
827+
axis: The axis along which the arrays will be joined. Default is 0.
828+
829+
Returns:
830+
The concatenated array.
831+
832+
Raises:
833+
Error: If the list of arrays is empty.
834+
Error: If the arrays do not have the same number of dimensions.
835+
Error: If the array shapes are incompatible along non-concatenation axes.
836+
837+
Examples:
838+
```mojo
839+
import numojo as nm
840+
var a = nm.arange[nm.f64](0, 6, 1)
841+
var a2d = nm.reshape(a, nm.Shape(2, 3))
842+
var b = nm.arange[nm.f64](6, 12, 1)
843+
var b2d = nm.reshape(b, nm.Shape(2, 3))
844+
var c = nm.concatenate(a2d, b2d, axis=0) # Shape (4, 3)
845+
var d = nm.concatenate(a2d, b2d, axis=1) # Shape (2, 6)
846+
```
847+
"""
848+
var arr_list = List[NDArray[dtype]]()
849+
for i in range(len(arrays)):
850+
arr_list.append(arrays[i].copy())
851+
return _concatenate_list(arr_list, axis)
852+
853+
854+
def column_stack[
855+
dtype: DType
856+
](*arrays: NDArray[dtype]) raises -> NDArray[dtype]:
857+
"""Stack 1-D arrays as columns into a 2-D array, or concatenate
858+
2-D+ arrays along the second axis (like `numpy.column_stack`).
859+
860+
Parameters:
861+
dtype: The data type of the arrays.
862+
863+
Args:
864+
arrays: The arrays to stack. 1-D arrays are treated as column
865+
vectors. All arrays must have the same number of rows
866+
(first dimension).
867+
868+
Returns:
869+
The 2-D (or higher) array formed by stacking the inputs as columns.
870+
871+
Raises:
872+
Error: If the list of arrays is empty.
873+
874+
Examples:
875+
```mojo
876+
import numojo as nm
877+
var a = nm.arange[nm.f64](0, 3, 1) # Shape (3,)
878+
var b = nm.arange[nm.f64](3, 6, 1) # Shape (3,)
879+
var c = nm.column_stack(a, b) # Shape (3, 2)
880+
```
881+
"""
882+
if len(arrays) == 0:
883+
raise Error(
884+
NumojoError(
885+
category="value",
886+
message="Need at least one array to column_stack.",
887+
location="column_stack()",
888+
)
889+
)
890+
891+
# Transform 1-D arrays into 2-D column vectors.
892+
var transformed = List[NDArray[dtype]]()
893+
for i in range(len(arrays)):
894+
if arrays[i].ndim == 1:
895+
# Reshape (N,) -> (N, 1)
896+
transformed.append(
897+
reshape(
898+
arrays[i].copy(),
899+
NDArrayShape(arrays[i].shape[0], 1),
900+
)
901+
)
902+
else:
903+
transformed.append(arrays[i].copy())
904+
905+
return _concatenate_list(transformed, axis=1)
906+
907+
908+
def row_stack[dtype: DType](*arrays: NDArray[dtype]) raises -> NDArray[dtype]:
909+
"""Stack arrays vertically (row-wise), equivalent to
910+
`numpy.row_stack` / `numpy.vstack`.
911+
912+
Parameters:
913+
dtype: The data type of the arrays.
914+
915+
Args:
916+
arrays: The arrays to stack. 1-D arrays of shape `(N,)` are
917+
reshaped to `(1, N)` before concatenation.
918+
919+
Returns:
920+
The array formed by stacking the inputs vertically.
921+
922+
Raises:
923+
Error: If the list of arrays is empty.
924+
925+
Examples:
926+
```mojo
927+
import numojo as nm
928+
var a = nm.arange[nm.f64](0, 3, 1) # Shape (3,)
929+
var b = nm.arange[nm.f64](3, 6, 1) # Shape (3,)
930+
var c = nm.row_stack(a, b) # Shape (2, 3)
931+
```
932+
"""
933+
if len(arrays) == 0:
934+
raise Error(
935+
NumojoError(
936+
category="value",
937+
message="Need at least one array to row_stack.",
938+
location="row_stack()",
939+
)
940+
)
941+
942+
var transformed = List[NDArray[dtype]]()
943+
for i in range(len(arrays)):
944+
if arrays[i].ndim == 1:
945+
# Reshape (N,) -> (1, N)
946+
transformed.append(
947+
reshape(
948+
arrays[i].copy(),
949+
NDArrayShape(1, arrays[i].shape[0]),
950+
)
951+
)
952+
else:
953+
transformed.append(arrays[i].copy())
954+
955+
return _concatenate_list(transformed, axis=0)
956+
957+
958+
def hstack[dtype: DType](*arrays: NDArray[dtype]) raises -> NDArray[dtype]:
959+
"""Stack arrays in sequence horizontally (column-wise),
960+
equivalent to `numpy.hstack`.
961+
962+
For 1-D arrays, this concatenates along axis 0.
963+
For 2-D+ arrays, this concatenates along axis 1.
964+
965+
Parameters:
966+
dtype: The data type of the arrays.
967+
968+
Args:
969+
arrays: The arrays to stack.
970+
971+
Returns:
972+
The array formed by stacking the inputs horizontally.
973+
974+
Raises:
975+
Error: If the list of arrays is empty.
976+
977+
Examples:
978+
```mojo
979+
import numojo as nm
980+
var a = nm.arange[nm.f64](0, 3, 1) # Shape (3,)
981+
var b = nm.arange[nm.f64](3, 6, 1) # Shape (3,)
982+
var c = nm.hstack(a, b) # Shape (6,)
983+
```
984+
"""
985+
if len(arrays) == 0:
986+
raise Error(
987+
NumojoError(
988+
category="value",
989+
message="Need at least one array to hstack.",
990+
location="hstack()",
991+
)
992+
)
993+
994+
var arr_list = List[NDArray[dtype]]()
995+
for i in range(len(arrays)):
996+
arr_list.append(arrays[i].copy())
997+
998+
# For 1-D arrays, concatenate along axis 0.
999+
if arr_list[0].ndim == 1:
1000+
return _concatenate_list(arr_list, axis=0)
1001+
1002+
return _concatenate_list(arr_list, axis=1)
1003+
1004+
1005+
def vstack[dtype: DType](*arrays: NDArray[dtype]) raises -> NDArray[dtype]:
1006+
"""Stack arrays in sequence vertically (row-wise),
1007+
equivalent to `numpy.vstack`.
1008+
1009+
For 1-D arrays of shape `(N,)`, they are reshaped to `(1, N)` first.
1010+
Then concatenated along axis 0.
1011+
1012+
Parameters:
1013+
dtype: The data type of the arrays.
1014+
1015+
Args:
1016+
arrays: The arrays to stack.
1017+
1018+
Returns:
1019+
The array formed by stacking the inputs vertically.
1020+
1021+
Raises:
1022+
Error: If the list of arrays is empty.
1023+
1024+
Examples:
1025+
```mojo
1026+
import numojo as nm
1027+
var a = nm.arange[nm.f64](0, 3, 1) # Shape (3,)
1028+
var b = nm.arange[nm.f64](3, 6, 1) # Shape (3,)
1029+
var c = nm.vstack(a, b) # Shape (2, 3)
1030+
```
1031+
"""
1032+
if len(arrays) == 0:
1033+
raise Error(
1034+
NumojoError(
1035+
category="value",
1036+
message="Need at least one array to vstack.",
1037+
location="vstack()",
1038+
)
1039+
)
1040+
1041+
var transformed = List[NDArray[dtype]]()
1042+
for i in range(len(arrays)):
1043+
if arrays[i].ndim == 1:
1044+
transformed.append(
1045+
reshape(
1046+
arrays[i].copy(),
1047+
NDArrayShape(1, arrays[i].shape[0]),
1048+
)
1049+
)
1050+
else:
1051+
transformed.append(arrays[i].copy())
1052+
1053+
return _concatenate_list(transformed, axis=0)

0 commit comments

Comments
 (0)