Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions lib/src/dynamic_forward_io.dart
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,39 @@ class _SocksConnection {
StreamSubscription<Uint8List>? _remoteSub;
Timer? _handshakeTimer;
bool _closed = false;
bool _dialing = false;
_SocksState _state = _SocksState.greeting;

void start() {
_handshakeTimer = Timer(options.handshakeTimeout, () async {
_sendReply(_SocksReply.ttlExpired);
await close();
});

_clientSub = _client.listen(
_onClientData,
onDone: close,
onDone: _handleClientEOF,
onError: (_, __) => close(),
cancelOnError: true,
);

_handshakeTimer = Timer(options.handshakeTimeout, () async {
_sendReply(_SocksReply.ttlExpired);
await close();
});
}

void _handleClientEOF() {
if (_state == _SocksState.streaming) {
_remote?.sink.close();
_clientSub?.cancel();
} else {
close();
}
}

void _handleRemoteEOF() {
if (_state == _SocksState.streaming) {
_client.destroy();
_remoteSub?.cancel();
} else {
close();
}
}

Future<void> close() async {
Expand All @@ -152,9 +171,8 @@ class _SocksConnection {
return;
}

_buffer.add(chunk);

try {
_buffer.add(chunk);
await _consumeHandshake();
} catch (_) {
await close();
Expand All @@ -169,41 +187,55 @@ class _SocksConnection {
}

if (_state == _SocksState.request) {
if (_dialing) return;
final target = _parseConnectRequest();
if (target == null) return;
_dialing = true;

if (filter != null && !filter!(target.host, target.port)) {
_sendReply(_SocksReply.connectionNotAllowed);
_dialing = false;
await close();
return;
}

if (!canOpenTunnel()) {
_sendReply(_SocksReply.connectionRefused);
_dialing = false;
await close();
return;
}

_handshakeTimer?.cancel();
_handshakeTimer = null;

try {
_remote = await dial(target.host, target.port).timeout(
options.connectTimeout,
);
} catch (_) {
_sendReply(_SocksReply.hostUnreachable);
_dialing = false;
await close();
return;
}

_dialing = false;

if (_closed) {
_remote?.destroy();
_remote = null;
return;
}

_remoteSub = _remote!.stream.listen(
_client.add,
onDone: close,
onDone: _handleRemoteEOF,
onError: (_, __) => close(),
cancelOnError: true,
);

_sendReply(_SocksReply.succeeded);
_handshakeTimer?.cancel();
_handshakeTimer = null;
_state = _SocksState.streaming;

final pending = _buffer.takeAll();
Expand Down Expand Up @@ -288,7 +320,7 @@ class _SocksConnection {
if (atyp == 0x03) {
final length = request[4];
final bytes = request.sublist(5, 5 + length);
return utf8.decode(bytes);
return utf8.decode(bytes, allowMalformed: true);
}

final raw = request.sublist(4, 20);
Expand Down Expand Up @@ -316,12 +348,18 @@ class _SocksConnection {
}

class _ByteBuffer {
static const kMaxHandshakeSize = 32768;

final _data = <int>[];
int _offset = 0;

int get length => _data.length - _offset;

void add(List<int> chunk) {
if (length + chunk.length > kMaxHandshakeSize) {
throw StateError(
'Handshake buffer overflow: $length + ${chunk.length} > $kMaxHandshakeSize');
}
_data.addAll(chunk);
}

Expand Down
20 changes: 19 additions & 1 deletion lib/src/ssh_agent.dart
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,16 @@ class SSHKeyPairAgent implements SSHAgentHandler {
) {
final key = _rsaKeyFrom(identity);
if (key == null) {
return identity.sign(data) as SSHRsaSignature;
final signature = identity.sign(data);
if (signature is SSHRsaSignature) {
if (signature.type != signatureType) {
throw StateError(
'RSA signature type mismatch: requested $signatureType but identity produced ${signature.type}');
}
return signature;
}
throw StateError(
'RSA signing requested but identity produced non-RSA signature: ${signature.runtimeType}');
}

final signer = _rsaSignerFor(signatureType);
Expand Down Expand Up @@ -154,6 +163,8 @@ class SSHKeyPairAgent implements SSHAgentHandler {
}

class SSHAgentChannel {
static const maxFrameSize = 256 * 1024;

SSHAgentChannel(this._channel, this._handler, {this.printDebug}) {
_subscription = _channel.stream.listen(
_handleData,
Expand Down Expand Up @@ -188,6 +199,13 @@ class SSHAgentChannel {
Future<void> _processQueue() async {
while (_buffer.length >= 4) {
final length = ByteData.sublistView(_buffer, 0, 4).getUint32(0);
if (length == 0 || length > maxFrameSize) {
printDebug
?.call('SSH agent: invalid frame length $length, closing channel');
_channel.destroy();
_buffer = Uint8List(0);
return;
}
if (_buffer.length < 4 + length) return;
final payload = _buffer.sublist(4, 4 + length);
_buffer = _buffer.sublist(4 + length);
Expand Down
50 changes: 43 additions & 7 deletions lib/src/ssh_key_pair.dart
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,13 @@ class OpenSSHKeyPairs {

final key = Uint8List.view(kdfHash.buffer, 0, cipher.keySize);
final iv = Uint8List.view(kdfHash.buffer, cipher.keySize, cipher.ivSize);
final decryptCipher = cipher.createCipher(key, iv, forEncryption: false);
return decryptCipher.processAll(blob);

try {
final decryptCipher = cipher.createCipher(key, iv, forEncryption: false);
return decryptCipher.processAll(blob);
} catch (e) {
throw SSHKeyDecryptError('Failed to decrypt private key', e);
}
}

@override
Expand Down Expand Up @@ -339,7 +344,7 @@ class OpenSSHRsaKeyPair with OpenSSHKeyPair {
final iqmp = reader.readMpint();
final p = reader.readMpint();
final q = reader.readMpint();
final comment = reader.readUtf8();
final comment = reader.readUtf8(allowMalformed: true);
return OpenSSHRsaKeyPair(n, e, d, iqmp, p, q, comment);
}

Expand Down Expand Up @@ -397,7 +402,7 @@ class OpenSSHEd25519KeyPair with OpenSSHKeyPair {
factory OpenSSHEd25519KeyPair.readFrom(SSHMessageReader reader) {
final publicKey = reader.readString();
final privateKey = reader.readString();
final comment = reader.readUtf8();
final comment = reader.readUtf8(allowMalformed: true);
return OpenSSHEd25519KeyPair(publicKey, privateKey, comment);
}

Expand Down Expand Up @@ -446,7 +451,7 @@ class OpenSSHEcdsaKeyPair with OpenSSHKeyPair {
final curve = reader.readUtf8();
final q = reader.readString();
final d = reader.readMpint();
final comment = reader.readUtf8();
final comment = reader.readUtf8(allowMalformed: true);
return OpenSSHEcdsaKeyPair(curve, q, d, comment);
}

Expand Down Expand Up @@ -538,6 +543,8 @@ class RsaKeyPair {

try {
return RsaPrivateKey.decode(keyBlob);
} on UnsupportedError {
rethrow;
} catch (e) {
throw SSHKeyDecodeError('Failed to decode private key', e);
}
Expand Down Expand Up @@ -755,6 +762,8 @@ class EcKeyPair {

try {
return _decodeLegacyEcPrivateKey(keyBlob);
} on UnsupportedError {
rethrow;
} catch (e) {
throw SSHKeyDecodeError('Failed to decode private key', e);
}
Expand All @@ -772,28 +781,55 @@ class EcKeyPair {
final d = decodeBigIntWithSign(1, privateKeyOctets);

Uint8List? publicPoint;
String? curveId;

for (var i = 2; i < sequence.elements.length; i++) {
final element = sequence.elements[i];
if (element.tag == 0xA1) {
if (element.tag == 0xA0) {
final inner = ASN1Parser(element.valueBytes()).nextObject();
if (inner is ASN1ObjectIdentifier && inner.identifier != null) {
final oid = inner.identifier!;
curveId = _curveIdFromOid(oid);
if (curveId == null) {
throw UnsupportedError(
'Unsupported EC PRIVATE KEY curve OID: $oid');
}
}
} else if (element.tag == 0xA1) {
final inner = ASN1Parser(element.valueBytes()).nextObject();
if (inner is ASN1BitString) {
publicPoint = inner.contentBytes();
}
}
}

final curveId =
curveId ??=
_inferCurveId(publicPoint?.length ?? 0, privateKeyOctets.length);
if (curveId == null) {
throw UnsupportedError('Unsupported EC PRIVATE KEY curve');
}

if (publicPoint != null) {
final expectedPublicPoint = _derivePublicPoint(curveId, d);
if (publicPoint.length != expectedPublicPoint.length ||
!publicPoint.equals(expectedPublicPoint)) {
throw UnsupportedError(
'EC PRIVATE KEY public point does not match curve $curveId');
}
}

final q = publicPoint ?? _derivePublicPoint(curveId, d);

return OpenSSHEcdsaKeyPair(curveId, q, d, '');
}

String? _curveIdFromOid(String oid) {
if (oid == '1.2.840.10045.3.1.7') return 'nistp256';
if (oid == '1.3.132.0.34') return 'nistp384';
if (oid == '1.3.132.0.35') return 'nistp521';
return null;
}

String? _inferCurveId(int publicPointLength, int privateKeyLength) {
if (publicPointLength == 65 || privateKeyLength == 32) {
return 'nistp256';
Expand Down
63 changes: 63 additions & 0 deletions test/src/socket/dynamic_forward_io_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,69 @@ void main() {
expect(dialedHosts[0], '192.168.1.2');
expect(dialedHosts[1], contains(':'));
});

test('closes remote sink when client EOF arrives during streaming',
() async {
late _DialedTunnel dialed;

final forward = await startDynamicForward(
bindHost: '127.0.0.1',
bindPort: 0,
options: const SSHDynamicForwardOptions(),
dial: (_, __) async {
dialed = _DialedTunnel.create();
return dialed.channel;
},
);

final client = await Socket.connect(forward.host, forward.port);
final incoming = client.asBroadcastStream();
addTearDown(() async {
await client.close();
await forward.close();
dialed.dispose();
});

await _sendGreeting(client, incoming);
final reply =
await _sendConnectDomain(client, incoming, 'example.com', 443);
expect(reply[1], 0x00);

// Send some data then close client side (half-close / EOF).
client.add(utf8.encode('data'));
await client.close();
await Future<void>.delayed(const Duration(milliseconds: 30));
});

test('handles handshake buffer overflow gracefully', () async {
final forward = await startDynamicForward(
bindHost: '127.0.0.1',
bindPort: 0,
options: const SSHDynamicForwardOptions(),
dial: (_, __) async => _DialedTunnel.create().channel,
);
addTearDown(() => forward.close());

final client = await Socket.connect(forward.host, forward.port);
addTearDown(() => client.close());

// Send a valid greeting but then flood the handshake buffer beyond
// kMaxHandshakeSize (32768). The server should close the connection
// rather than keep buffering indefinitely.
await _sendGreeting(client, client.asBroadcastStream());
final huge = Uint8List(33000);
client.add(huge);

// Give the server time to detect the overflow and close. The test
// passes if the server does not crash — the forward is still usable
// for a new connection after the overflow victim is cleaned up.
await Future<void>.delayed(const Duration(milliseconds: 100));

// Verify the forward still accepts new connections.
final client2 = await Socket.connect(forward.host, forward.port);
addTearDown(() => client2.close());
await _sendGreeting(client2, client2.asBroadcastStream());
});
});
}

Expand Down
Loading
Loading