diff --git a/src/ts/ssh-tcp/index.ts b/src/ts/ssh-tcp/index.ts index d8a583e..cf7bf1c 100644 --- a/src/ts/ssh-tcp/index.ts +++ b/src/ts/ssh-tcp/index.ts @@ -11,6 +11,7 @@ export { PortForwardingService } from './services/portForwardingService'; export { LocalPortForwarder } from './services/localPortForwarder'; export { RemotePortForwarder } from './services/remotePortForwarder'; export { RemotePortStreamer } from './services/remotePortStreamer'; +export { StreamForwarder } from './services/streamForwarder'; export { PortForwardMessageFactory } from './portForwardMessageFactory'; export { PortForwardRequestMessage } from './messages/portForwardRequestMessage'; diff --git a/src/ts/ssh-tcp/services/localPortForwarder.ts b/src/ts/ssh-tcp/services/localPortForwarder.ts index 9943d11..3aaaabd 100644 --- a/src/ts/ssh-tcp/services/localPortForwarder.ts +++ b/src/ts/ssh-tcp/services/localPortForwarder.ts @@ -14,6 +14,7 @@ import { SshProtocolExtensionNames, SshStream, } from '@microsoft/dev-tunnels-ssh'; +import { Duplex } from 'stream'; import { StreamForwarder } from './streamForwarder'; import { PortForwardingService } from './portForwardingService'; @@ -157,6 +158,20 @@ export class LocalPortForwarder extends SshService { // TODO: Set socket options? + // Attach a temporary 'error' handler so a peer reset between accept and + // StreamForwarder construction (while we await openChannel and the + // forwardedPortConnecting event handler) does not crash the host with + // an unhandled 'error' event. Removed once the forwarder takes over. + const acceptErrorHandler = (e: Error) => { + this.trace( + TraceLevel.Warning, + SshTraceEventIds.portForwardConnectionFailed, + `PortForwardingService accepted socket errored before forwarding started: ${e.message}`, + e, + ); + }; + socket.on('error', acceptErrorHandler); + let channel: SshChannel | null; try { channel = await this.pfs.openChannel( @@ -172,6 +187,7 @@ export class LocalPortForwarder extends SshService { // TODO: Destroy the socket in a way that causes a connection reset: // https://github.com/nodejs/node/issues/27428 + socket.removeListener('error', acceptErrorHandler); socket.destroy(); // Don't re-throw. This is an async event handler so the caller isn't awaiting. @@ -180,19 +196,41 @@ export class LocalPortForwarder extends SshService { } // The event handler may return a transformed stream. - const forwardedStream = await this.pfs.forwardedPortConnecting( - this.remotePort ?? this.localPort, - false, - new SshStream(channel), - ); + const sshStream = new SshStream(channel); + let forwardedStream: Duplex | null; + try { + forwardedStream = await this.pfs.forwardedPortConnecting( + this.remotePort ?? this.localPort, + false, + sshStream, + ); + } catch (e) { + socket.removeListener('error', acceptErrorHandler); + socket.destroy(); + sshStream.destroy(); + throw e; + } if (!forwardedStream) { // The event handler rejected the connection. + socket.removeListener('error', acceptErrorHandler); + socket.destroy(); + sshStream.destroy(); return; } - const forwarder = new StreamForwarder(socket, forwardedStream, channel.session.trace); - this.pfs.streamForwarders.push(forwarder); + // Hand off socket error handling to the StreamForwarder. + socket.removeListener('error', acceptErrorHandler); + + const forwarder = new StreamForwarder( + socket, + forwardedStream, + channel.session.trace, + this.pfs.removeStreamForwarder, + ); + if (!forwarder.isDisposed) { + this.pfs.streamForwarders.add(forwarder); + } } public dispose() { diff --git a/src/ts/ssh-tcp/services/portForwardingService.ts b/src/ts/ssh-tcp/services/portForwardingService.ts index 50de5f6..e4bd947 100644 --- a/src/ts/ssh-tcp/services/portForwardingService.ts +++ b/src/ts/ssh-tcp/services/portForwardingService.ts @@ -89,7 +89,12 @@ export class PortForwardingService extends SshService { private readonly remoteConnectors = new Map(); /* @internal */ - public readonly streamForwarders: StreamForwarder[] = []; + public readonly streamForwarders: Set = new Set(); + + /* @internal */ + public readonly removeStreamForwarder = (forwarder: StreamForwarder): void => { + this.streamForwarders.delete(forwarder); + }; /* @internal */ public constructor(session: SshSession) { @@ -936,7 +941,7 @@ export class PortForwardingService extends SshService { ...this.remoteConnectors.values(), ]; - this.streamForwarders.splice(0, this.streamForwarders.length); + this.streamForwarders.clear(); this.localForwarders.clear(); this.remoteConnectors.clear(); diff --git a/src/ts/ssh-tcp/services/remotePortForwarder.ts b/src/ts/ssh-tcp/services/remotePortForwarder.ts index bf91ba7..84d310e 100644 --- a/src/ts/ssh-tcp/services/remotePortForwarder.ts +++ b/src/ts/ssh-tcp/services/remotePortForwarder.ts @@ -14,6 +14,7 @@ import { TraceLevel, SshStream, } from '@microsoft/dev-tunnels-ssh'; +import { Duplex } from 'stream'; import { StreamForwarder } from './streamForwarder'; import { PortForwardingService } from './portForwardingService'; import { RemotePortConnector } from './remotePortConnector'; @@ -75,16 +76,41 @@ export class RemotePortForwarder extends RemotePortConnector { cancellation?: CancellationToken, ): Promise { const channel = request.channel; + const sshStream = new SshStream(channel); + + // Attach a temporary 'error' handler so a remote channel reset between + // channel-open and StreamForwarder construction (while we await + // forwardedPortConnecting and the local TCP connect) does not crash + // the host with an unhandled 'error' event. Removed once the forwarder + // takes over (or the connection is rejected/aborted below). + const channelErrorHandler = (e: Error) => { + trace( + TraceLevel.Warning, + SshTraceEventIds.portForwardConnectionFailed, + `PortForwardingService channel stream errored before forwarding started: ${e.message}`, + e, + ); + }; + sshStream.on('error', channelErrorHandler); - const forwardedStream = await pfs.forwardedPortConnecting( - remotePort ?? localPort, - true, - new SshStream(channel), - cancellation, - ); + let forwardedStream: Duplex | null; + try { + forwardedStream = await pfs.forwardedPortConnecting( + remotePort ?? localPort, + true, + sshStream, + cancellation, + ); + } catch (e) { + sshStream.removeListener('error', channelErrorHandler); + sshStream.destroy(); + throw e; + } if (!forwardedStream) { // The event handler rejected the connection. + sshStream.removeListener('error', channelErrorHandler); + sshStream.destroy(); request.failureReason = SshChannelOpenFailureReason.connectFailed; return; } @@ -119,6 +145,11 @@ export class RemotePortForwarder extends RemotePortConnector { await connectCompletion.promise; } catch (e) { if (!(e instanceof Error) || cancellation?.isCancellationRequested) { + sshStream.removeListener('error', channelErrorHandler); + // The forwardedStream may be the user-substituted stream; close it + // so we don't leak the underlying SSH channel. + forwardedStream.destroy(); + if (forwardedStream !== sshStream) sshStream.destroy(); throw e; } @@ -131,19 +162,39 @@ export class RemotePortForwarder extends RemotePortConnector { ); request.failureReason = SshChannelOpenFailureReason.connectFailed; request.failureDescription = e.message; + + // Tear down the SSH side and abandon: do NOT proceed to construct a + // StreamForwarder around a destroyed socket. Pre-existing bug: the + // previous code fell through and built a forwarder anyway, leaking + // resources and (before PR #138) potentially crashing the host. + sshStream.removeListener('error', channelErrorHandler); + forwardedStream.destroy(); + if (forwardedStream !== sshStream) sshStream.destroy(); + return; } finally { cancellationRegistration?.dispose(); } // TODO: Set socket options? - const streamForwarder = new StreamForwarder(socket, forwardedStream, channel.session.trace); + // Hand off SSH stream error handling to the StreamForwarder. + sshStream.removeListener('error', channelErrorHandler); + + const streamForwarder = new StreamForwarder( + socket, + forwardedStream, + channel.session.trace, + pfs.removeStreamForwarder, + ); trace( TraceLevel.Info, SshTraceEventIds.portForwardConnectionOpened, `${channel.session} PortForwardingService forwarded channel ` + `#${channel.channelId} connection to ${localHost}:${localPort}.`, ); - pfs.streamForwarders.push(streamForwarder); + pfs.streamForwarders.add(streamForwarder); + if (streamForwarder.isDisposed) { + pfs.streamForwarders.delete(streamForwarder); + } } } diff --git a/src/ts/ssh-tcp/services/streamForwarder.ts b/src/ts/ssh-tcp/services/streamForwarder.ts index cd49f97..6a0d89a 100644 --- a/src/ts/ssh-tcp/services/streamForwarder.ts +++ b/src/ts/ssh-tcp/services/streamForwarder.ts @@ -9,16 +9,23 @@ import { Socket } from 'net'; export class StreamForwarder implements Disposable { private disposed: boolean = false; + private readonly onDisposedCallback?: (forwarder: StreamForwarder) => void; + + public get isDisposed(): boolean { + return this.disposed; + } - /* @internal */ public constructor( public readonly localStream: Duplex, public readonly remoteStream: Duplex, public readonly trace: Trace, + onDisposed?: (forwarder: StreamForwarder) => void, ) { if (!localStream) throw new TypeError('Local stream is required.'); if (!remoteStream) throw new TypeError('Remote stream is required.'); + this.onDisposedCallback = onDisposed; + // Without these listeners, errors from either side of the forwarder // propagate up to the Node process as unhandled 'error' events and // crash the host. Node's pipe() does not propagate errors between @@ -26,6 +33,9 @@ export class StreamForwarder implements Disposable { localStream.on('error', (err) => this.onStreamError('local', err)); remoteStream.on('error', (err) => this.onStreamError('remote', err)); + // pipe() forwards 'end' (so EOF on one side gracefully ends the other), + // but does NOT forward 'error'. Error propagation is handled above + // by disposing the forwarder, which tears down both sides. localStream.pipe(remoteStream); remoteStream.pipe(localStream); } @@ -72,6 +82,18 @@ export class StreamForwarder implements Disposable { if (!this.disposed) { this.disposed = true; this.close(true); + if (this.onDisposedCallback) { + try { + this.onDisposedCallback(this); + } catch (e) { + const errorMessage = e instanceof Error ? e.message : String(e); + this.trace( + TraceLevel.Warning, + SshTraceEventIds.unknownError, + `Stream forwarder onDisposed callback threw: ${errorMessage}`, + ); + } + } } } } diff --git a/src/ts/ssh/sshStream.ts b/src/ts/ssh/sshStream.ts index 3a282b6..9c5916b 100644 --- a/src/ts/ssh/sshStream.ts +++ b/src/ts/ssh/sshStream.ts @@ -4,6 +4,7 @@ import { SshChannel } from './sshChannel'; import { PromiseCompletionSource } from './util/promiseCompletionSource'; +import { TraceLevel, SshTraceEventIds } from './trace'; import { Duplex } from 'stream'; /** @@ -132,7 +133,14 @@ export class SshStream extends Duplex { * Destroys the stream and closes the underlying SSH channel. */ public destroy(error?: Error) { - void this.channel.close().catch(); + this.channel.close().catch((e) => { + const message = e instanceof Error ? e.message : String(e); + this.channel.session.trace( + TraceLevel.Warning, + SshTraceEventIds.unknownError, + `${this} channel close on destroy failed: ${message}`, + ); + }); super.destroy(error); return this; } diff --git a/test/ts/ssh-test/streamForwarderTests.ts b/test/ts/ssh-test/streamForwarderTests.ts new file mode 100644 index 0000000..4c7a242 --- /dev/null +++ b/test/ts/ssh-test/streamForwarderTests.ts @@ -0,0 +1,234 @@ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// + +import * as assert from 'assert'; +import { suite, test, slow, timeout } from '@testdeck/mocha'; +import { Duplex } from 'stream'; + +import { Trace } from '@microsoft/dev-tunnels-ssh'; +import { StreamForwarder } from '@microsoft/dev-tunnels-ssh-tcp'; + +const timeoutMs = 3000; + +/** + * A minimal Duplex stream that captures written data and allows pushing readable data. + * Does not echo writes back to the readable side (unlike PassThrough), making it safe + * for use in bidirectional pipe scenarios without creating infinite loops. + */ +class MockDuplex extends Duplex { + public readonly written: Buffer[] = []; + + constructor() { + super(); + } + + _write(chunk: Buffer, _encoding: string, callback: (error?: Error | null) => void): void { + this.written.push(Buffer.from(chunk)); + callback(); + } + + _read(_size: number): void { + // No-op; data is pushed externally via this.push() + } + + pushData(data: Buffer): void { + this.push(data); + } +} + +/** + * A Duplex that emits 'error' on nextTick after pipe() is called, + * simulating the synchronous-dispose race in StreamForwarder construction. + */ +class ErrorOnPipeDuplex extends MockDuplex { + private readonly pipeError: Error; + + constructor(error: Error) { + super(); + this.pipeError = error; + } + + pipe(destination: T): T { + const result = super.pipe(destination); + process.nextTick(() => this.emit('error', this.pipeError)); + return result; + } +} + +function createTrace(): Trace { + return () => {}; +} + +@suite +@slow(2000) +@timeout(timeoutMs * 2) +export class StreamForwarderTests { + @test + public async forwardDataLocalToRemote() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + const forwarder = new StreamForwarder(local, remote, createTrace()); + + // Push data into local's readable side; it should be piped to remote's writable side. + local.pushData(Buffer.from('hello')); + + await new Promise((r) => setImmediate(r)); + assert.strictEqual(Buffer.concat(remote.written).toString(), 'hello'); + + forwarder.dispose(); + local.destroy(); + remote.destroy(); + } + + @test + public async forwardDataRemoteToLocal() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + const forwarder = new StreamForwarder(local, remote, createTrace()); + + // Push data into remote's readable side; it should be piped to local's writable side. + remote.pushData(Buffer.from('world')); + + await new Promise((r) => setImmediate(r)); + assert.strictEqual(Buffer.concat(local.written).toString(), 'world'); + + forwarder.dispose(); + local.destroy(); + remote.destroy(); + } + + @test + public async localStreamErrorDisposesForwarder() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + let disposedCalled = false; + const forwarder = new StreamForwarder(local, remote, createTrace(), () => { + disposedCalled = true; + }); + + local.emit('error', new Error('connection reset')); + + await new Promise((r) => setImmediate(r)); + assert.strictEqual(forwarder.isDisposed, true); + assert.strictEqual(disposedCalled, true); + local.destroy(); + remote.destroy(); + } + + @test + public async remoteStreamErrorDisposesForwarder() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + let disposedCalled = false; + const forwarder = new StreamForwarder(local, remote, createTrace(), () => { + disposedCalled = true; + }); + + remote.emit('error', new Error('channel closed')); + + await new Promise((r) => setImmediate(r)); + assert.strictEqual(forwarder.isDisposed, true); + assert.strictEqual(disposedCalled, true); + local.destroy(); + remote.destroy(); + } + + @test + public async onDisposedCallbackInvokedOnDispose() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + let callbackForwarder: StreamForwarder | null = null; + const forwarder = new StreamForwarder(local, remote, createTrace(), (f: StreamForwarder) => { + callbackForwarder = f; + }); + + forwarder.dispose(); + assert.strictEqual(callbackForwarder, forwarder); + local.destroy(); + remote.destroy(); + } + + @test + public async disposeIsIdempotent() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + let disposeCount = 0; + const forwarder = new StreamForwarder(local, remote, createTrace(), () => { + disposeCount++; + }); + + forwarder.dispose(); + forwarder.dispose(); + forwarder.dispose(); + assert.strictEqual(disposeCount, 1); + local.destroy(); + remote.destroy(); + } + + @test + public async synchronousErrorDuringPipeMarksDisposed() { + const errorStream = new ErrorOnPipeDuplex(new Error('immediate failure')); + const remote = new MockDuplex(); + const forwarder = new StreamForwarder(errorStream, remote, createTrace()); + + // The error fires on nextTick, so wait a tick. + await new Promise((r) => setImmediate(r)); + assert.strictEqual(forwarder.isDisposed, true); + errorStream.destroy(); + remote.destroy(); + } + + @test + public async forwarderRemovedFromSetOnDispose() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + const set = new Set(); + const forwarder = new StreamForwarder(local, remote, createTrace(), (f: StreamForwarder) => { + set.delete(f); + }); + set.add(forwarder); + + forwarder.dispose(); + assert.strictEqual(set.size, 0); + local.destroy(); + remote.destroy(); + } + + @test + public async synchronousDisposeRaceDoesNotLeaveStaleEntry() { + const errorStream = new ErrorOnPipeDuplex(new Error('race error')); + const remote = new MockDuplex(); + const set = new Set(); + const forwarder = new StreamForwarder(errorStream, remote, createTrace(), (f: StreamForwarder) => { + set.delete(f); + }); + + // Caller adds after construction (the real code pattern). + if (!forwarder.isDisposed) { + set.add(forwarder); + } + + // Wait for the error to fire and dispose to run. + await new Promise((r) => setImmediate(r)); + assert.strictEqual(set.size, 0); + assert.strictEqual(forwarder.isDisposed, true); + errorStream.destroy(); + remote.destroy(); + } + + @test + public async onDisposedCallbackErrorIsSwallowed() { + const local = new MockDuplex(); + const remote = new MockDuplex(); + const forwarder = new StreamForwarder(local, remote, createTrace(), () => { + throw new Error('callback error'); + }); + + // Should not throw — the error is traced and swallowed. + forwarder.dispose(); + assert.strictEqual(forwarder.isDisposed, true); + local.destroy(); + remote.destroy(); + } +}