Skip to content

Commit 88b16ef

Browse files
authored
Respect maximumCount argument in request reader (#44)
Motivation: Currently, we do not use the `maximumCount` argument in `HTTPRequestConcludingAsyncReader`'s `read(maximumCount:body:)` method. Modifications: Updated `HTTPRequestConcludingAsyncReader/read(maximumCount:body:)` to throw a `LimitExceeded` error when the bytes available exceed `maximumCount`. Result: The `maximumCount` argument is now respected in `HTTPRequestConcludingAsyncReader/read(maximumCount:body:)`.
1 parent cee4932 commit 88b16ef

2 files changed

Lines changed: 233 additions & 27 deletions

File tree

Sources/NIOHTTPServer/HTTPRequestConcludingAsyncReader.swift

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,90 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
4141
/// The HTTP trailer fields captured at the end of the request.
4242
fileprivate var state: ReaderState
4343

44-
/// The iterator that provides HTTP request parts from the underlying channel.
45-
private var iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator
44+
struct RequestBodyStateMachine {
45+
enum State {
46+
// The request body is currently being read: expecting more request body parts or a request end part.
47+
case readingBody(ReadingBodyState)
48+
49+
// The request end part was received. We have finished.
50+
case finished
51+
52+
enum ReadingBodyState {
53+
// All received bytes have been consumed; no excess bytes need to be stored.
54+
case noExcess
55+
56+
// `read` was called with a `maximumCount` value that was lower than the bytes available. The excess
57+
// bytes are stored here so they can be dispensed in future calls to `read`.
58+
case excess(ByteBuffer)
59+
}
60+
}
61+
62+
private var state: State
63+
64+
/// The iterator that provides HTTP request parts from the underlying channel.
65+
private var iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator
66+
67+
init(iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator) {
68+
self.state = .readingBody(.noExcess)
69+
self.iterator = iterator
70+
}
71+
72+
enum ReadResult {
73+
case readBody(ByteBuffer)
74+
case readEnd(HTTPFields?)
75+
case streamFinished
76+
}
77+
78+
mutating func read(limit: Int?) async throws -> ReadResult {
79+
switch self.state {
80+
case .readingBody(let readingBodyState):
81+
var bodyElement: ByteBuffer
82+
83+
switch readingBodyState {
84+
case .excess(let excessElement):
85+
// There was an excess of bytes from the previous call to `read`. We read directly from this
86+
// excess and don't advance the iterator.
87+
bodyElement = excessElement
88+
89+
case .noExcess:
90+
// There is no excess from previous reads. We obtain the next element from the stream.
91+
let requestPart = try await self.iterator.next(isolation: #isolation)
92+
93+
switch requestPart {
94+
case .head:
95+
fatalError("Unexpectedly received a request head.")
96+
97+
case .none:
98+
fatalError("The stream unexpectedly ended before receiving a request end.")
99+
100+
case .body(let element):
101+
bodyElement = element
102+
103+
case .end(let trailers):
104+
self.state = .finished
105+
return .readEnd(trailers)
106+
}
107+
}
108+
109+
if let limit, limit < bodyElement.readableBytes,
110+
let truncated = bodyElement.readSlice(length: limit)
111+
{
112+
// There are more bytes available than `limit`. We must store the excess in a buffer for it to
113+
// be consumed in the next call to `read`.
114+
self.state = .readingBody(.excess(bodyElement))
115+
return .readBody(truncated)
116+
}
117+
118+
self.state = .readingBody(.noExcess)
119+
return .readBody(bodyElement)
120+
121+
case .finished:
122+
return .streamFinished
123+
}
124+
}
125+
}
126+
127+
var requestBodyStateMachine: RequestBodyStateMachine
46128

47129
/// Initializes a new request body reader with the given NIO async channel iterator.
48130
///
@@ -51,7 +133,7 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
51133
iterator: consuming sending NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
52134
readerState: ReaderState
53135
) {
54-
self.iterator = iterator
136+
self.requestBodyStateMachine = .init(iterator: iterator)
55137
self.state = readerState
56138
}
57139

@@ -65,26 +147,26 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
65147
maximumCount: Int?,
66148
body: nonisolated(nonsending) (consuming Span<ReadElement>) async throws(Failure) -> Return
67149
) async throws(EitherError<ReadFailure, Failure>) -> Return {
68-
let requestPart: HTTPRequestPart?
150+
let readResult: RequestBodyStateMachine.ReadResult
69151
do {
70-
requestPart = try await self.iterator.next(isolation: #isolation)
152+
readResult = try await self.requestBodyStateMachine.read(limit: maximumCount)
71153
} catch {
72154
throw .first(error)
73155
}
74156

75157
do {
76-
switch requestPart {
77-
case .head:
78-
fatalError()
79-
case .body(let element):
80-
return try await body(Array(buffer: element).span)
81-
case .end(let trailers):
158+
switch readResult {
159+
case .readBody(let readElement):
160+
return try await body(Array(buffer: readElement).span)
161+
162+
case .readEnd(let trailers):
82163
self.state.wrapped.withLock { state in
83164
state.trailers = trailers
84165
state.finishedReading = true
85166
}
86167
return try await body(.init())
87-
case .none:
168+
169+
case .streamFinished:
88170
return try await body(.init())
89171
}
90172
} catch {

Tests/NIOHTTPServerTests/HTTPRequestConcludingAsyncReaderTests.swift

Lines changed: 139 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,32 @@ struct HTTPRequestConcludingAsyncReaderTests {
4242

4343
_ = try await requestReader.consumeAndConclude { bodyReader in
4444
var bodyReader = bodyReader
45-
try await bodyReader.read(maximumCount: nil) { element in () }
45+
try await bodyReader.read(maximumCount: nil) { _ in }
46+
}
47+
}
48+
}
49+
50+
@Test("Stream cannot be finished before writing request end part")
51+
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
52+
func testNotWritingRequestEndPartFatalError() async throws {
53+
await #expect(processExitsWith: .failure) {
54+
let (stream, source) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
55+
56+
// Only write a request body part; do not write an end part.
57+
source.yield(.body(.init()))
58+
source.finish()
59+
60+
let requestReader = HTTPRequestConcludingAsyncReader(
61+
iterator: stream.makeAsyncIterator(),
62+
readerState: .init()
63+
)
64+
65+
_ = try await requestReader.consumeAndConclude { bodyReader in
66+
var bodyReader = bodyReader
67+
68+
try await bodyReader.read(maximumCount: nil) { _ in }
69+
// The stream has finished without an end part. Calling `read` now should result in a fatal error.
70+
try await bodyReader.read(maximumCount: nil) { _ in }
4671
}
4772
}
4873
}
@@ -172,26 +197,125 @@ struct HTTPRequestConcludingAsyncReaderTests {
172197
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
173198
@Test("More bytes available than consumption limit")
174199
func testCollectMoreBytesThanAvailable() async throws {
175-
await #expect(processExitsWith: .failure) {
176-
let (stream, source) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
200+
let (stream, source) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
177201

178-
// Write 10 bytes
179-
source.yield(.body(.init(repeating: 5, count: 10)))
180-
source.finish()
202+
// Write 10 bytes
203+
source.yield(.body(.init(repeating: 5, count: 10)))
204+
source.finish()
181205

182-
let requestReader = HTTPRequestConcludingAsyncReader(
183-
iterator: stream.makeAsyncIterator(),
184-
readerState: .init()
185-
)
206+
let requestReader = HTTPRequestConcludingAsyncReader(
207+
iterator: stream.makeAsyncIterator(),
208+
readerState: .init()
209+
)
210+
211+
_ = try await requestReader.consumeAndConclude { requestBodyReader in
212+
var requestBodyReader = requestBodyReader
213+
214+
// There are more bytes available than our limit.
215+
let collected = try await requestBodyReader.collect(upTo: 9) { element in
216+
var buffer = ByteBuffer()
217+
buffer.writeBytes(element.bytes)
218+
return buffer
219+
}
220+
221+
// We should only collect up to the limit (the first 9 bytes).
222+
#expect(collected == .init(repeating: 5, count: 9))
223+
}
224+
}
225+
226+
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
227+
@Test("Multiple body chunks; multiple reads with limits")
228+
func testReadWithLimits() async throws {
229+
let (stream, source) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
230+
231+
// First write 10 bytes;
232+
source.yield(.body(.init(repeating: 1, count: 10)))
233+
// Then write another 5 bytes.
234+
source.yield(.body(.init(repeating: 2, count: 5)))
235+
source.yield(.end(nil))
236+
source.finish()
237+
238+
let streamIterator = stream.makeAsyncIterator()
239+
240+
let requestReader = HTTPRequestConcludingAsyncReader(iterator: streamIterator, readerState: .init())
241+
_ = try await requestReader.consumeAndConclude { requestBodyReader in
242+
var requestBodyReader = requestBodyReader
243+
244+
// Collect 8 bytes (partial of first write).
245+
let collectedPartOne = try await requestBodyReader.collect(upTo: 8) { element in
246+
var buffer = ByteBuffer()
247+
buffer.writeBytes(element.bytes)
248+
return buffer
249+
}
250+
251+
// Then collect 4 more bytes (overlap of first and second write).
252+
let collectedPartTwo = try await requestBodyReader.collect(upTo: 4) { element in
253+
var buffer = ByteBuffer()
254+
buffer.writeBytes(element.bytes)
255+
return buffer
256+
}
257+
258+
// Then collect 3 more bytes (partial of second write).
259+
let collectedPartThree = try await requestBodyReader.collect(upTo: 3) { element in
260+
var buffer = ByteBuffer()
261+
buffer.writeBytes(element.bytes)
262+
return buffer
263+
}
186264

187-
_ = try await requestReader.consumeAndConclude { requestBodyReader in
188-
var requestBodyReader = requestBodyReader
265+
#expect(collectedPartOne == .init(repeating: 1, count: 8))
266+
#expect(collectedPartTwo == .init([1, 1, 2, 2]))
267+
#expect(collectedPartThree == .init(repeating: 2, count: 3))
268+
}
269+
}
270+
271+
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
272+
@Test("Multiple random-length chunks; multiple reads with random limits")
273+
func testMultipleReadsWithRandomLimits() async throws {
274+
let (stream, source) = NIOAsyncChannelInboundStream<HTTPRequestPart>.makeTestingStream()
275+
276+
// Generate random ByteBuffers of varying length and write them to the stream.
277+
var randomBuffer = ByteBuffer()
278+
for _ in 0..<100 {
279+
let randomNumber = UInt8.random(in: 1...50)
280+
let randomCount = Int.random(in: 1...50)
281+
282+
let randomData = ByteBuffer(repeating: randomNumber, count: randomCount)
283+
// Store the data so we can track what we have wrote
284+
randomBuffer.writeImmutableBuffer(randomData)
285+
286+
source.yield(.body(randomData))
287+
}
288+
source.yield(.end(nil))
289+
source.finish()
290+
291+
let streamIterator = stream.makeAsyncIterator()
292+
293+
let requestReader = HTTPRequestConcludingAsyncReader(iterator: streamIterator, readerState: .init())
294+
_ = try await requestReader.consumeAndConclude { requestBodyReader in
295+
var requestBodyReader = requestBodyReader
189296

190-
// Since there are more bytes than requested, this should fail.
191-
try await requestBodyReader.collect(upTo: 9) { element in
192-
()
297+
var collectedBuffer = ByteBuffer()
298+
while true {
299+
let randomMaxCount = Int.random(in: 1...100)
300+
301+
let collected = try await requestBodyReader.collect(upTo: randomMaxCount) { element in
302+
var localBuffer = ByteBuffer()
303+
localBuffer.writeBytes(element.bytes)
304+
return localBuffer
193305
}
306+
307+
if collected.readableBytes == 0 {
308+
break
309+
}
310+
311+
// The collected buffer should never exceed the specified max count.
312+
try #require(collected.readableBytes <= randomMaxCount)
313+
314+
collectedBuffer.writeImmutableBuffer(collected)
194315
}
316+
317+
// Check if the collected buffer exactly matches what was written to the stream.
318+
try #require(randomBuffer == collectedBuffer)
195319
}
196320
}
197321
}

0 commit comments

Comments
 (0)