@@ -1018,7 +1018,7 @@ std::shared_ptr<Tensor> EqualsForward(const std::shared_ptr<Tensor> &a, const st
10181018 DISPATCH (a->Dtype (),
10191019 return BinaryForward (a, b,
10201020 [] __device__ (auto x, auto y) { return (x == y) ? decltype (x){1 } : decltype (x){0 }; });
1021- , INFINI_ALL_TYPES )
1021+ , INFINI_ALL_NUMERIC_TYPES )
10221022}
10231023
10241024std::shared_ptr<Tensor> EqualsScalarForward (const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1033,7 +1033,7 @@ std::shared_ptr<Tensor> EqualsScalarForward(const std::shared_ptr<Tensor> &a, fl
10331033std::shared_ptr<Tensor> LtForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10341034 DISPATCH (a->Dtype (), return BinaryForward (
10351035 a, b, [] __device__ (auto x, auto y) { return x < y ? decltype (x){1 } : decltype (x){0 }; });
1036- , INFINI_ALL_TYPES )
1036+ , INFINI_ALL_NUMERIC_TYPES )
10371037}
10381038
10391039std::shared_ptr<Tensor> LtScalarForward (const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1042,14 +1042,14 @@ std::shared_ptr<Tensor> LtScalarForward(const std::shared_ptr<Tensor> &a, float
10421042 return (x < static_cast <decltype (x)>(scalar)) ? decltype (x){1 }
10431043 : decltype (x){0 };
10441044 });
1045- , INFINI_ALL_TYPES )
1045+ , INFINI_ALL_NUMERIC_TYPES )
10461046}
10471047
10481048std::shared_ptr<Tensor> LeForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10491049 DISPATCH (a->Dtype (),
10501050 return BinaryForward (a, b,
10511051 [] __device__ (auto x, auto y) { return (x <= y) ? decltype (x){1 } : decltype (x){0 }; });
1052- , INFINI_ALL_TYPES )
1052+ , INFINI_ALL_NUMERIC_TYPES )
10531053}
10541054
10551055std::shared_ptr<Tensor> LeScalarForward (const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1058,13 +1058,13 @@ std::shared_ptr<Tensor> LeScalarForward(const std::shared_ptr<Tensor> &a, float
10581058 return (x <= static_cast <decltype (x)>(scalar)) ? decltype (x){1 }
10591059 : decltype (x){0 };
10601060 });
1061- , INFINI_ALL_TYPES )
1061+ , INFINI_ALL_NUMERIC_TYPES )
10621062}
10631063
10641064std::shared_ptr<Tensor> GtForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10651065 DISPATCH (a->Dtype (), return BinaryForward (
10661066 a, b, [] __device__ (auto x, auto y) { return x > y ? decltype (x){1 } : decltype (x){0 }; });
1067- , INFINI_ALL_TYPES )
1067+ , INFINI_ALL_NUMERIC_TYPES )
10681068}
10691069
10701070std::shared_ptr<Tensor> GtScalarForward (const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1073,14 +1073,14 @@ std::shared_ptr<Tensor> GtScalarForward(const std::shared_ptr<Tensor> &a, float
10731073 return (x > static_cast <decltype (x)>(scalar)) ? decltype (x){1 }
10741074 : decltype (x){0 };
10751075 });
1076- , INFINI_ALL_TYPES )
1076+ , INFINI_ALL_NUMERIC_TYPES )
10771077}
10781078
10791079std::shared_ptr<Tensor> GeForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10801080 DISPATCH (a->Dtype (),
10811081 return BinaryForward (a, b,
10821082 [] __device__ (auto x, auto y) { return (x >= y) ? decltype (x){1 } : decltype (x){0 }; });
1083- , INFINI_ALL_TYPES )
1083+ , INFINI_ALL_NUMERIC_TYPES )
10841084}
10851085
10861086std::shared_ptr<Tensor> GeScalarForward (const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1089,7 +1089,7 @@ std::shared_ptr<Tensor> GeScalarForward(const std::shared_ptr<Tensor> &a, float
10891089 return (x >= static_cast <decltype (x)>(scalar)) ? decltype (x){1 }
10901090 : decltype (x){0 };
10911091 });
1092- , INFINI_ALL_TYPES )
1092+ , INFINI_ALL_NUMERIC_TYPES )
10931093}
10941094
10951095std::shared_ptr<Tensor> OrForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
@@ -1098,7 +1098,7 @@ std::shared_ptr<Tensor> OrForward(const std::shared_ptr<Tensor> &a, const std::s
10981098 return (x != decltype (x){0 } || y != decltype (y){0 }) ? decltype (x){1 }
10991099 : decltype (x){0 };
11001100 });
1101- , INFINI_ALL_TYPES )
1101+ , INFINI_ALL_NUMERIC_TYPES )
11021102}
11031103
11041104std::shared_ptr<Tensor> AndForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
@@ -1107,7 +1107,7 @@ std::shared_ptr<Tensor> AndForward(const std::shared_ptr<Tensor> &a, const std::
11071107 return (x != decltype (x){0 } && y != decltype (y){0 }) ? decltype (x){1 }
11081108 : decltype (x){0 };
11091109 });
1110- , INFINI_ALL_TYPES )
1110+ , INFINI_ALL_NUMERIC_TYPES )
11111111}
11121112
11131113std::shared_ptr<Tensor> AddForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
@@ -1125,19 +1125,19 @@ std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> AddBackward(const st
11251125std::shared_ptr<Tensor> AddScalarForward (const std::shared_ptr<Tensor> &a, float scalar) {
11261126 DISPATCH (a->Dtype (),
11271127 return UnaryForward (a, [scalar] __device__ (auto x) { return Add (x, static_cast <decltype (x)>(scalar)); });
1128- , INFINI_ALL_TYPES )
1128+ , INFINI_ALL_NUMERIC_TYPES )
11291129}
11301130
11311131std::shared_ptr<Tensor> AddScalarBackward (const std::shared_ptr<Tensor> &grad_output) {
11321132 DISPATCH (grad_output->Dtype (),
11331133 return UnaryBackward (grad_output, nullptr ,
11341134 [] __device__ (auto x) { return common::cuda::Cast<decltype (x)>(1 ); });
1135- , INFINI_ALL_TYPES )
1135+ , INFINI_ALL_NUMERIC_TYPES )
11361136}
11371137
11381138std::shared_ptr<Tensor> SubForward (const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
11391139 DISPATCH (a->Dtype (), return BinaryForward (a, b, [] __device__ (auto x, auto y) { return Sub (x, y); });
1140- , INFINI_ALL_TYPES )
1140+ , INFINI_ALL_NUMERIC_TYPES )
11411141}
11421142
11431143std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> SubBackward (const std::shared_ptr<Tensor> &grad_output,
0 commit comments