File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -47,6 +47,7 @@ def forward(
4747 input_ids : Optional [torch .LongTensor ] = None ,
4848 past_key_values : Optional [RemotePastKeyValues ] = None ,
4949 attention_mask : Optional [torch .Tensor ] = None ,
50+ position_ids : Optional [torch .LongTensor ] = None ,
5051 head_mask : Optional [torch .LongTensor ] = None ,
5152 inputs_embeds : Optional [torch .LongTensor ] = None ,
5253 use_cache : Optional [bool ] = None ,
@@ -68,6 +69,9 @@ def forward(
6869 assert (
6970 attention_mask is None or (attention_mask == 1 ).all ()
7071 ), f"Custom attention masks are not supported, { attention_mask = } "
72+ assert (
73+ position_ids is None or (position_ids [:, 1 :] - position_ids [:, :- 1 ] == 1 ).all ()
74+ ), f"Non-consecutive position_ids are not supported, { position_ids = } "
7175 assert head_mask is None , f"Custom head masks are not supported, { head_mask = } "
7276 assert use_cache is None or use_cache , f"{ use_cache = } is not supported"
7377 assert not output_attentions , f"{ output_attentions = } is not supported"
You can’t perform that action at this time.
0 commit comments