diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index e2a04d0dd8..f853f7af17 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -905,6 +905,41 @@ struct ElementwiseFilter > { } }; +// Snake activation: Snake_a(x) = x + (1/a) * sin^2(a*x) +// Introduced in Ziyin, Hartwig, Ueda, "Neural Networks Fail to Learn +// Periodic Functions and How to Fix It," NeurIPS 2020 (arXiv:2006.08195). +// The per-channel learnable frequency `a` is passed as the second operand +// (intended to flow through an EVT child such as Sm90RowBroadcast). +// Caller must ensure a != 0 (the formula is singular at a = 0). +template +struct Snake { + static const bool kIsHeavy = true; + + CUTLASS_HOST_DEVICE + T operator()(T const& x, T const& alpha) const { + float xf = float(x); + float af = float(alpha); + float s = fast_sin(af * xf); + return T(xf + s * s / af); + } +}; + +template +struct Snake> { + static const bool kIsHeavy = true; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& x, Array const& alpha) const { + Array result; + Snake scalar_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i){ + result[i] = scalar_op(x[i], alpha[i]); + } + return result; + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/test/unit/epilogue/thread/activation.cu b/test/unit/epilogue/thread/activation.cu index 3c98082d2c..e747d003cf 100644 --- a/test/unit/epilogue/thread/activation.cu +++ b/test/unit/epilogue/thread/activation.cu @@ -51,6 +51,16 @@ __global__ void test_Epilogue_thread_activation(T *out, T *in) { vec_out[threadIdx.x] = func(vec_in[threadIdx.x]); } +template +__global__ void test_Epilogue_thread_activation_binary(T *out, T *x, T *alpha){ + cutlass::Array *vec_out = reinterpret_cast *>(out); + cutlass::Array *vec_x = reinterpret_cast *>(x); + cutlass::Array *vec_alpha = reinterpret_cast *>(alpha); + + Func func; + vec_out[threadIdx.x] = func(vec_x[threadIdx.x], vec_alpha[threadIdx.x]); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -191,6 +201,327 @@ static double GELU_golden_output[] = { -0.168275654316, 0.220552831888, -0.159705042839, 1.549110531807 }; +static double Snake_golden_x[] = { + 0.029053429184, -0.073700162105, -0.957751321690, -0.229464130144, + -1.054150496919, -0.438869051576, 0.058027082664, -0.003119327765, + 0.432405505892, 0.530008876465, -0.193385668869, -0.186124514126, + 0.987970150471, -0.990709745322, 0.417275110485, 0.806415034595, + 0.555878890939, 1.091370569098, -0.652624600678, -0.305256813078, + -3.677834378094, 0.020385454985, 1.029275814765, -1.658167845616, + 1.657781136611, -0.627017252985, -0.956601541851, 1.061830856399, + -1.996895384427, -0.796726103109, -0.620245955882, -0.216337758007, + -0.970205843319, -0.762363183314, -0.263468892540, 0.706260146861, + -0.547733438907, 1.028736714398, 0.532700417479, 0.219131887583, + 0.410371090789, -0.860321949693, -0.134011925815, 0.378068516506, + -0.397935138621, -0.281645814281, -0.674965764600, 0.075066377376, + 0.591349731308, -0.986103386873, 0.540573540728, -0.195186025501, + 0.382170818241, -0.336124810010, 0.514623317515, -0.633037271176, + -2.225114959496, -0.424278574433, 0.864951591244, -0.748100501207, + -0.814455285078, 1.399324223175, 2.550507919197, 0.112022056199, + 0.501069610182, 0.615049456991, -2.455269089844, -0.133235353840, + -1.048553628708, 0.614094770306, 1.135570854607, 2.135839794654, + -1.132154443249, 0.416582342787, 0.199975770530, 0.870957070562, + -1.067125609377, -0.317593855414, -1.365204049444, 0.581168074356, + -1.034777548687, -1.060219263238, -0.952406555466, -0.003488014184, + 1.072586326741, -0.397457304592, -0.479034742410, 1.616406066530, + -0.340577642369, 1.352103052797, -0.691795383481, -0.623777129584, + -0.774545149758, 1.277156048075, 0.098359200811, -0.686292493575, + -0.469269332700, -1.084840021288, -0.732852705424, 0.869019317992, + -0.073697571167, -1.893281354625, -1.110974735000, 0.558214090747, + -0.451057560756, 0.447461089338, 0.197494571084, 1.828269271566, + 0.827678321070, 0.798278975336, -0.586337889400, 0.344753541998, + 0.815051610629, -0.614818900479, 1.582063437842, 0.651619160959, + 1.581119442075, 1.303423829832, 1.186057802117, -1.537848796338, + 0.439386823326, -1.989887737521, -0.530844704054, -0.092983961880, + 2.731809826905, -0.818315266683, 0.307108151906, 0.419630822952, + -0.228842566535, -0.215408664510, -0.288090633921, -1.365735055710, + -1.513522512402, 1.790313840521, 0.365196032814, -0.428147708497, + 0.944347054741, 1.889699917802, -1.242820812714, -0.295221915419, + 0.078524208465, 1.054135848649, 1.027491739965, -0.638342920129, + 1.076303143523, -0.431866264152, 0.582247177956, -1.448992543035, + -0.756821186630, -0.148059611738, 0.413561474504, -0.685984494800, + -0.145935889932, 2.010263640011, 0.703426641971, -2.445829502306, + -1.085952602443, -1.030567244448, -0.778586822390, 1.409630762219, + 0.277775313817, -0.817427176209, 1.535450399874, 0.689783169875, + -0.558834477613, -1.653817347161, 1.712497692950, 0.069523477922, + 0.629550163029, -1.335994370169, 1.938423995860, 0.110381170776, + -0.572298694654, -0.510668568690, -0.182994362066, -1.180016882135, + 0.265853520200, 0.193906608187, 1.428012477429, 0.191521664853, + 1.540343614656, 0.351226457409, -0.508023007883, 0.285515852982, + 2.007526227190, 0.513980525975, 1.671017405384, 1.609615930737, + 0.791780795594, -1.086131492859, -0.252483101516, -0.456344955137, + -0.457268634699, 0.032448480614, -0.002629782119, 0.121148264944, + -0.909985276560, -0.615054078531, -0.523624031453, 0.614206124249, + -0.370579889227, 1.692268634801, -0.032660060202, 1.486452809603, + -0.111649745792, 0.493170630381, -1.490800986310, -0.468854631740, + 0.573790416548, -2.243287177026, 0.778386044561, -0.376802969344, + 1.936419879235, -1.017792188928, 1.318757680373, 0.712829317167, + -1.937630450484, 0.758523630020, -0.624877514003, 0.228800672781, + 0.876842599715, -0.973031770011, -0.502147893064, -0.839962712295, + 0.233819743409, 1.727011814368, 0.580076431287, -0.219354728997, + -0.644087881572, 0.259351024516, -0.659000482282, -0.717816416909, + 0.611838700073, 2.284100566159, 0.976363940000, -0.168807743207, + 0.224294957866, 0.919491656813, 0.089184696718, -0.742632491463, + 1.165372466386, -1.195659887110, 0.368437873632, 1.075869087616, + -0.070620875083, -2.492959152949, 0.851985725585, 0.398781627880, + 0.710014437585, -1.460378031581, -1.532444654966, 2.996609973191, + 0.570902491622, 0.123891040827, -0.820849107832, 0.255297073043, +}; + +static double Snake_golden_alpha[] = { + 1.314910917070, 0.147520434923, 0.622555704901, 0.524100402483, + 1.499295306912, 1.385729026104, 1.795141178639, 0.265183781996, + 0.901651457402, 0.156614716932, 0.515412152127, 1.060175047396, + 0.150418342399, 0.477791536305, 1.334780431781, 1.135388813146, + 0.518837181877, 1.219604799364, 1.637917867688, 0.112347643388, + 1.631056578482, 1.426464850478, 0.746475981384, 0.395411049642, + 1.918704837193, 0.739529635714, 0.276217102422, 0.283761115984, + 1.710239296060, 1.247079459597, 1.633543719221, 1.486490394718, + 1.118833373764, 1.948919951561, 0.819215316696, 1.148877199419, + 1.675868862081, 1.275187529492, 1.737243110590, 1.196969075988, + 1.438686488808, 0.187066328946, 0.533006723738, 0.649837130844, + 0.251604756155, 0.542302684086, 0.291902715878, 0.628149845909, + 1.307800444102, 0.793181140043, 0.803343837522, 0.498063358466, + 0.607257861893, 1.879643716654, 1.331267231969, 1.257348910767, + 0.425163431576, 1.485340916106, 0.410464738148, 0.820965339340, + 1.980094366210, 1.315999543723, 1.158204513172, 1.400767076881, + 1.701418648361, 1.574399831938, 0.535191336732, 0.160990463418, + 0.699360791312, 0.608707664354, 0.500867402814, 1.891528457237, + 1.765098490298, 0.697887973517, 1.345333464060, 0.851700612015, + 1.837640420507, 0.971818519916, 0.603272316346, 0.568592264619, + 1.166599454910, 0.599209056194, 1.210713381425, 1.805863478845, + 0.858860959767, 0.516709442399, 1.995321452341, 1.068099957985, + 0.272727883130, 0.189521113307, 0.308333347666, 1.292147479236, + 1.604950792290, 0.902103936919, 0.220702641689, 0.825076644362, + 1.992630622456, 1.105317255688, 1.945048917466, 1.735481434246, + 0.121813941691, 1.469371456784, 1.395249701150, 1.120243627777, + 0.606967860910, 1.317827417302, 0.311949129832, 0.926053976271, + 0.962075042025, 1.912250262290, 1.764120586719, 0.600439196427, + 1.051113614796, 0.439438573007, 1.833992894755, 1.753985282690, + 0.667045103752, 1.314004040245, 1.257043401733, 0.390394610244, + 1.548770520143, 1.124820157227, 1.579390309398, 1.107671977171, + 0.101086602643, 0.715896508309, 0.137005810533, 1.865287370903, + 1.769571567864, 1.680164505786, 0.684276838265, 0.210057816339, + 1.768218238488, 1.899203946066, 0.262741558929, 1.023381880302, + 0.231503785090, 1.545144113989, 1.555085415683, 0.343943782550, + 1.003036518388, 1.144626827640, 0.603607594986, 1.757622778062, + 0.903962086382, 0.502416590340, 1.124662568681, 1.486869031271, + 0.482187020440, 0.692260953472, 1.990783777656, 1.334768309515, + 0.932390159438, 1.083394097968, 0.329907972150, 0.526924940360, + 0.742362568080, 1.217786565069, 0.537217991933, 0.518413030458, + 0.234886863417, 1.299095618813, 0.534989389241, 1.820298024712, + 1.733307260482, 0.234628964788, 0.552208805301, 1.371057778763, + 0.507049934004, 0.351392512578, 1.877477057103, 1.184981877318, + 0.998074949992, 1.590776906152, 1.634244295757, 0.461778837288, + 0.284168547035, 0.918997246572, 0.904799383738, 0.987346869270, + 1.485244113974, 1.379392639857, 1.969913901595, 0.286993955189, + 0.864980435994, 0.744674950250, 1.737177819070, 0.572447034449, + 0.461396926038, 0.952365740883, 0.901575115685, 0.629235774867, + 0.574632250976, 1.854204638624, 0.941948415602, 1.736563299047, + 1.145618093655, 0.196117826097, 1.998636689984, 1.688452411652, + 1.941092888841, 1.860097267715, 1.712521895387, 0.415991110147, + 1.022718138356, 0.506119868478, 0.861976555844, 0.211407259947, + 0.820048926056, 1.972086803181, 0.603885810527, 1.589734143702, + 0.964515897944, 0.903714223381, 1.918903517633, 1.991303110036, + 1.155959814471, 1.464975723063, 0.394113968021, 0.663744868440, + 1.940547793441, 1.200442552551, 1.130170882611, 1.521153564720, + 0.208614018524, 1.209937429472, 1.055415727547, 1.720167794892, + 0.399122183085, 1.925479916221, 0.252211783957, 0.453067425863, + 1.230566702255, 1.382903851848, 0.546887400502, 0.327784566500, + 1.791545896846, 0.567809160798, 1.229586391714, 1.276824869631, + 0.896527339138, 1.208977349653, 1.093287159511, 1.875941889699, + 0.488092478905, 1.460764421500, 0.553503309970, 0.851993108903, + 1.376211423624, 0.669994451618, 0.700736672917, 1.528542535587, + 0.237831917537, 0.970742492975, 1.997063437623, 1.992583250925, +}; + +static double Snake_golden_output[] = { + 0.030162807937, -0.072898904881, -0.451235357962, -0.202001042643, + -0.387233066601, -0.203293172314, 0.064049747723, -0.003116747473, + 0.592620610942, 0.573902473192, -0.174174024192, -0.149871786124, + 1.133713590934, -0.555749922155, 0.626624953928, 1.360172568431, + 0.711804070287, 1.865072258447, -0.183375257350, -0.294792169803, + -3.629557790529, 0.020978079314, 1.676228747578, -0.718106263419, + 1.658581601367, -0.356520741146, -0.709666913330, 1.372201772507, + -1.954214035750, -0.233634934140, -0.179485640924, -0.149132420195, + -0.270902259126, -0.252957737752, -0.207479952329, 1.164095231296, + -0.171221404601, 1.761510017151, 0.900069003977, 0.275302814603, + 0.625778121370, -0.723055354236, -0.124455822242, 0.469099634932, + -0.358225882828, -0.238961475856, -0.543692958091, 0.078603354731, + 0.964472741676, -0.359814582475, 0.760935446722, -0.176270715394, + 0.469282660995, -0.150556971280, 0.815370048789, -0.226984684910, + -0.677733423654, -0.190473406374, 1.159349608758, -0.343584035169, + -0.310315044483, 2.104861670423, 2.580534178202, 0.129456397669, + 0.834279493899, 1.046249631450, -0.707274523119, -0.130377944208, + -0.407949718281, 0.833151490387, 1.714738405775, 2.459414719751, + -0.663025883113, 0.534320399265, 0.252490523020, 1.406886887967, + -0.601681821943, -0.222643876025, -0.473145016659, 0.766324510773, + -0.286018010333, -0.472545537912, -0.262369684875, -0.003466043901, + 1.810937227027, -0.316972658667, -0.144678092813, 2.530137759850, + -0.309033953907, 1.691064488623, -0.546457367230, -0.220875718256, + -0.216018047603, 2.202461263004, 0.100494060288, -0.337481553094, + -0.144343811696, -0.299469854864, -0.229514049722, 1.442971110350, + -0.073035976981, -1.808978356092, -0.394564481565, 0.864091753499, + -0.330622849956, 0.682118910180, 0.209646481704, 2.892053970254, + 1.358707200042, 1.320198072576, -0.167539638392, 0.415105207584, + 1.358353624039, -0.452712667833, 1.612896990941, 1.123591800211, + 2.715150738504, 2.049233543250, 1.976511261696, -0.720306639514, + 0.694995256047, -1.441499342653, -0.180764329912, -0.083440827570, + 3.467215173315, -0.391315325873, 0.320022288481, 0.686256722794, + -0.141128137179, -0.140792612702, -0.232030129207, -0.984560009377, + -1.399641042245, 1.824741858878, 0.400129978571, -0.252253096938, + 1.147531900195, 1.921001492668, -0.680375923993, -0.265348040001, + 0.084696204485, 1.816932450233, 1.587114829655, -0.176514350448, + 1.832056765146, -0.339622624154, 0.912045912589, -0.980690637506, + -0.492679505498, -0.132937177582, 0.683745005253, -0.214917977569, + -0.126200761750, 2.632874230636, 0.863758613512, -0.695016654463, + -0.384453837757, -0.288605904012, -0.471480605518, 2.268948289718, + 0.295873275711, -0.230375350268, 2.537377334201, 1.186352685088, + -0.167005613002, -1.043649922380, 2.903221727544, 0.076130458029, + 0.823777276176, -0.753543706804, 2.059843094880, 0.124736844178, + -0.279443963185, -0.179416250318, -0.129880671289, -0.598191370944, + 0.285899830656, 0.228096484580, 2.449564888553, 0.227308464352, + 1.922903611519, 0.508484639786, -0.148227747872, 0.308859084627, + 3.132178423292, 0.701284943286, 1.703208304556, 2.717753607115, + 1.068396770315, -0.310346108239, -0.195995603805, -0.328868361179, + -0.339856050311, 0.034398425285, -0.002623267849, 0.146261819239, + -0.258880864936, -0.541223489008, -0.148655781444, 1.053168999018, + -0.146930666868, 1.692289244233, -0.030835252178, 2.294266447178, + -0.098956195752, 0.613732471310, -0.422860765158, -0.422534094080, + 0.824433000317, -1.777238421851, 1.118104980062, -0.176829371659, + 2.884464242244, -0.317590030795, 1.490282214541, 1.203599102469, + -1.405324950154, 1.306823510371, -0.474073201204, 0.263281246252, + 1.383400285807, -0.267965517976, -0.246482570724, -0.237437585109, + 0.245215978030, 2.350318084309, 0.893000796635, -0.140440535337, + -0.482127721991, 0.378450318066, -0.550474425828, -0.492483912395, + 0.991771804327, 2.284311962680, 1.450002778639, -0.159476705205, + 0.309676599108, 1.357490057435, 0.098925576361, -0.225763312326, + 1.999603147308, -0.381426939450, 0.508992666191, 1.509134957604, + -0.068187571496, -2.335591589999, 1.224848727676, 0.529137461176, + 1.209259963395, -0.433076347361, -0.429737367104, 3.639504604020, + 0.647943774070, 0.138719258661, -0.322459202822, 0.374343913903, +}; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_snake, device_f32) { + + int const kN = 256; + int const kV = 4; + + using Element = float; + using Func = cutlass::epilogue::thread::Snake>; + + double tolerance = 1e-5; + + // + // Construct workspace + // + cutlass::HostTensor tensor_Destination({1, kN}); + cutlass::HostTensor tensor_X({1, kN}); + cutlass::HostTensor tensor_Alpha({1, kN}); + + for (int i = 0; i < kN; ++i) { + tensor_X.host_data(i) = Element(Snake_golden_x[i]); + tensor_Alpha.host_data(i) = Element(Snake_golden_alpha[i]); + } + + tensor_Destination.sync_device(); + tensor_X.sync_device(); + tensor_Alpha.sync_device(); + + // + // Launch the kernel + // + dim3 grid(1,1,1); + dim3 block(kN / kV, 1, 1); + + test_Epilogue_thread_activation_binary<<< grid, block >>>( + tensor_Destination.device_data(), + tensor_X.device_data(), + tensor_Alpha.device_data()); + + tensor_Destination.sync_host(); + + // + // Verify + // + + for (int i = 0; i < kN; ++i) { + Element x_in = Element(Snake_golden_x[i]); + Element alpha_in = Element(Snake_golden_alpha[i]); + Element got = tensor_Destination.host_data(i); + Element expected = Element(Snake_golden_output[i]); + + double rel_error = (double(got) - double(expected)) / double(expected); + + EXPECT_LT(std::abs(rel_error), tolerance) + << "Input[" << i << "]: x=" << x_in << ", alpha=" << alpha_in + << ", Got: " << got << ", expected: " << expected; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_snake, device_bf16) { + + int const kN = 256; + int const kV = 8; + + using Element = cutlass::bfloat16_t; + using Func = cutlass::epilogue::thread::Snake>; + + double tolerance = 0.02; + + // + // Construct workspace + // + cutlass::HostTensor tensor_Destination({1, kN}); + cutlass::HostTensor tensor_X({1, kN}); + cutlass::HostTensor tensor_Alpha({1, kN}); + + for (int i = 0; i < kN; ++i) { + tensor_X.host_data(i) = Element(Snake_golden_x[i]); + tensor_Alpha.host_data(i) = Element(Snake_golden_alpha[i]); + } + + tensor_Destination.sync_device(); + tensor_X.sync_device(); + tensor_Alpha.sync_device(); + + // + // Launch the kernel + // + dim3 grid(1,1,1); + dim3 block(kN / kV, 1, 1); + + test_Epilogue_thread_activation_binary<<< grid, block >>>( + tensor_Destination.device_data(), + tensor_X.device_data(), + tensor_Alpha.device_data()); + + tensor_Destination.sync_host(); + + // + // Verify + // + + for (int i = 0; i < kN; ++i) { + Element x_in = Element(Snake_golden_x[i]); + Element alpha_in = Element(Snake_golden_alpha[i]); + Element got = tensor_Destination.host_data(i); + Element expected = Element(Snake_golden_output[i]); + + double rel_error = (double(got) - double(expected)) / double(expected); + + EXPECT_LT(std::abs(rel_error), tolerance) + << "Input[" << i << "]: x=" << x_in << ", alpha=" << alpha_in + << ", Got: " << got << ", expected: " << expected; + } +} + ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(Epilogue_thread_gelu_taylor, device_f32) { @@ -248,7 +579,7 @@ TEST(Epilogue_thread_gelu_taylor, device_f32) { case 218: tolerance_override = 0.013; break; } - EXPECT_LT(std::abs(rel_error), tolerance_override) + EXPECT_LT(std::abs(rel_error), tolerance_override) << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; } }