Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions lib/src/ssh_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ class SSHClient {
/// method. Set this to null to disable automatic keep-alive messages.
final Duration? keepAliveInterval;

/// Maximum time to wait for the SSH transport handshake to complete.
final Duration? handshakeTimeout;

/// Maximum time to wait for authentication after the transport is ready.
final Duration? authTimeout;

/// Function called when additional host keys are received. This is an OpenSSH
/// extension. May not be called if the server does not support the extension.
// final SSHHostKeysHandler? onHostKeys;
Expand Down Expand Up @@ -217,6 +223,8 @@ class SSHClient {
this.onX11Forward,
this.agentHandler,
this.keepAliveInterval = const Duration(seconds: 10),
this.handshakeTimeout,
this.authTimeout,
this.disableHostkeyVerification = false,
String ident = 'DartSSH_2.0',
}) : ident = _validateIdent(ident) {
Expand Down Expand Up @@ -247,6 +255,11 @@ class SSHClient {
if (identities != null) {
_keyPairsLeft.addAll(identities!);
}

final handshakeTimeout = this.handshakeTimeout;
if (handshakeTimeout != null) {
_handshakeTimeoutTimer = Timer(handshakeTimeout, _handleHandshakeTimeout);
}
}

static String _validateIdent(String ident) {
Expand Down Expand Up @@ -295,6 +308,12 @@ class SSHClient {

SSHAuthMethod? _currentAuthMethod;

var _transportReady = false;

Timer? _handshakeTimeoutTimer;

Timer? _authTimeoutTimer;

/// A [Future] that completes when the client has authenticated, or
/// completes with an error if the client could not authenticate.
Future<void> get authenticated => _authenticated.future;
Expand Down Expand Up @@ -676,6 +695,8 @@ class SSHClient {
/// Shutdown the entire SSH connection. Sessions and channels will also be
/// closed immediately.
void close() {
_handshakeTimeoutTimer?.cancel();
_authTimeoutTimer?.cancel();
_closeChannels();
_transport.close();
}
Expand All @@ -692,11 +713,25 @@ class SSHClient {

void _handleTransportReady() {
printDebug?.call('SSHClient._onTransportReady');
_transportReady = true;
_handshakeTimeoutTimer?.cancel();
_handshakeTimeoutTimer = null;

final authTimeout = this.authTimeout;
if (authTimeout != null) {
_authTimeoutTimer = Timer(authTimeout, _handleAuthTimeout);
}

_requestAuthentication();
}

void _handleTransportClosed(SSHError? error) {
printDebug?.call('SSHClient._onTransportClosed');
_handshakeTimeoutTimer?.cancel();
_handshakeTimeoutTimer = null;
_authTimeoutTimer?.cancel();
_authTimeoutTimer = null;

if (!_authenticated.isCompleted) {
_authenticated.completeError(
SSHAuthAbortError('Connection closed before authentication', error),
Expand Down Expand Up @@ -806,11 +841,31 @@ class SSHClient {
void _handleUserauthSuccess() {
printTrace?.call('<- $socket: SSH_Message_Userauth_Success');
printDebug?.call('SSHClient._handleUserauthSuccess');
_authTimeoutTimer?.cancel();
_authTimeoutTimer = null;
_authenticated.complete();
onAuthenticated?.call();
_keepAlive?.start();
}

void _handleHandshakeTimeout() {
if (_authenticated.isCompleted || _transportReady) return;

_handshakeTimeoutTimer = null;
final error = SSHHandshakeError('Handshake timed out');
_authenticated.completeError(error, StackTrace.current);
}

void _handleAuthTimeout() {
if (_authenticated.isCompleted) return;

_authTimeoutTimer = null;
_authenticated.completeError(
SSHAuthAbortError('Authentication timed out'),
StackTrace.current,
);
}

void _handleUserauthFailure(Uint8List payload) {
final message = SSH_Message_Userauth_Failure.decode(payload);
printTrace?.call('<- $socket: $message');
Expand Down
103 changes: 103 additions & 0 deletions test/src/ssh_client_timeout_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import 'dart:async';
import 'dart:convert';
import 'dart:mirrors';
import 'dart:typed_data';

import 'package:dartssh2/dartssh2.dart';
import 'package:test/test.dart';

void main() {
final clientLibrary = reflectClass(SSHClient).owner as LibraryMirror;
Symbol privateSymbol(String name) =>
MirrorSystem.getSymbol(name, clientLibrary);

group('SSHClient timeouts', () {
test('fails authentication future when handshake times out', () async {
final socket = _FakeSSHSocket();
final client = SSHClient(
socket,
username: 'demo',
handshakeTimeout: const Duration(milliseconds: 10),
);

await expectLater(
client.authenticated,
throwsA(isA<SSHHandshakeError>()),
);

client.close();
});

test('fails authentication future when auth times out', () async {
final socket = _FakeSSHSocket();
final client = SSHClient(
socket,
username: 'demo',
authTimeout: const Duration(milliseconds: 10),
);

reflect(client).invoke(privateSymbol('_handleTransportReady'), const []);

await expectLater(
client.authenticated,
throwsA(isA<SSHAuthAbortError>()),
);

client.close();
});
});
}

class _FakeSSHSocket implements SSHSocket {
final _inputController = StreamController<Uint8List>();
final _doneCompleter = Completer<void>();
final _sink = _RecordingSink();

@override
Stream<Uint8List> get stream => _inputController.stream;

@override
StreamSink<List<int>> get sink => _sink;

@override
Future<void> get done => _doneCompleter.future;

@override
Future<void> close() async {
if (!_doneCompleter.isCompleted) {
_doneCompleter.complete();
}
await _inputController.close();
}

@override
void destroy() {
if (!_doneCompleter.isCompleted) {
_doneCompleter.complete();
}
unawaited(_inputController.close());
}
}

class _RecordingSink implements StreamSink<List<int>> {
@override
void add(List<int> data) {
latin1.decode(data);
}

@override
void addError(Object error, [StackTrace? stackTrace]) {}

@override
Future<void> addStream(Stream<List<int>> stream) async {
await for (final data in stream) {
add(data);
}
}

@override
Future<void> close() async {}

@override
Future<void> get done async {}
}