diff --git a/packages/interface/src/stream-handler.ts b/packages/interface/src/stream-handler.ts index 728fedfe40..43e6ec152c 100644 --- a/packages/interface/src/stream-handler.ts +++ b/packages/interface/src/stream-handler.ts @@ -9,9 +9,12 @@ export interface StreamHandler { /** * Stream middleware allows accessing stream data outside of the stream handler + * + * Return false to stop the middleware chain without aborting the stream. + * Throw or reject to abort the stream. */ export interface StreamMiddleware { - (stream: Stream, connection: Connection, next: (stream: Stream, connection: Connection) => void): void | Promise + (stream: Stream, connection: Connection, next: (stream: Stream, connection: Connection) => void): void | false | Promise } export interface StreamHandlerOptions extends AbortOptions { diff --git a/packages/libp2p/src/connection.ts b/packages/libp2p/src/connection.ts index fb16e4ba85..08792e0a4d 100644 --- a/packages/libp2p/src/connection.ts +++ b/packages/libp2p/src/connection.ts @@ -250,7 +250,11 @@ export class Connection extends TypedEventEmitter implement throw new LimitedConnectionError('Cannot open protocol stream on limited connection') } - const middleware = this.components.registrar.getMiddleware(muxedStream.protocol) + // Copy registered middleware before appending the handler wrapper below; + // the registered middleware array is reused across streams. + const middleware = [ + ...this.components.registrar.getMiddleware(muxedStream.protocol) + ] middleware.push(async (stream, connection, next) => { await handler(stream, connection) @@ -268,22 +272,14 @@ export class Connection extends TypedEventEmitter implement const mw = middleware[i] stream.log.trace('running middleware', i, mw) - // eslint-disable-next-line no-loop-func - await new Promise((resolve, reject) => { - try { - const result = mw(stream, connection, (s, c) => { - stream = s - connection = c - resolve() - }) - - if (result instanceof Promise) { - result.catch(reject) - } - } catch (err) { - reject(err) - } - }) + const result = await runMiddleware(mw, stream, connection) + stream = result.stream + connection = result.connection + + if (result.stop) { + stream.log.trace('middleware stopped chain', i, mw) + break + } stream.log.trace('ran middleware', i, mw) } @@ -353,6 +349,40 @@ function findOutgoingStreamLimit (protocol: string, registrar: Registrar, option return options.maxOutboundStreams ?? DEFAULT_MAX_OUTBOUND_STREAMS } +interface RunMiddlewareResult { + stream: Stream + connection: ConnectionInterface + stop: boolean +} + +function runMiddleware (mw: StreamMiddleware, stream: Stream, connection: ConnectionInterface): Promise { + return new Promise((resolve, reject) => { + const continueChain = (s: Stream, c: ConnectionInterface): void => { + resolve({ stream: s, connection: c, stop: false }) + } + + const stopChain = (): void => { + resolve({ stream, connection, stop: true }) + } + + try { + const result = mw(stream, connection, continueChain) + + if (result === false) { + stopChain() + } else if (result != null) { + result.then(result => { + if (result === false) { + stopChain() + } + }).catch(reject) + } + } catch (err) { + reject(err) + } + }) +} + function countStreams (protocol: string, direction: 'inbound' | 'outbound', connection: Connection): number { let streamCount = 0 diff --git a/packages/libp2p/test/connection/index.spec.ts b/packages/libp2p/test/connection/index.spec.ts index 0dd1f0fa2c..fa2d03503c 100644 --- a/packages/libp2p/test/connection/index.spec.ts +++ b/packages/libp2p/test/connection/index.spec.ts @@ -338,10 +338,11 @@ describe('connection', () => { middleware1, middleware2 ] + const handler = Sinon.stub() registrar.getMiddleware.withArgs(streamProtocol).returns(middleware) registrar.getHandler.withArgs(streamProtocol).returns({ - handler: () => {}, + handler, options: {} }) @@ -379,6 +380,328 @@ describe('connection', () => { expect(middleware1.called).to.be.true() expect(middleware2.called).to.be.true() + expect(handler.calledOnce).to.be.true() + }) + + it('should not mutate incoming stream middleware when appending handler', async () => { + const streamProtocol = '/test/protocol' + + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + const middleware = [ + middleware1 + ] + const handler = Sinon.stub() + + registrar.getMiddleware.withArgs(streamProtocol).returns(middleware) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler, + options: {} + }) + + const muxer = stubInterface({ + streams: [] + }) + + createConnection(components, { + ...init, + muxer + }) + + expect(muxer.addEventListener.getCall(0).args[0]).to.equal('stream') + const onIncomingStream = muxer.addEventListener.getCall(0).args[1] + + if (typeof onIncomingStream !== 'function') { + throw new Error('Stream handler was not function') + } + + onIncomingStream(new CustomEvent('stream', { + detail: stubInterface({ + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol + }) + })) + onIncomingStream(new CustomEvent('stream', { + detail: stubInterface({ + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol + }) + })) + + // incoming streams are opened asynchronously + await delay(100) + + expect(middleware).to.have.lengthOf(1) + expect(middleware1.callCount).to.equal(2) + expect(handler.callCount).to.equal(2) + }) + + it('should allow incoming stream middleware to stop the chain without aborting', async () => { + const streamProtocol = '/test/protocol' + + const middleware1 = Sinon.stub().callsFake(() => { + return false + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + const handler = Sinon.stub() + + registrar.getMiddleware.withArgs(streamProtocol).returns([ + middleware1, + middleware2 + ]) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler, + options: {} + }) + + const muxer = stubInterface({ + streams: [] + }) + + createConnection(components, { + ...init, + muxer + }) + + const onIncomingStream = muxer.addEventListener.getCall(0).args[1] + + if (typeof onIncomingStream !== 'function') { + throw new Error('Stream handler was not function') + } + + const incomingStream = stubInterface({ + abort: Sinon.stub(), + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol + }) + + onIncomingStream(new CustomEvent('stream', { + detail: incomingStream + })) + + // incoming stream is opened asynchronously + await delay(100) + + expect(middleware1.calledOnce).to.be.true() + expect(middleware2.called).to.be.false() + expect(handler.called).to.be.false() + expect(incomingStream.abort.called).to.be.false() + }) + + it('should allow async incoming stream middleware to stop the chain without aborting', async () => { + const streamProtocol = '/test/protocol' + + const middleware1 = Sinon.stub().callsFake(async () => { + await delay(1) + return false + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + const handler = Sinon.stub() + + registrar.getMiddleware.withArgs(streamProtocol).returns([ + middleware1, + middleware2 + ]) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler, + options: {} + }) + + const muxer = stubInterface({ + streams: [] + }) + + createConnection(components, { + ...init, + muxer + }) + + const onIncomingStream = muxer.addEventListener.getCall(0).args[1] + + if (typeof onIncomingStream !== 'function') { + throw new Error('Stream handler was not function') + } + + const incomingStream = stubInterface({ + abort: Sinon.stub(), + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol + }) + + onIncomingStream(new CustomEvent('stream', { + detail: incomingStream + })) + + // incoming stream is opened asynchronously + await delay(100) + + expect(middleware1.calledOnce).to.be.true() + expect(middleware2.called).to.be.false() + expect(handler.called).to.be.false() + expect(incomingStream.abort.called).to.be.false() + }) + + it('should allow outbound stream middleware to stop the chain without aborting', async () => { + const streamProtocol = '/test/protocol' + + const middleware1 = Sinon.stub().callsFake(() => { + return false + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + registrar.getMiddleware.withArgs(streamProtocol).returns([ + middleware1, + middleware2 + ]) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler: () => {}, + options: {} + }) + + const abort = Sinon.stub() + const muxedStream = stubInterface({ + abort, + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol, + status: 'open' + }) + muxer.createStream = async () => muxedStream + + const connection = createConnection(components, init) + + await connection.newStream(streamProtocol) + + expect(middleware1.calledOnce).to.be.true() + expect(middleware2.called).to.be.false() + expect(abort.called).to.be.false() + }) + + it('should allow async outbound stream middleware to stop the chain without aborting', async () => { + const streamProtocol = '/test/protocol' + + const middleware1 = Sinon.stub().callsFake(async () => { + await delay(1) + return false + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + registrar.getMiddleware.withArgs(streamProtocol).returns([ + middleware1, + middleware2 + ]) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler: () => {}, + options: {} + }) + + const abort = Sinon.stub() + const muxedStream = stubInterface({ + abort, + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol, + status: 'open' + }) + muxer.createStream = async () => muxedStream + + const connection = createConnection(components, init) + + await connection.newStream(streamProtocol) + + expect(middleware1.calledOnce).to.be.true() + expect(middleware2.called).to.be.false() + expect(abort.called).to.be.false() + }) + + it('should abort the outgoing stream when middleware throws', async () => { + const streamProtocol = '/test/protocol' + const err = new Error('boom') + + const middleware = Sinon.stub().callsFake(() => { + throw err + }) + + registrar.getMiddleware.withArgs(streamProtocol).returns([ + middleware + ]) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler: () => {}, + options: {} + }) + + const abort = Sinon.stub() + const muxedStream = stubInterface({ + abort, + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol, + status: 'open' + }) + muxer.createStream = async () => muxedStream + + const connection = createConnection(components, init) + + await expect(connection.newStream(streamProtocol)).to.eventually.be.rejectedWith(err) + + expect(middleware.calledOnce).to.be.true() + expect(abort.calledOnceWith(err)).to.be.true() + }) + + it('should abort the incoming stream when middleware throws', async () => { + const streamProtocol = '/test/protocol' + const err = new Error('boom') + + const middleware = Sinon.stub().callsFake(() => { + throw err + }) + const handler = Sinon.stub() + + registrar.getMiddleware.withArgs(streamProtocol).returns([ + middleware + ]) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler, + options: {} + }) + + const muxer = stubInterface({ + streams: [] + }) + + createConnection(components, { + ...init, + muxer + }) + + const onIncomingStream = muxer.addEventListener.getCall(0).args[1] + + if (typeof onIncomingStream !== 'function') { + throw new Error('Stream handler was not function') + } + + const incomingStream = stubInterface({ + abort: Sinon.stub(), + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol + }) + + onIncomingStream(new CustomEvent('stream', { + detail: incomingStream + })) + + // incoming stream is opened asynchronously + await delay(100) + + expect(middleware.calledOnce).to.be.true() + expect(handler.called).to.be.false() + expect(incomingStream.abort.calledOnceWith(err)).to.be.true() }) it('should not call outbound middleware if previous middleware errors', async () => {