@@ -82,10 +82,132 @@ static bool
8282add_impl_hwy (ImageBuf& R, const ImageBuf& A, const ImageBuf& B, ROI roi,
8383 int nthreads)
8484{
85+ auto op = [](auto /* d*/ , auto a, auto b) {
86+ return hn::Add (a, b);
87+ };
88+
89+ // Special-case: RGBA images but ROI is RGB (strided channel subset). We
90+ // still can SIMD the RGB channels by processing full RGBA and preserving
91+ // alpha exactly (bitwise) from the destination.
92+ if (roi.chbegin == 0 && roi.chend == 3 ) {
93+ // Only support same-type float/half/double in this fast path.
94+ constexpr bool floaty = (std::is_same_v<Rtype, float >
95+ || std::is_same_v<Rtype, double >
96+ || std::is_same_v<Rtype, half>)
97+ && std::is_same_v<Rtype, Atype>
98+ && std::is_same_v<Rtype, Btype>;
99+ if constexpr (floaty) {
100+ auto Rv = HwyPixels (R);
101+ auto Av = HwyPixels (A);
102+ auto Bv = HwyPixels (B);
103+ if (Rv.nchannels >= 4 && Av.nchannels >= 4 && Bv.nchannels >= 4
104+ && ChannelsContiguous<Rtype>(Rv, 4 )
105+ && ChannelsContiguous<Atype>(Av, 4 )
106+ && ChannelsContiguous<Btype>(Bv, 4 )) {
107+ ROI roi4 = roi;
108+ roi4.chbegin = 0 ;
109+ roi4.chend = 4 ;
110+ using MathT = typename SimdMathType<Rtype>::type;
111+ const hn::ScalableTag<MathT> d;
112+ const size_t lanes = hn::Lanes (d);
113+ ImageBufAlgo::parallel_image (roi4, nthreads, [&](ROI roi4) {
114+ for (int y = roi4.ybegin ; y < roi4.yend ; ++y) {
115+ Rtype* r_row = RoiRowPtr<Rtype>(Rv, y, roi4);
116+ const Atype* a_row = RoiRowPtr<Atype>(Av, y, roi4);
117+ const Btype* b_row = RoiRowPtr<Btype>(Bv, y, roi4);
118+ const size_t npixels = static_cast <size_t >(roi4.width ());
119+
120+ size_t x = 0 ;
121+ for (; x + lanes <= npixels; x += lanes) {
122+ const size_t off = x * 4 ;
123+ if constexpr (std::is_same_v<Rtype, half>) {
124+ using T16 = hwy::float16_t ;
125+ auto d16 = hn::Rebind<T16, decltype (d)>();
126+ const T16* a16
127+ = reinterpret_cast <const T16*>(a_row + off);
128+ const T16* b16
129+ = reinterpret_cast <const T16*>(b_row + off);
130+ T16* r16 = reinterpret_cast <T16*>(r_row + off);
131+
132+ hn::Vec<decltype (d16)> ar16, ag16, ab16, aa16;
133+ hn::Vec<decltype (d16)> br16, bg16, bb16, ba16;
134+ hn::Vec<decltype (d16)> dr16, dg16, db16, da16;
135+ hn::LoadInterleaved4 (d16, a16, ar16, ag16, ab16,
136+ aa16);
137+ hn::LoadInterleaved4 (d16, b16, br16, bg16, bb16,
138+ ba16);
139+ hn::LoadInterleaved4 (d16, r16, dr16, dg16, db16,
140+ da16);
141+ (void )aa16;
142+ (void )ba16;
143+ (void )dr16;
144+ (void )dg16;
145+ (void )db16;
146+
147+ auto rr = op (d, hn::PromoteTo (d, ar16),
148+ hn::PromoteTo (d, br16));
149+ auto rg = op (d, hn::PromoteTo (d, ag16),
150+ hn::PromoteTo (d, bg16));
151+ auto rb = op (d, hn::PromoteTo (d, ab16),
152+ hn::PromoteTo (d, bb16));
153+
154+ auto rr16 = hn::DemoteTo (d16, rr);
155+ auto rg16 = hn::DemoteTo (d16, rg);
156+ auto rb16 = hn::DemoteTo (d16, rb);
157+ hn::StoreInterleaved4 (rr16, rg16, rb16, da16, d16,
158+ r16);
159+ } else {
160+ hn::Vec<decltype (d)> ar, ag, ab, aa;
161+ hn::Vec<decltype (d)> br, bg, bb, ba;
162+ hn::Vec<decltype (d)> dr, dg, db, da;
163+ hn::LoadInterleaved4 (d, a_row + off, ar, ag, ab,
164+ aa);
165+ hn::LoadInterleaved4 (d, b_row + off, br, bg, bb,
166+ ba);
167+ hn::LoadInterleaved4 (d, r_row + off, dr, dg, db,
168+ da);
169+ (void )aa;
170+ (void )ba;
171+ (void )dr;
172+ (void )dg;
173+ (void )db;
174+
175+ auto rr = op (d, ar, br);
176+ auto rg = op (d, ag, bg);
177+ auto rb = op (d, ab, bb);
178+ hn::StoreInterleaved4 (rr, rg, rb, da, d,
179+ r_row + off);
180+ }
181+ }
182+
183+ for (; x < npixels; ++x) {
184+ const size_t off = x * 4 ;
185+ if constexpr (std::is_same_v<Rtype, half>) {
186+ r_row[off + 0 ]
187+ = half ((float )a_row[off + 0 ]
188+ + (float )b_row[off + 0 ]);
189+ r_row[off + 1 ]
190+ = half ((float )a_row[off + 1 ]
191+ + (float )b_row[off + 1 ]);
192+ r_row[off + 2 ]
193+ = half ((float )a_row[off + 2 ]
194+ + (float )b_row[off + 2 ]);
195+ } else {
196+ r_row[off + 0 ] = a_row[off + 0 ] + b_row[off + 0 ];
197+ r_row[off + 1 ] = a_row[off + 1 ] + b_row[off + 1 ];
198+ r_row[off + 2 ] = a_row[off + 2 ] + b_row[off + 2 ];
199+ }
200+ // Preserve alpha (off+3).
201+ }
202+ }
203+ });
204+ return true ;
205+ }
206+ }
207+ }
208+
85209 return hwy_binary_perpixel_op<Rtype, Atype, Btype>(R, A, B, roi, nthreads,
86- [](auto /* d*/ , auto a, auto b) {
87- return hn::Add (a, b);
88- });
210+ op);
89211}
90212
91213template <class Rtype , class Atype >
@@ -143,6 +265,25 @@ add_impl(ImageBuf& R, const ImageBuf& A, const ImageBuf& B, ROI roi,
143265 return add_impl_hwy_native_int<Rtype>(R, A, B, roi, nthreads);
144266 return add_impl_hwy<Rtype, Atype, Btype>(R, A, B, roi, nthreads);
145267 }
268+
269+ // Handle the common RGBA + RGB ROI strided case (preserving alpha).
270+ constexpr bool floaty_strided = (std::is_same_v<Rtype, float >
271+ || std::is_same_v<Rtype, double >
272+ || std::is_same_v<Rtype, half>)
273+ && std::is_same_v<Rtype, Atype>
274+ && std::is_same_v<Rtype, Btype>;
275+ if constexpr (floaty_strided) {
276+ if (roi.chbegin == 0 && roi.chend == 3 ) {
277+ const bool contig4 = (Rv.nchannels >= 4 && Av.nchannels >= 4
278+ && Bv.nchannels >= 4 )
279+ && ChannelsContiguous<Rtype>(Rv, 4 )
280+ && ChannelsContiguous<Atype>(Av, 4 )
281+ && ChannelsContiguous<Btype>(Bv, 4 );
282+ if (contig4)
283+ return add_impl_hwy<Rtype, Atype, Btype>(R, A, B, roi,
284+ nthreads);
285+ }
286+ }
146287 }
147288#endif
148289 return add_impl_scalar<Rtype, Atype, Btype>(R, A, B, roi, nthreads);
@@ -177,10 +318,132 @@ static bool
177318sub_impl_hwy (ImageBuf& R, const ImageBuf& A, const ImageBuf& B, ROI roi,
178319 int nthreads)
179320{
321+ auto op = [](auto /* d*/ , auto a, auto b) {
322+ return hn::Sub (a, b);
323+ };
324+
325+ // Special-case: RGBA images but ROI is RGB (strided channel subset). We
326+ // still can SIMD the RGB channels by processing full RGBA and preserving
327+ // alpha exactly (bitwise) from the destination.
328+ if (roi.chbegin == 0 && roi.chend == 3 ) {
329+ // Only support same-type float/half/double in this fast path.
330+ constexpr bool floaty = (std::is_same_v<Rtype, float >
331+ || std::is_same_v<Rtype, double >
332+ || std::is_same_v<Rtype, half>)
333+ && std::is_same_v<Rtype, Atype>
334+ && std::is_same_v<Rtype, Btype>;
335+ if constexpr (floaty) {
336+ auto Rv = HwyPixels (R);
337+ auto Av = HwyPixels (A);
338+ auto Bv = HwyPixels (B);
339+ if (Rv.nchannels >= 4 && Av.nchannels >= 4 && Bv.nchannels >= 4
340+ && ChannelsContiguous<Rtype>(Rv, 4 )
341+ && ChannelsContiguous<Atype>(Av, 4 )
342+ && ChannelsContiguous<Btype>(Bv, 4 )) {
343+ ROI roi4 = roi;
344+ roi4.chbegin = 0 ;
345+ roi4.chend = 4 ;
346+ using MathT = typename SimdMathType<Rtype>::type;
347+ const hn::ScalableTag<MathT> d;
348+ const size_t lanes = hn::Lanes (d);
349+ ImageBufAlgo::parallel_image (roi4, nthreads, [&](ROI roi4) {
350+ for (int y = roi4.ybegin ; y < roi4.yend ; ++y) {
351+ Rtype* r_row = RoiRowPtr<Rtype>(Rv, y, roi4);
352+ const Atype* a_row = RoiRowPtr<Atype>(Av, y, roi4);
353+ const Btype* b_row = RoiRowPtr<Btype>(Bv, y, roi4);
354+ const size_t npixels = static_cast <size_t >(roi4.width ());
355+
356+ size_t x = 0 ;
357+ for (; x + lanes <= npixels; x += lanes) {
358+ const size_t off = x * 4 ;
359+ if constexpr (std::is_same_v<Rtype, half>) {
360+ using T16 = hwy::float16_t ;
361+ auto d16 = hn::Rebind<T16, decltype (d)>();
362+ const T16* a16
363+ = reinterpret_cast <const T16*>(a_row + off);
364+ const T16* b16
365+ = reinterpret_cast <const T16*>(b_row + off);
366+ T16* r16 = reinterpret_cast <T16*>(r_row + off);
367+
368+ hn::Vec<decltype (d16)> ar16, ag16, ab16, aa16;
369+ hn::Vec<decltype (d16)> br16, bg16, bb16, ba16;
370+ hn::Vec<decltype (d16)> dr16, dg16, db16, da16;
371+ hn::LoadInterleaved4 (d16, a16, ar16, ag16, ab16,
372+ aa16);
373+ hn::LoadInterleaved4 (d16, b16, br16, bg16, bb16,
374+ ba16);
375+ hn::LoadInterleaved4 (d16, r16, dr16, dg16, db16,
376+ da16);
377+ (void )aa16;
378+ (void )ba16;
379+ (void )dr16;
380+ (void )dg16;
381+ (void )db16;
382+
383+ auto rr = op (d, hn::PromoteTo (d, ar16),
384+ hn::PromoteTo (d, br16));
385+ auto rg = op (d, hn::PromoteTo (d, ag16),
386+ hn::PromoteTo (d, bg16));
387+ auto rb = op (d, hn::PromoteTo (d, ab16),
388+ hn::PromoteTo (d, bb16));
389+
390+ auto rr16 = hn::DemoteTo (d16, rr);
391+ auto rg16 = hn::DemoteTo (d16, rg);
392+ auto rb16 = hn::DemoteTo (d16, rb);
393+ hn::StoreInterleaved4 (rr16, rg16, rb16, da16, d16,
394+ r16);
395+ } else {
396+ hn::Vec<decltype (d)> ar, ag, ab, aa;
397+ hn::Vec<decltype (d)> br, bg, bb, ba;
398+ hn::Vec<decltype (d)> dr, dg, db, da;
399+ hn::LoadInterleaved4 (d, a_row + off, ar, ag, ab,
400+ aa);
401+ hn::LoadInterleaved4 (d, b_row + off, br, bg, bb,
402+ ba);
403+ hn::LoadInterleaved4 (d, r_row + off, dr, dg, db,
404+ da);
405+ (void )aa;
406+ (void )ba;
407+ (void )dr;
408+ (void )dg;
409+ (void )db;
410+
411+ auto rr = op (d, ar, br);
412+ auto rg = op (d, ag, bg);
413+ auto rb = op (d, ab, bb);
414+ hn::StoreInterleaved4 (rr, rg, rb, da, d,
415+ r_row + off);
416+ }
417+ }
418+
419+ for (; x < npixels; ++x) {
420+ const size_t off = x * 4 ;
421+ if constexpr (std::is_same_v<Rtype, half>) {
422+ r_row[off + 0 ]
423+ = half ((float )a_row[off + 0 ]
424+ - (float )b_row[off + 0 ]);
425+ r_row[off + 1 ]
426+ = half ((float )a_row[off + 1 ]
427+ - (float )b_row[off + 1 ]);
428+ r_row[off + 2 ]
429+ = half ((float )a_row[off + 2 ]
430+ - (float )b_row[off + 2 ]);
431+ } else {
432+ r_row[off + 0 ] = a_row[off + 0 ] - b_row[off + 0 ];
433+ r_row[off + 1 ] = a_row[off + 1 ] - b_row[off + 1 ];
434+ r_row[off + 2 ] = a_row[off + 2 ] - b_row[off + 2 ];
435+ }
436+ // Preserve alpha (off+3).
437+ }
438+ }
439+ });
440+ return true ;
441+ }
442+ }
443+ }
444+
180445 return hwy_binary_perpixel_op<Rtype, Atype, Btype>(R, A, B, roi, nthreads,
181- [](auto /* d*/ , auto a, auto b) {
182- return hn::Sub (a, b);
183- });
446+ op);
184447}
185448#endif // defined(OIIO_USE_HWY) && OIIO_USE_HWY
186449
@@ -210,6 +473,25 @@ sub_impl(ImageBuf& R, const ImageBuf& A, const ImageBuf& B, ROI roi,
210473 return sub_impl_hwy_native_int<Rtype>(R, A, B, roi, nthreads);
211474 return sub_impl_hwy<Rtype, Atype, Btype>(R, A, B, roi, nthreads);
212475 }
476+
477+ // Handle the common RGBA + RGB ROI strided case (preserving alpha).
478+ constexpr bool floaty_strided = (std::is_same_v<Rtype, float >
479+ || std::is_same_v<Rtype, double >
480+ || std::is_same_v<Rtype, half>)
481+ && std::is_same_v<Rtype, Atype>
482+ && std::is_same_v<Rtype, Btype>;
483+ if constexpr (floaty_strided) {
484+ if (roi.chbegin == 0 && roi.chend == 3 ) {
485+ const bool contig4 = (Rv.nchannels >= 4 && Av.nchannels >= 4
486+ && Bv.nchannels >= 4 )
487+ && ChannelsContiguous<Rtype>(Rv, 4 )
488+ && ChannelsContiguous<Atype>(Av, 4 )
489+ && ChannelsContiguous<Btype>(Bv, 4 );
490+ if (contig4)
491+ return sub_impl_hwy<Rtype, Atype, Btype>(R, A, B, roi,
492+ nthreads);
493+ }
494+ }
213495 }
214496#endif
215497 return sub_impl_scalar<Rtype, Atype, Btype>(R, A, B, roi, nthreads);
0 commit comments