Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions packages/interface/src/stream-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,30 @@ export interface StreamHandler {
/**
* A callback function that accepts the incoming stream data
*/
(stream: Stream, connection: Connection): void | Promise<void>
(stream: Stream, connection: Connection, context?: StreamContext): void | Promise<void>
}

export interface StreamContext {
get<T = unknown>(key: StreamContextKey<T>): T | undefined
set<T = unknown>(key: StreamContextKey<T>, value: T): void
has<T = unknown>(key: StreamContextKey<T>): boolean
delete<T = unknown>(key: StreamContextKey<T>): boolean
}

export interface StreamContextKey<T = unknown> {
id: symbol
type?: T
}

/**
* Stream middleware allows accessing stream data outside of the stream handler
*
* Return false to stop the middleware chain without aborting the stream.
* Call next to continue the middleware chain.
* Throw or reject to abort the stream.
*/
export interface StreamMiddleware {
(stream: Stream, connection: Connection, next: (stream: Stream, connection: Connection) => void): void | Promise<void>
(stream: Stream, connection: Connection, next: (stream: Stream, connection: Connection) => void, context: StreamContext): void | false | Promise<void | false>
}

export interface StreamHandlerOptions extends AbortOptions {
Expand Down
87 changes: 64 additions & 23 deletions packages/libp2p/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { CONNECTION_CLOSE_TIMEOUT, PROTOCOL_NEGOTIATION_TIMEOUT } from './connec
import { isDirect } from './connection-manager/utils.ts'
import { MuxerUnavailableError } from './errors.ts'
import { DEFAULT_MAX_INBOUND_STREAMS, DEFAULT_MAX_OUTBOUND_STREAMS } from './registrar.ts'
import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream, StreamMiddleware, OpenStreamEvent, OpenedStreamEvent } from '@libp2p/interface'
import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream, StreamContext, StreamMiddleware, OpenStreamEvent, OpenedStreamEvent } from '@libp2p/interface'
import type { Registrar } from '@libp2p/interface-internal'
import type { Multiaddr } from '@multiformats/multiaddr'

Expand Down Expand Up @@ -185,7 +185,7 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement

const middleware = this.components.registrar.getMiddleware(muxedStream.protocol)

const stream = await this.runMiddlewareChain(muxedStream, this, middleware)
const stream = await this.runMiddlewareChain(muxedStream, this, middleware, createStreamContext())

options.onProgress?.(new CustomProgressEvent<OpenedStreamEvent>('connection:opened-stream', {
connection: this,
Expand Down Expand Up @@ -250,40 +250,38 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
throw new LimitedConnectionError('Cannot open protocol stream on limited connection')
}

const middleware = this.components.registrar.getMiddleware(muxedStream.protocol)
// Copy registered middleware before appending the handler wrapper below;
// the registered middleware array is reused across streams.
const middleware = [
...this.components.registrar.getMiddleware(muxedStream.protocol)
]

const context = createStreamContext()

middleware.push(async (stream, connection, next) => {
await handler(stream, connection)
middleware.push(async (stream, connection, next, context) => {
await handler(stream, connection, context)
next(stream, connection)
})

await this.runMiddlewareChain(muxedStream, this, middleware)
await this.runMiddlewareChain(muxedStream, this, middleware, context)
} catch (err: any) {
muxedStream.abort(err)
}
}

private async runMiddlewareChain (stream: Stream, connection: ConnectionInterface, middleware: StreamMiddleware[]): Promise<Stream> {
private async runMiddlewareChain (stream: Stream, connection: ConnectionInterface, middleware: StreamMiddleware[], context: StreamContext): Promise<Stream> {
for (let i = 0; i < middleware.length; i++) {
const mw = middleware[i]
stream.log.trace('running middleware', i, mw)

// eslint-disable-next-line no-loop-func
await new Promise<void>((resolve, reject) => {
try {
const result = mw(stream, connection, (s, c) => {
stream = s
connection = c
resolve()
})

if (result instanceof Promise) {
result.catch(reject)
}
} catch (err) {
reject(err)
}
})
const result = await runMiddleware(mw, stream, connection, context)
stream = result.stream
connection = result.connection

if (result.stop) {
stream.log.trace('middleware stopped chain', i, mw)
break
}

stream.log.trace('ran middleware', i, mw)
}
Expand Down Expand Up @@ -353,6 +351,49 @@ function findOutgoingStreamLimit (protocol: string, registrar: Registrar, option
return options.maxOutboundStreams ?? DEFAULT_MAX_OUTBOUND_STREAMS
}

interface RunMiddlewareResult {
stream: Stream
connection: ConnectionInterface
stop: boolean
}

function runMiddleware (mw: StreamMiddleware, stream: Stream, connection: ConnectionInterface, context: StreamContext): Promise<RunMiddlewareResult> {
return new Promise<RunMiddlewareResult>((resolve, reject) => {
const continueChain = (s: Stream, c: ConnectionInterface): void => {
resolve({ stream: s, connection: c, stop: false })
}

const stopChain = (): void => {
resolve({ stream, connection, stop: true })
}

try {
Promise.resolve(mw(stream, connection, continueChain, context))
.then(result => {
if (result === false) {
stopChain()
}
})
.catch(reject)
} catch (err) {
reject(err)
}
})
}

function createStreamContext (): StreamContext {
const values = new Map<symbol, unknown>()

return {
get: <T = unknown>(key: { id: symbol }): T | undefined => values.get(key.id) as T | undefined,
set: <T = unknown>(key: { id: symbol }, value: T) => {
values.set(key.id, value)
},
has: (key: { id: symbol }) => values.has(key.id),
delete: (key: { id: symbol }) => values.delete(key.id)
}
}

function countStreams (protocol: string, direction: 'inbound' | 'outbound', connection: Connection): number {
let streamCount = 0

Expand Down
Loading
Loading