@@ -809,14 +809,17 @@ export const applyAttention = (
809809) => {
810810 // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
811811 const outputCount = Math . min ( context . outputCount , 1 + ( pastKey ? 1 : 0 ) + ( pastValue ? 1 : 0 ) ) ;
812+ // When there are no present key/value outputs (outputCount <= 1), ignore past to match CPU EP semantics.
813+ const effectivePastKey = outputCount > 1 ? pastKey : undefined ;
814+ const effectivePastValue = outputCount > 1 ? pastValue : undefined ;
812815 const pastSequenceLength = outputCount > 1 ? parameters . pastSequenceLength : 0 ;
813816 const totalSequenceLength = pastSequenceLength + parameters . kvSequenceLength ;
814817 const attentionBias =
815818 attentionBiasInput && ShapeUtil . size ( attentionBiasInput . dims ) > 0 ? attentionBiasInput : undefined ;
816819
817820 const inputsK = [ q , k ] ;
818- if ( outputCount > 1 && pastKey && ShapeUtil . size ( pastKey . dims ) > 0 ) {
819- inputsK . push ( pastKey ) ;
821+ if ( effectivePastKey && ShapeUtil . size ( effectivePastKey . dims ) > 0 ) {
822+ inputsK . push ( effectivePastKey ) ;
820823 }
821824 if ( attentionBias ) {
822825 inputsK . push ( attentionBias ) ;
@@ -833,7 +836,7 @@ export const applyAttention = (
833836 outputCount ,
834837 q ,
835838 k ,
836- pastKey ,
839+ effectivePastKey ,
837840 attentionBias ,
838841 parameters ,
839842 pastSequenceLength ,
@@ -860,8 +863,8 @@ export const applyAttention = (
860863
861864 // Run AttentionScore
862865 const inputsV = [ probs , v ] ;
863- if ( outputCount > 1 && pastValue && ShapeUtil . size ( pastValue . dims ) > 0 ) {
864- inputsV . push ( pastValue ) ;
866+ if ( effectivePastValue && ShapeUtil . size ( effectivePastValue . dims ) > 0 ) {
867+ inputsV . push ( effectivePastValue ) ;
865868 }
866869 if ( seqLens ) {
867870 inputsV . push ( seqLens ) ;
@@ -874,7 +877,7 @@ export const applyAttention = (
874877 outputCount ,
875878 probs ,
876879 v ,
877- pastValue ,
880+ effectivePastValue ,
878881 parameters ,
879882 pastSequenceLength ,
880883 seqLens ,
0 commit comments