@@ -472,16 +472,16 @@ def register_fake(
472472)
473473
474474lib .define (
475- "quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
475+ "quantized_softmax(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
476476)
477477lib .define (
478- "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
478+ "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
479479)
480480lib .define (
481- "quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
481+ "quantized_softmax.out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
482482)
483483lib .define (
484- "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
484+ "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
485485)
486486
487487# pack float/bool mask tensor into a bitmask of type uint8 (each element holding 8 bool mask elements)
@@ -2957,6 +2957,8 @@ def quantized_softmax_meta(
29572957 input : torch .Tensor ,
29582958 mask : torch .Tensor ,
29592959 dim : int ,
2960+ mask_type : int ,
2961+ pos : torch .Tensor ,
29602962 in_scale : torch .Tensor ,
29612963 in_zero_point : torch .Tensor ,
29622964 out_scale : torch .Tensor ,
@@ -2970,6 +2972,8 @@ def quantized_softmax_per_tensor_meta(
29702972 input : torch .Tensor ,
29712973 mask : torch .Tensor ,
29722974 dim : int ,
2975+ mask_type : int ,
2976+ pos : torch .Tensor ,
29732977 in_scale : float ,
29742978 in_zero_point : int ,
29752979 out_scale : float ,
0 commit comments