|
18 | 18 | import static io.aklivity.zilla.runtime.engine.buffer.BufferPool.NO_SLOT; |
19 | 19 | import static io.aklivity.zilla.runtime.engine.concurrent.Signaler.NO_CANCEL_ID; |
20 | 20 | import static java.lang.System.currentTimeMillis; |
21 | | -import static java.nio.ByteOrder.BIG_ENDIAN; |
22 | 21 | import static java.util.concurrent.TimeUnit.SECONDS; |
23 | | -import static org.agrona.BitUtil.SIZE_OF_BYTE; |
24 | | -import static org.agrona.BitUtil.SIZE_OF_SHORT; |
25 | 22 |
|
26 | 23 | import java.util.Optional; |
27 | 24 | import java.util.function.Consumer; |
|
36 | 33 | import io.aklivity.zilla.runtime.binding.tls.internal.config.TlsBindingConfig; |
37 | 34 | import io.aklivity.zilla.runtime.binding.tls.internal.config.TlsRouteConfig; |
38 | 35 | import io.aklivity.zilla.runtime.binding.tls.internal.types.OctetsFW; |
| 36 | +import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsExtensionFW; |
| 37 | +import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsExtensionsFW; |
| 38 | +import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsHandshakeClientHelloFW; |
| 39 | +import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsHandshakeMessageFW; |
39 | 40 | import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsRecordFW; |
| 41 | +import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsServerNameExtensionFW; |
| 42 | +import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsServerNameTypeFW; |
| 43 | +import io.aklivity.zilla.runtime.binding.tls.internal.types.codec.TlsSniHostNameFW; |
40 | 44 | import io.aklivity.zilla.runtime.binding.tls.internal.types.stream.AbortFW; |
41 | 45 | import io.aklivity.zilla.runtime.binding.tls.internal.types.stream.BeginFW; |
42 | 46 | import io.aklivity.zilla.runtime.binding.tls.internal.types.stream.DataFW; |
@@ -92,6 +96,13 @@ public final class TlsProxyFactory implements TlsStreamFactory |
92 | 96 | private final ResetFW.Builder resetRW = new ResetFW.Builder(); |
93 | 97 |
|
94 | 98 | private final TlsRecordFW tlsRecordRO = new TlsRecordFW(); |
| 99 | + private final TlsHandshakeMessageFW tlsHandshakeRO = new TlsHandshakeMessageFW(); |
| 100 | + private final TlsHandshakeClientHelloFW tlsClientHelloRO = new TlsHandshakeClientHelloFW(); |
| 101 | + private final TlsExtensionsFW tlsExtensionsRO = new TlsExtensionsFW(); |
| 102 | + private final TlsExtensionFW tlsExtensionRO = new TlsExtensionFW(); |
| 103 | + private final TlsServerNameExtensionFW tlsServerNameExRO = new TlsServerNameExtensionFW(); |
| 104 | + private final TlsServerNameTypeFW tlsServerNameTypeRO = new TlsServerNameTypeFW(); |
| 105 | + private final TlsSniHostNameFW tlsHostnameRO = new TlsSniHostNameFW(); |
95 | 106 |
|
96 | 107 | private final TlsProxyDecoder decodeRecord = this::decodeRecord; |
97 | 108 | private final TlsProxyDecoder decodeRecordBytes = this::decodeRecordBytes; |
@@ -446,94 +457,93 @@ private int decodeRecord( |
446 | 457 | break decode; |
447 | 458 | } |
448 | 459 |
|
449 | | - DirectBuffer message = tlsRecord.payload().value(); |
450 | | - int messageProgress = 0; |
451 | | - byte messageType = message.getByte(messageProgress++); |
452 | | - int messageLength = |
453 | | - (message.getByte(messageProgress++) & 0xff) << 16 | |
454 | | - (message.getByte(messageProgress++) & 0xff) << 8 | |
455 | | - (message.getByte(messageProgress++) & 0xff) << 0; |
456 | | - |
457 | | - if (messageType != MESSAGE_TYPE_CLIENT_HELLO || |
458 | | - messageProgress + messageLength != message.capacity()) |
| 460 | + TlsHandshakeMessageFW tlsHandshake = tlsRecord.payload().get(tlsHandshakeRO::tryWrap); |
| 461 | + if (tlsHandshake == null || |
| 462 | + tlsHandshake.typeAndLength() >> 24 != MESSAGE_TYPE_CLIENT_HELLO || |
| 463 | + (tlsHandshake.typeAndLength() & 0x00ff_ffff) + tlsHandshake.limit() > tlsRecord.limit()) |
459 | 464 | { |
460 | 465 | proxy.doNetReset(traceId); |
461 | 466 | proxy.decoder = decodeIgnoreAll; |
462 | 467 | break decode; |
463 | 468 | } |
464 | 469 |
|
465 | | - // skip version |
466 | | - messageProgress += 2; |
467 | | - |
468 | | - // skip random |
469 | | - messageProgress += 32; |
470 | | - |
471 | | - // skip session id |
472 | | - messageProgress += SIZE_OF_BYTE + (message.getByte(messageProgress) & 0xff); |
473 | | - |
474 | | - // cipher suites |
475 | | - messageProgress += SIZE_OF_SHORT + (message.getShort(messageProgress, BIG_ENDIAN) & 0xffff); |
476 | | - |
477 | | - // compress methods |
478 | | - messageProgress += SIZE_OF_BYTE + (message.getByte(messageProgress) & 0xff); |
479 | | - |
480 | | - int extensionsLength = message.getShort(messageProgress, BIG_ENDIAN) & 0xffff; |
481 | | - messageProgress += SIZE_OF_SHORT; |
482 | | - |
483 | | - if (messageProgress + extensionsLength != message.capacity()) |
| 470 | + TlsHandshakeClientHelloFW tlsClientHello = |
| 471 | + tlsClientHelloRO.tryWrap(tlsRecord.buffer(), tlsHandshake.limit(), tlsRecord.limit()); |
| 472 | + if (tlsClientHello == null || |
| 473 | + tlsClientHello.limit() > tlsRecord.limit()) |
484 | 474 | { |
485 | 475 | proxy.doNetReset(traceId); |
486 | 476 | proxy.decoder = decodeIgnoreAll; |
487 | 477 | break decode; |
488 | 478 | } |
489 | 479 |
|
490 | 480 | String serverName = null; |
491 | | - while (messageProgress < message.capacity()) |
492 | | - { |
493 | | - int extensionType = message.getShort(messageProgress, BIG_ENDIAN) & 0xffff; |
494 | | - messageProgress += SIZE_OF_SHORT; |
495 | | - int extensionLength = message.getShort(messageProgress, BIG_ENDIAN) & 0xffff; |
496 | | - messageProgress += SIZE_OF_SHORT; |
497 | 481 |
|
498 | | - if (messageProgress + extensionLength > message.capacity()) |
| 482 | + if (tlsClientHello.limit() < tlsRecord.limit()) |
| 483 | + { |
| 484 | + TlsExtensionsFW tlsExtensions = |
| 485 | + tlsExtensionsRO.tryWrap(tlsRecord.buffer(), tlsClientHello.limit(), tlsRecord.limit()); |
| 486 | + if (tlsExtensions == null || |
| 487 | + tlsExtensions.limit() != tlsRecord.limit()) |
499 | 488 | { |
500 | 489 | proxy.doNetReset(traceId); |
501 | 490 | proxy.decoder = decodeIgnoreAll; |
502 | 491 | break decode; |
503 | 492 | } |
504 | 493 |
|
505 | | - if (extensionType == EXTENSION_TYPE_SNI) |
506 | | - { |
507 | | - int sniLength = message.getShort(messageProgress, BIG_ENDIAN) & 0xffff; |
508 | | - messageProgress += SIZE_OF_SHORT; |
| 494 | + DirectBuffer tlsExtensionsBuf = tlsExtensions.value().value(); |
| 495 | + int tlsExtensionsLimit = tlsExtensions.length(); |
| 496 | + int tlsExtensionsProgress = 0; |
509 | 497 |
|
510 | | - if (messageProgress + sniLength > message.capacity()) |
| 498 | + while (tlsExtensionsProgress < tlsExtensionsLimit) |
| 499 | + { |
| 500 | + TlsExtensionFW tlsExtension = |
| 501 | + tlsExtensionRO.tryWrap(tlsExtensionsBuf, tlsExtensionsProgress, tlsExtensionsLimit); |
| 502 | + if (tlsExtension == null || |
| 503 | + tlsExtension.limit() > tlsRecord.limit()) |
511 | 504 | { |
512 | 505 | proxy.doNetReset(traceId); |
513 | 506 | proxy.decoder = decodeIgnoreAll; |
514 | 507 | break decode; |
515 | 508 | } |
516 | 509 |
|
517 | | - int sniType = message.getByte(messageProgress++); |
518 | | - if (sniType == SNI_TYPE_HOSTNAME) |
| 510 | + if (tlsExtension.type() == EXTENSION_TYPE_SNI) |
519 | 511 | { |
520 | | - int hostnameLength = message.getShort(messageProgress, BIG_ENDIAN) & 0xffff; |
521 | | - messageProgress += SIZE_OF_SHORT; |
| 512 | + TlsServerNameExtensionFW tlsServerNameEx = tlsExtension.data().get(tlsServerNameExRO::tryWrap); |
| 513 | + if (tlsServerNameEx == null || |
| 514 | + tlsServerNameEx.limit() > tlsRecord.limit()) |
| 515 | + { |
| 516 | + proxy.doNetReset(traceId); |
| 517 | + proxy.decoder = decodeIgnoreAll; |
| 518 | + break decode; |
| 519 | + } |
522 | 520 |
|
523 | | - if (messageProgress + hostnameLength > message.capacity()) |
| 521 | + int tlsServerNameTypeOffset = tlsServerNameEx.value().offset(); |
| 522 | + TlsServerNameTypeFW tlsServerNameType = |
| 523 | + tlsServerNameTypeRO.tryWrap(tlsExtensionsBuf, tlsServerNameTypeOffset, tlsExtensionsLimit); |
| 524 | + if (tlsServerNameType == null) |
524 | 525 | { |
525 | 526 | proxy.doNetReset(traceId); |
526 | 527 | proxy.decoder = decodeIgnoreAll; |
527 | 528 | break decode; |
528 | 529 | } |
529 | 530 |
|
530 | | - serverName = message.getStringWithoutLengthUtf8(messageProgress, hostnameLength); |
531 | | - messageProgress += hostnameLength; |
| 531 | + if (tlsServerNameType.value() == SNI_TYPE_HOSTNAME) |
| 532 | + { |
| 533 | + TlsSniHostNameFW tlsHostname = |
| 534 | + tlsHostnameRO.tryWrap(tlsExtensionsBuf, tlsServerNameType.limit(), tlsExtensionsLimit); |
| 535 | + if (tlsHostname == null) |
| 536 | + { |
| 537 | + proxy.doNetReset(traceId); |
| 538 | + proxy.decoder = decodeIgnoreAll; |
| 539 | + break decode; |
| 540 | + } |
| 541 | + |
| 542 | + serverName = tlsHostname.value().asString(); |
| 543 | + } |
532 | 544 | } |
533 | | - } |
534 | | - else |
535 | | - { |
536 | | - messageProgress += extensionLength; |
| 545 | + |
| 546 | + tlsExtensionsProgress = tlsExtension.limit(); |
537 | 547 | } |
538 | 548 | } |
539 | 549 |
|
|
0 commit comments