@@ -10,10 +10,12 @@ import (
1010 "sync"
1111 "time"
1212
13+ ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types"
1314 "google.golang.org/protobuf/proto"
1415
1516 ragep2ptypes "github.com/smartcontractkit/libocr/ragep2p/types"
1617
18+ "github.com/smartcontractkit/chainlink-common/keystore/corekeys/ocr2key"
1719 "github.com/smartcontractkit/chainlink-common/pkg/beholder"
1820 commoncap "github.com/smartcontractkit/chainlink-common/pkg/capabilities"
1921 "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb"
@@ -43,6 +45,7 @@ type ClientRequest struct {
4345 totalErrorCount int
4446 responseReceived map [p2ptypes.PeerID ]bool
4547 lggr logger.Logger
48+ ocr3Configs map [string ]ocrtypes.ContractConfig
4649
4750 requiredIdenticalResponses int
4851 remoteNodeCount int
@@ -58,6 +61,7 @@ type ClientRequest struct {
5861func NewClientExecuteRequest (ctx context.Context , lggr logger.Logger , req commoncap.CapabilityRequest ,
5962 remoteCapabilityInfo commoncap.CapabilityInfo , localDonInfo commoncap.DON , dispatcher types.Dispatcher ,
6063 requestTimeout time.Duration , transmissionConfig * transmission.TransmissionConfig , capMethodName string ,
64+ ocr3Configs map [string ]ocrtypes.ContractConfig ,
6165) (* ClientRequest , error ) {
6266 rawRequest , err := proto.MarshalOptions {Deterministic : true }.Marshal (pb .CapabilityRequestToProto (req ))
6367 if err != nil {
@@ -87,14 +91,15 @@ func NewClientExecuteRequest(ctx context.Context, lggr logger.Logger, req common
8791 }
8892
8993 lggr = logger .With (lggr , "requestId" , requestID ) // cap ID and method name included in the parent logger
90- return newClientRequest (ctx , lggr , requestID , remoteCapabilityInfo , localDonInfo , dispatcher , requestTimeout , tc , types .MethodExecute , rawRequest , workflowExecutionID , req .Metadata .ReferenceID , capMethodName )
94+ return newClientRequest (ctx , lggr , requestID , remoteCapabilityInfo , localDonInfo , dispatcher , requestTimeout , tc , types .MethodExecute , rawRequest , workflowExecutionID , req .Metadata .ReferenceID , capMethodName , ocr3Configs )
9195}
9296
9397var defaultDelayMargin = 10 * time .Second
9498
9599func newClientRequest (ctx context.Context , lggr logger.Logger , requestID string , remoteCapabilityInfo commoncap.CapabilityInfo ,
96100 localDonInfo commoncap.DON , dispatcher types.Dispatcher , requestTimeout time.Duration ,
97101 tc transmission.TransmissionConfig , methodType string , rawRequest []byte , workflowExecutionID string , stepRef string , capMethodName string ,
102+ ocr3Configs map [string ]ocrtypes.ContractConfig ,
98103) (* ClientRequest , error ) {
99104 remoteCapabilityDonInfo := remoteCapabilityInfo .DON
100105 if remoteCapabilityDonInfo == nil {
@@ -200,6 +205,7 @@ func newClientRequest(ctx context.Context, lggr logger.Logger, requestID string,
200205 responseCh : make (chan clientResponse , 1 ),
201206 wg : & wg ,
202207 lggr : lggr ,
208+ ocr3Configs : ocr3Configs ,
203209 }, nil
204210}
205211
@@ -301,6 +307,32 @@ func (c *ClientRequest) OnMessage(_ context.Context, msg *types.MessageBody) err
301307 c .responseReceived [sender ] = true
302308
303309 if msg .Error == types .Error_OK {
310+ resp , err := pb .UnmarshalCapabilityResponse (msg .Payload )
311+ if err != nil {
312+ return fmt .Errorf ("failed to unmarshal capability response: %w" , err )
313+ }
314+
315+ if resp .Metadata .OCRAttestation != nil {
316+ // Since signatures are provided switch to OCR based validation. It's enough to get 1 response with F+1 signatures
317+ // to be confident that the response is honest.
318+ err = c .verifyAttestation (resp )
319+ if err != nil {
320+ c .lggr .Errorw ("failed to verify capability response OCR attestation" , "peer" , sender , "err" , err , "requestID" , c .id , "msgPayload" , hex .EncodeToString (msg .Payload ))
321+ return fmt .Errorf ("failed to verify capability response OCR attestation: %w" , err )
322+ }
323+
324+ rpt := resp .Metadata .Metering [0 ]
325+ rpt .Peer2PeerID = sender .String ()
326+ var payload []byte
327+ payload , err = c .encodePayloadWithMetadata (msg , commoncap.ResponseMetadata {Metering : []commoncap.MeteringNodeDetail {rpt }})
328+ if err != nil {
329+ return fmt .Errorf ("failed to encode payload with metadata: %w" , err )
330+ }
331+
332+ c .sendResponse (clientResponse {Result : payload })
333+ return nil
334+ }
335+
304336 // metering reports per node are aggregated into a single array of values. for any single node message, the
305337 // metering values are extracted from the CapabilityResponse, added to an array, and the CapabilityResponse
306338 // is marshalled without the metering value to get the hash. each node could have a different metering value
@@ -359,6 +391,47 @@ func (c *ClientRequest) OnMessage(_ context.Context, msg *types.MessageBody) err
359391 return nil
360392}
361393
394+ func (c * ClientRequest ) verifyAttestation (resp commoncap.CapabilityResponse ) error {
395+ if c .ocr3Configs == nil {
396+ return errors .New ("OCR3 configs not provided, cannot verify signatures" )
397+ }
398+
399+ cfg , ok := c .ocr3Configs [pb .OCR3ConfigDefaultKey ]
400+ if ! ok {
401+ return fmt .Errorf ("OCR3 config with key %s not found" , pb .OCR3ConfigDefaultKey )
402+ }
403+
404+ attestation := resp .Metadata .OCRAttestation
405+ if len (attestation .Sigs ) < int (cfg .F )+ 1 {
406+ return fmt .Errorf ("not enough signatures: got %d, need at least %d" , len (attestation .Sigs ), cfg .F + 1 )
407+ }
408+
409+ if len (resp .Metadata .Metering ) != 1 {
410+ return fmt .Errorf ("unexpected number of metering records: got %d, want 1" , len (resp .Metadata .Metering ))
411+ }
412+
413+ reportData := commoncap .ResponseToReportData (c .id , resp .Payload .Value , resp .Metadata .Metering [0 ].SpendUnit , resp .Metadata .Metering [0 ].SpendValue )
414+ sigData := ocr2key .ReportToSigData3 (attestation .ConfigDigest , attestation .SequenceNumber , reportData )
415+ signed := make ([]bool , len (cfg .Signers ))
416+ for _ , sig := range attestation .Sigs {
417+ if int (sig .Signer ) > len (cfg .Signers ) {
418+ return fmt .Errorf ("invalid signer index: %d" , sig .Signer )
419+ }
420+
421+ if signed [sig .Signer ] {
422+ return fmt .Errorf ("duplicate signature from signer index: %d" , sig .Signer )
423+ }
424+
425+ if ! ocr2key .EvmVerifyBlob (cfg .Signers [sig .Signer ], sigData , sig .Signature ) {
426+ return fmt .Errorf ("invalid signature from signer index: %d" , sig .Signer )
427+ }
428+
429+ signed [sig .Signer ] = true
430+ }
431+
432+ return nil
433+ }
434+
362435func (c * ClientRequest ) sendResponse (response clientResponse ) {
363436 c .responseCh <- response
364437 close (c .responseCh )
0 commit comments