diff --git a/protocol/ssh/connection.go b/protocol/ssh/connection.go index c348f1b2..f1faca83 100644 --- a/protocol/ssh/connection.go +++ b/protocol/ssh/connection.go @@ -82,7 +82,7 @@ func NewConnection(cfg Config, opts ...Option) (*Connection, error) { } var ( - authMethodCache = sync.Map{} + signerCache = sync.Map{} knownHostsMU sync.Mutex globalOnce sync.Once @@ -307,8 +307,7 @@ func (c *Connection) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop log.Trace(context.Background(), "using passed-in auth methods", "count", len(c.AuthMethods)) config.Auth = c.AuthMethods } else if len(signers) > 0 { - c.Log().Debug("using all keys from ssh agent because a keypath was not explicitly given", "count", len(signers)) - config.Auth = append(config.Auth, ssh.PublicKeys(signers...)) + c.Log().Debug("using all keys from ssh agent", "count", len(signers)) } for _, keyPath := range c.keyPaths { @@ -317,28 +316,29 @@ func (c *Connection) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop log.Trace(context.Background(), "expand keypath", log.FileAttr(keyPath), log.ErrorAttr(err)) continue } - if am, ok := authMethodCache.Load(keyPath); ok { - switch authM := am.(type) { - case ssh.AuthMethod: + if am, ok := signerCache.Load(keyPath); ok { + switch signerCacheItem := am.(type) { + case ssh.Signer: log.Trace(context.Background(), "using cached auth method", log.FileAttr(keyPath)) - config.Auth = append(config.Auth, authM) + signers = append(signers, signerCacheItem) case error: - log.Trace(context.Background(), "already discarded key", log.FileAttr(keyPath), log.ErrorAttr(authM)) + log.Trace(context.Background(), "already discarded key", log.FileAttr(keyPath), log.ErrorAttr(signerCacheItem)) default: log.Trace(context.Background(), fmt.Sprintf("unexpected type %T for cached auth method for %s", am, keyPath)) } continue } - privateKeyAuth, err := c.pkeySigner(signers, keyPath) + signer, err := c.pkeySigner(signers, keyPath) if err != nil { c.Log().Debug("failed to obtain a signer for identity", log.KeyFile, keyPath, log.ErrorAttr(err)) // store the error so this key won't be loaded again - authMethodCache.Store(keyPath, err) + signerCache.Store(keyPath, err) } else { - authMethodCache.Store(keyPath, privateKeyAuth) - config.Auth = append(config.Auth, privateKeyAuth) + signerCache.Store(keyPath, signer) + signers = append(signers, signer) } } + config.Auth = append(config.Auth, ssh.PublicKeys(signers...)) if len(config.Auth) == 0 { return nil, fmt.Errorf("%w: no usable authentication method found", protocol.ErrAbort) @@ -433,7 +433,7 @@ func (c *Connection) Connect() error { return nil } -func (c *Connection) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) { +func (c *Connection) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.Signer, error) { if len(signers) == 0 { return nil, fmt.Errorf("%w: signer not found for public key", protocol.ErrAbort) } @@ -441,14 +441,14 @@ func (c *Connection) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh. for _, s := range signers { if bytes.Equal(key.Marshal(), s.PublicKey().Marshal()) { c.Log().Debug("signer for public key available in ssh agent") - return ssh.PublicKeys(s), nil + return s, nil } } return nil, fmt.Errorf("%w: the provided key is a public key and is not known by agent", protocol.ErrAbort) } -func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, error) { +func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.Signer, error) { path, err := homedir.ExpandFile(path) if err != nil { return nil, fmt.Errorf("expand keyfile path: %w", err) @@ -468,7 +468,7 @@ func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMeth signer, err := ssh.ParsePrivateKey(key) if err == nil { c.Log().Debug("using an unencrypted private key", log.KeyFile, path) - return ssh.PublicKeys(signer), nil + return signer, nil } var ppErr *ssh.PassphraseMissingError @@ -491,7 +491,7 @@ func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMeth if err != nil { return nil, fmt.Errorf("%w: encrypted key %s decoding failed: %w", protocol.ErrAbort, path, err) } - return ssh.PublicKeys(signer), nil + return signer, nil } }