diff --git a/common/protocol/quic/sniff.go b/common/protocol/quic/sniff.go index 5b29d6ffada4..0c20c3f61a17 100644 --- a/common/protocol/quic/sniff.go +++ b/common/protocol/quic/sniff.go @@ -100,17 +100,17 @@ func SniffQUIC(b []byte) (*SniffHeader, error) { } if isQuicInitial { // Only initial packets have token, see https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.2 - tokenLen, err := quicvarint.Read(buffer) - if err != nil || tokenLen > uint64(len(b)) { + tokenLen, err := readShortQuicVarint(buffer) + if err != nil || tokenLen > int32(len(b)) { return nil, errNotQuic } - if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil { + if _, err = buffer.ReadBytes(tokenLen); err != nil { return nil, errNotQuic } } - packetLen, err := quicvarint.Read(buffer) + packetLen, err := readShortQuicVarint(buffer) if err != nil { return nil, errNotQuic } @@ -179,45 +179,45 @@ func SniffQUIC(b []byte) (*SniffHeader, error) { case 0x00: // PADDING frame case 0x01: // PING frame case 0x02, 0x03: // ACK frame - if _, err = quicvarint.Read(buffer); err != nil { // Field: Largest Acknowledged + if _, err = readShortQuicVarint(buffer); err != nil { // Field: Largest Acknowledged return nil, io.ErrUnexpectedEOF } - if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Delay + if _, err = readShortQuicVarint(buffer); err != nil { // Field: ACK Delay return nil, io.ErrUnexpectedEOF } - ackRangeCount, err := quicvarint.Read(buffer) // Field: ACK Range Count + ackRangeCount, err := readShortQuicVarint(buffer) // Field: ACK Range Count if err != nil { return nil, io.ErrUnexpectedEOF } - if _, err = quicvarint.Read(buffer); err != nil { // Field: First ACK Range + if _, err = readShortQuicVarint(buffer); err != nil { // Field: First ACK Range return nil, io.ErrUnexpectedEOF } for i := 0; i < int(ackRangeCount); i++ { // Field: ACK Range - if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> Gap + if _, err = readShortQuicVarint(buffer); err != nil { // Field: ACK Range -> Gap return nil, io.ErrUnexpectedEOF } - if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> ACK Range Length + if _, err = readShortQuicVarint(buffer); err != nil { // Field: ACK Range -> ACK Range Length return nil, io.ErrUnexpectedEOF } } if frameType == 0x03 { - if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT0 Count + if _, err = readShortQuicVarint(buffer); err != nil { // Field: ECN Counts -> ECT0 Count return nil, io.ErrUnexpectedEOF } - if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT1 Count + if _, err = readShortQuicVarint(buffer); err != nil { // Field: ECN Counts -> ECT1 Count return nil, io.ErrUnexpectedEOF } - if _, err = quicvarint.Read(buffer); err != nil { //nolint:misspell // Field: ECN Counts -> ECT-CE Count + if _, err = readShortQuicVarint(buffer); err != nil { //nolint:misspell // Field: ECN Counts -> ECT-CE Count return nil, io.ErrUnexpectedEOF } } case 0x06: // CRYPTO frame, we will use this frame - offset, err := quicvarint.Read(buffer) // Field: Offset + offset, err := readShortQuicVarint(buffer) // Field: Offset if err != nil { return nil, io.ErrUnexpectedEOF } - length, err := quicvarint.Read(buffer) // Field: Length - if err != nil || length > uint64(buffer.Len()) { + length, err := readShortQuicVarint(buffer) // Field: Length + if err != nil || length > buffer.Len() { return nil, io.ErrUnexpectedEOF } currentCryptoLen := int32(offset + length) @@ -228,17 +228,17 @@ func SniffQUIC(b []byte) (*SniffHeader, error) { cryptoDataBuf.Extend(currentCryptoLen - cryptoLen) cryptoLen = currentCryptoLen } - if _, err := buffer.Read(cryptoDataBuf.BytesRange(int32(offset), currentCryptoLen)); err != nil { // Field: Crypto Data + if _, err := buffer.Read(cryptoDataBuf.BytesRange(offset, currentCryptoLen)); err != nil { // Field: Crypto Data return nil, io.ErrUnexpectedEOF } case 0x1c: // CONNECTION_CLOSE frame, only 0x1c is permitted in initial packet - if _, err = quicvarint.Read(buffer); err != nil { // Field: Error Code + if _, err = readShortQuicVarint(buffer); err != nil { // Field: Error Code return nil, io.ErrUnexpectedEOF } - if _, err = quicvarint.Read(buffer); err != nil { // Field: Frame Type + if _, err = readShortQuicVarint(buffer); err != nil { // Field: Frame Type return nil, io.ErrUnexpectedEOF } - length, err := quicvarint.Read(buffer) // Field: Reason Phrase Length + length, err := readShortQuicVarint(buffer) // Field: Reason Phrase Length if err != nil { return nil, io.ErrUnexpectedEOF } @@ -283,3 +283,18 @@ func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, len } return out } + +// readShortQuicVarint wraps quicvarint.Read with a max limit for length related fields. +// we only handle QUIC Initial so these numbers should not exceed 65535 +// returns int32 to reduce type conversion +func readShortQuicVarint(reader io.ByteReader) (int32, error) { + v, err := quicvarint.Read(reader) + if err != nil { + return 0, err + } + if v > 65535 { + // not used( + return 0, errNotQuicInitial + } + return int32(v), nil +}