Skip to content

Commit e9c15f1

Browse files
Nikhil0250copybara-github
authored andcommitted
Use CappedTag to prevent potential out of bound reads.
PiperOrigin-RevId: 879112470
1 parent d09285c commit e9c15f1

1 file changed

Lines changed: 32 additions & 28 deletions

File tree

hwy/contrib/math/fast_math-inl.h

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ HWY_INLINE V FastTan(D d, V x) {
169169
// Convert to Integer Vector (Signed)
170170
auto idx_int = ConvertTo(RebindToSigned<D>(), idx_float);
171171

172-
HWY_ALIGN static constexpr T arr_a[] = {
172+
HWY_ALIGN static constexpr T arr_a[8] = {
173173
static_cast<T>(630.25357464271012), static_cast<T>(572.95779513082321),
174174
static_cast<T>(343.77467707849392), static_cast<T>(572.95779513082321),
175175
static_cast<T>(229.18311805232929), static_cast<T>(57.295779513082323),
176176
static_cast<T>(57.295779513082323), static_cast<T>(57.295779513082323)};
177177

178-
HWY_ALIGN static constexpr T arr_b[] = {static_cast<T>(0.0000000000000000),
178+
HWY_ALIGN static constexpr T arr_b[8] = {static_cast<T>(0.0000000000000000),
179179
static_cast<T>(10.0000000000000000),
180180
static_cast<T>(46.0000000000000000),
181181
static_cast<T>(217.00000000000000),
@@ -184,7 +184,7 @@ HWY_INLINE V FastTan(D d, V x) {
184184
static_cast<T>(542.00000000000000),
185185
static_cast<T>(542.00000000000000)};
186186

187-
HWY_ALIGN static constexpr T arr_c[] = {
187+
HWY_ALIGN static constexpr T arr_c[8] = {
188188
static_cast<T>(-57.295779513082323),
189189
static_cast<T>(-229.18311805232929),
190190
static_cast<T>(-286.47889756541161),
@@ -194,7 +194,7 @@ HWY_INLINE V FastTan(D d, V x) {
194194
static_cast<T>(-630.25357464271012),
195195
static_cast<T>(-630.25357464271012)};
196196

197-
HWY_ALIGN static constexpr T arr_d[] = {
197+
HWY_ALIGN static constexpr T arr_d[8] = {
198198
static_cast<T>(632.00000000000000), static_cast<T>(657.00000000000000),
199199
static_cast<T>(541.00000000000000), static_cast<T>(1252.0000000000000),
200200
static_cast<T>(910.00000000000000), static_cast<T>(990.00000000000000),
@@ -203,10 +203,11 @@ HWY_INLINE V FastTan(D d, V x) {
203203
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
204204
// Cast to "Indices" Type
205205
auto idx = IndicesFromVec(d, idx_int);
206-
a = TableLookupLanes(Load(d, arr_a), idx);
207-
b = TableLookupLanes(Load(d, arr_b), idx);
208-
c = TableLookupLanes(Load(d, arr_c), idx);
209-
d_val = TableLookupLanes(Load(d, arr_d), idx);
206+
CappedTag<T, 8> d8;
207+
a = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_a)), idx);
208+
b = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_b)), idx);
209+
c = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_c)), idx);
210+
d_val = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_d)), idx);
210211
} else {
211212
auto idx = IndicesFromVec(d, idx_int);
212213
FixedTag<T, 4> d4;
@@ -331,20 +332,20 @@ HWY_INLINE V FastAtan(D d, V val) {
331332
idx_i = Add(idx_i, And(VecFromMask(DI(), mask60), one_i));
332333
idx_i = Add(idx_i, And(VecFromMask(DI(), mask75), one_i));
333334

334-
HWY_ALIGN static constexpr T arr_a[] = {
335+
HWY_ALIGN static constexpr T arr_a[8] = {
335336
static_cast<T>(630.25357464271012), static_cast<T>(572.95779513082321),
336337
static_cast<T>(343.77467707849392), static_cast<T>(572.95779513082321),
337338
static_cast<T>(229.18311805232929), static_cast<T>(57.295779513082323),
338339
static_cast<T>(57.295779513082323), static_cast<T>(57.295779513082323)};
339-
HWY_ALIGN static constexpr T arr_b[] = {static_cast<T>(0.0000000000000000),
340+
HWY_ALIGN static constexpr T arr_b[8] = {static_cast<T>(0.0000000000000000),
340341
static_cast<T>(10.0000000000000000),
341342
static_cast<T>(46.0000000000000000),
342343
static_cast<T>(217.00000000000000),
343344
static_cast<T>(297.00000000000000),
344345
static_cast<T>(542.00000000000000),
345346
static_cast<T>(542.00000000000000),
346347
static_cast<T>(542.00000000000000)};
347-
HWY_ALIGN static constexpr T arr_c[] = {
348+
HWY_ALIGN static constexpr T arr_c[8] = {
348349
static_cast<T>(-57.295779513082323),
349350
static_cast<T>(-229.18311805232929),
350351
static_cast<T>(-286.47889756541161),
@@ -353,18 +354,19 @@ HWY_INLINE V FastAtan(D d, V val) {
353354
static_cast<T>(-630.25357464271012),
354355
static_cast<T>(-630.25357464271012),
355356
static_cast<T>(-630.25357464271012)};
356-
HWY_ALIGN static constexpr T arr_d[] = {
357+
HWY_ALIGN static constexpr T arr_d[8] = {
357358
static_cast<T>(632.00000000000000), static_cast<T>(657.00000000000000),
358359
static_cast<T>(541.00000000000000), static_cast<T>(1252.0000000000000),
359360
static_cast<T>(910.00000000000000), static_cast<T>(990.00000000000000),
360361
static_cast<T>(990.00000000000000), static_cast<T>(990.00000000000000)};
361362

362363
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
363364
auto idx = IndicesFromVec(d, idx_i);
364-
a = TableLookupLanes(Load(d, arr_a), idx);
365-
b = TableLookupLanes(Load(d, arr_b), idx);
366-
c = TableLookupLanes(Load(d, arr_c), idx);
367-
d_coef = TableLookupLanes(Load(d, arr_d), idx);
365+
CappedTag<T, 8> d8;
366+
a = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_a)), idx);
367+
b = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_b)), idx);
368+
c = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_c)), idx);
369+
d_coef = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_d)), idx);
368370
} else {
369371
auto idx = IndicesFromVec(d, idx_i);
370372
FixedTag<T, 4> d4;
@@ -520,7 +522,7 @@ HWY_INLINE V FastTanh(D d, V val) {
520522
// Clamp index to 7
521523
idx_i = Min(idx_i, Set(DI(), 7));
522524

523-
HWY_ALIGN static constexpr T arr_a[] = {
525+
HWY_ALIGN static constexpr T arr_a[8] = {
524526
static_cast<T>(-2870.653300658652),
525527
static_cast<T>(-193.8913447691486),
526528
static_cast<T>(-37.25783093771139),
@@ -530,7 +532,7 @@ HWY_INLINE V FastTanh(D d, V val) {
530532
static_cast<T>(-0.9603919422736032),
531533
static_cast<T>(-0.4265454062350802)};
532534
// arr_b is not needed since its always 1.0
533-
HWY_ALIGN static constexpr T arr_c[] = {
535+
HWY_ALIGN static constexpr T arr_c[8] = {
534536
static_cast<T>(-316.5640994591445),
535537
static_cast<T>(-49.14374182730444),
536538
static_cast<T>(-15.69264419046708),
@@ -540,7 +542,7 @@ HWY_INLINE V FastTanh(D d, V val) {
540542
static_cast<T>(-0.9298342163526662),
541543
static_cast<T>(-0.426230503963466)};
542544

543-
HWY_ALIGN static constexpr T arr_d[] = {
545+
HWY_ALIGN static constexpr T arr_d[8] = {
544546
static_cast<T>(-2838.258534620734),
545547
static_cast<T>(-181.5331279956489),
546548
static_cast<T>(-30.30794802185292),
@@ -552,9 +554,10 @@ HWY_INLINE V FastTanh(D d, V val) {
552554

553555
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
554556
auto idx = IndicesFromVec(d, idx_i);
555-
a = TableLookupLanes(Load(d, arr_a), idx);
556-
c = TableLookupLanes(Load(d, arr_c), idx);
557-
d_coef = TableLookupLanes(Load(d, arr_d), idx);
557+
CappedTag<T, 8> d8;
558+
a = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_a)), idx);
559+
c = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_c)), idx);
560+
d_coef = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_d)), idx);
558561
} else {
559562
auto idx = IndicesFromVec(d, idx_i);
560563
FixedTag<T, 4> d4;
@@ -731,7 +734,7 @@ HWY_INLINE V FastLog(D d, V x) {
731734
// Clamp index to 7 to handle overshoots
732735
idx_i = Min(idx_i, Set(RebindToSigned<D>(), 7));
733736

734-
HWY_ALIGN static constexpr T arr_a[] = {
737+
HWY_ALIGN static constexpr T arr_a[8] = {
735738
static_cast<T>(-9.9805647568302591e-01),
736739
static_cast<T>(-9.9957356952094290e-01),
737740
static_cast<T>(-9.9997448030468128e-01),
@@ -741,7 +744,7 @@ HWY_INLINE V FastLog(D d, V x) {
741744
static_cast<T>(-1.0012578436820159e+00),
742745
static_cast<T>(-1.0026088937292035e+00)};
743746
// b array is not needed since b is always 1.0.
744-
HWY_ALIGN static constexpr T arr_c[] = {
747+
HWY_ALIGN static constexpr T arr_c[8] = {
745748
static_cast<T>(-5.8272115256950630e-01),
746749
static_cast<T>(-5.4794075644717266e-01),
747750
static_cast<T>(-5.1959981902435026e-01),
@@ -750,7 +753,7 @@ HWY_INLINE V FastLog(D d, V x) {
750753
static_cast<T>(-4.5972782480224245e-01),
751754
static_cast<T>(-4.4546134537646059e-01),
752755
static_cast<T>(-4.3319821691832594e-01)};
753-
HWY_ALIGN static constexpr T arr_d[] = {
756+
HWY_ALIGN static constexpr T arr_d[8] = {
754757
static_cast<T>(-4.3704086438791473e-01),
755758
static_cast<T>(-4.5946229210571821e-01),
756759
static_cast<T>(-4.8168192392472370e-01),
@@ -762,9 +765,10 @@ HWY_INLINE V FastLog(D d, V x) {
762765

763766
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
764767
auto idx = IndicesFromVec(d, idx_i);
765-
a = TableLookupLanes(Load(d, arr_a), idx);
766-
c = TableLookupLanes(Load(d, arr_c), idx);
767-
d_coef = TableLookupLanes(Load(d, arr_d), idx);
768+
CappedTag<T, 8> d8;
769+
a = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_a)), idx);
770+
c = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_c)), idx);
771+
d_coef = TableLookupLanes(ResizeBitCast(d, Load(d8, arr_d)), idx);
768772
} else {
769773
auto idx = IndicesFromVec(d, idx_i);
770774
FixedTag<T, 4> d4;

0 commit comments

Comments
 (0)