@@ -756,3 +756,119 @@ def test_convolution_2d_boundary_no_nan(boundary):
756756 assert not np .any (np .isnan (da_result .data .compute ()))
757757 np .testing .assert_allclose (
758758 np_result .data , da_result .data .compute (), equal_nan = True , rtol = 1e-5 )
759+
760+
761+ # --- 3D (multi-band) focal tests ---
762+
763+
764+ @pytest .fixture
765+ def rgb_data ():
766+ rng = np .random .default_rng (123 )
767+ return rng .random ((3 , 12 , 14 )).astype (np .float64 )
768+
769+
770+ def test_mean_3d_numpy (rgb_data ):
771+ agg = xr .DataArray (rgb_data , dims = ['band' , 'y' , 'x' ])
772+ result = mean (agg )
773+ assert result .shape == (3 , 12 , 14 )
774+ assert result .dims == ('band' , 'y' , 'x' )
775+ for i in range (3 ):
776+ band_result = mean (agg .isel (band = i ))
777+ np .testing .assert_allclose (result .isel (band = i ).data , band_result .data )
778+
779+
780+ @dask_array_available
781+ def test_mean_3d_dask (rgb_data ):
782+ dask_data = da .from_array (rgb_data , chunks = (1 , 6 , 7 ))
783+ agg = xr .DataArray (dask_data , dims = ['band' , 'y' , 'x' ])
784+ result = mean (agg )
785+ assert result .shape == (3 , 12 , 14 )
786+ # compare against numpy per-band
787+ numpy_agg = xr .DataArray (rgb_data , dims = ['band' , 'y' , 'x' ])
788+ numpy_result = mean (numpy_agg )
789+ np .testing .assert_allclose (
790+ result .data .compute (), numpy_result .data , equal_nan = True , rtol = 1e-5 )
791+
792+
793+ def test_apply_3d_numpy (rgb_data ):
794+ kernel = np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]])
795+ agg = xr .DataArray (rgb_data , dims = ['band' , 'y' , 'x' ])
796+ result = apply (agg , kernel )
797+ assert result .shape == (3 , 12 , 14 )
798+ assert result .dims == ('band' , 'y' , 'x' )
799+ for i in range (3 ):
800+ band_result = apply (agg .isel (band = i ), kernel )
801+ np .testing .assert_allclose (result .isel (band = i ).data , band_result .data )
802+
803+
804+ @dask_array_available
805+ def test_apply_3d_dask (rgb_data ):
806+ kernel = np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]])
807+ dask_data = da .from_array (rgb_data , chunks = (1 , 6 , 7 ))
808+ agg = xr .DataArray (dask_data , dims = ['band' , 'y' , 'x' ])
809+ result = apply (agg , kernel )
810+ assert result .shape == (3 , 12 , 14 )
811+ numpy_agg = xr .DataArray (rgb_data , dims = ['band' , 'y' , 'x' ])
812+ numpy_result = apply (numpy_agg , kernel )
813+ np .testing .assert_allclose (
814+ result .data .compute (), numpy_result .data , equal_nan = True , rtol = 1e-5 )
815+
816+
817+ def test_focal_stats_3d_numpy (rgb_data ):
818+ kernel = custom_kernel (np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]]))
819+ stats = ['mean' , 'max' ]
820+ agg = xr .DataArray (rgb_data , dims = ['band' , 'y' , 'x' ])
821+ result = focal_stats (agg , kernel , stats_funcs = stats )
822+ # 3D input -> 4D output: (band, stats, y, x)
823+ assert result .shape == (3 , 2 , 12 , 14 )
824+ for i in range (3 ):
825+ band_result = focal_stats (agg .isel (band = i ), kernel , stats_funcs = stats )
826+ np .testing .assert_allclose (
827+ result .isel (band = i ).data , band_result .data , equal_nan = True )
828+
829+
830+ @dask_array_available
831+ def test_focal_stats_3d_dask (rgb_data ):
832+ kernel = custom_kernel (np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]]))
833+ stats = ['mean' , 'max' ]
834+ dask_data = da .from_array (rgb_data , chunks = (1 , 6 , 7 ))
835+ agg = xr .DataArray (dask_data , dims = ['band' , 'y' , 'x' ])
836+ result = focal_stats (agg , kernel , stats_funcs = stats )
837+ assert result .shape == (3 , 2 , 12 , 14 )
838+ numpy_agg = xr .DataArray (rgb_data , dims = ['band' , 'y' , 'x' ])
839+ numpy_result = focal_stats (numpy_agg , kernel , stats_funcs = stats )
840+ np .testing .assert_allclose (
841+ result .data .compute (), numpy_result .data , equal_nan = True , rtol = 1e-5 )
842+
843+
844+ def test_hotspots_3d_numpy ():
845+ rng = np .random .default_rng (42 )
846+ data_2d = rng .standard_normal ((10 , 12 )).astype (np .float64 )
847+ # stack 3 copies with different scales to avoid zero-std bands
848+ data_3d = np .stack ([data_2d , data_2d * 2 , data_2d * 0.5 ])
849+ kernel = np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]], dtype = np .float64 )
850+ agg = xr .DataArray (data_3d , dims = ['band' , 'y' , 'x' ])
851+ result = hotspots (agg , kernel )
852+ assert result .shape == (3 , 10 , 12 )
853+ assert result .dims == ('band' , 'y' , 'x' )
854+ for i in range (3 ):
855+ band_result = hotspots (agg .isel (band = i ), kernel )
856+ np .testing .assert_array_equal (result .isel (band = i ).data , band_result .data )
857+
858+
859+ @dask_array_available
860+ def test_hotspots_3d_dask ():
861+ rng = np .random .default_rng (42 )
862+ data_2d = rng .standard_normal ((10 , 12 )).astype (np .float64 )
863+ data_3d = np .stack ([data_2d , data_2d * 2 , data_2d * 0.5 ])
864+ kernel = np .array ([[0 , 1 , 0 ], [1 , 1 , 1 ], [0 , 1 , 0 ]], dtype = np .float64 )
865+ # numpy reference
866+ numpy_agg = xr .DataArray (data_3d , dims = ['band' , 'y' , 'x' ])
867+ numpy_result = hotspots (numpy_agg , kernel )
868+ # dask
869+ dask_data = da .from_array (data_3d , chunks = (1 , 5 , 6 ))
870+ dask_agg = xr .DataArray (dask_data , dims = ['band' , 'y' , 'x' ])
871+ dask_result = hotspots (dask_agg , kernel )
872+ assert dask_result .shape == (3 , 10 , 12 )
873+ np .testing .assert_array_equal (
874+ dask_result .data .compute (), numpy_result .data )
0 commit comments