Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions protocol/ssh/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func NewConnection(cfg Config, opts ...Option) (*Connection, error) {
}

var (
authMethodCache = sync.Map{}
signerCache = sync.Map{}

knownHostsMU sync.Mutex
globalOnce sync.Once
Expand Down Expand Up @@ -307,8 +307,7 @@ func (c *Connection) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop
log.Trace(context.Background(), "using passed-in auth methods", "count", len(c.AuthMethods))
config.Auth = c.AuthMethods
} else if len(signers) > 0 {
c.Log().Debug("using all keys from ssh agent because a keypath was not explicitly given", "count", len(signers))
config.Auth = append(config.Auth, ssh.PublicKeys(signers...))
c.Log().Debug("using all keys from ssh agent", "count", len(signers))
}

for _, keyPath := range c.keyPaths {
Expand All @@ -317,28 +316,29 @@ func (c *Connection) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop
log.Trace(context.Background(), "expand keypath", log.FileAttr(keyPath), log.ErrorAttr(err))
continue
}
if am, ok := authMethodCache.Load(keyPath); ok {
switch authM := am.(type) {
case ssh.AuthMethod:
if am, ok := signerCache.Load(keyPath); ok {
switch signerCacheItem := am.(type) {
case ssh.Signer:
log.Trace(context.Background(), "using cached auth method", log.FileAttr(keyPath))
config.Auth = append(config.Auth, authM)
signers = append(signers, signerCacheItem)
case error:
log.Trace(context.Background(), "already discarded key", log.FileAttr(keyPath), log.ErrorAttr(authM))
log.Trace(context.Background(), "already discarded key", log.FileAttr(keyPath), log.ErrorAttr(signerCacheItem))
default:
log.Trace(context.Background(), fmt.Sprintf("unexpected type %T for cached auth method for %s", am, keyPath))
}
continue
}
privateKeyAuth, err := c.pkeySigner(signers, keyPath)
signer, err := c.pkeySigner(signers, keyPath)
if err != nil {
c.Log().Debug("failed to obtain a signer for identity", log.KeyFile, keyPath, log.ErrorAttr(err))
// store the error so this key won't be loaded again
authMethodCache.Store(keyPath, err)
signerCache.Store(keyPath, err)
} else {
authMethodCache.Store(keyPath, privateKeyAuth)
config.Auth = append(config.Auth, privateKeyAuth)
signerCache.Store(keyPath, signer)
signers = append(signers, signer)
}
}
config.Auth = append(config.Auth, ssh.PublicKeys(signers...))

if len(config.Auth) == 0 {
return nil, fmt.Errorf("%w: no usable authentication method found", protocol.ErrAbort)
Expand Down Expand Up @@ -433,22 +433,22 @@ func (c *Connection) Connect() error {
return nil
}

func (c *Connection) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) {
func (c *Connection) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.Signer, error) {
if len(signers) == 0 {
return nil, fmt.Errorf("%w: signer not found for public key", protocol.ErrAbort)
}

for _, s := range signers {
if bytes.Equal(key.Marshal(), s.PublicKey().Marshal()) {
c.Log().Debug("signer for public key available in ssh agent")
return ssh.PublicKeys(s), nil
return s, nil
}
}

return nil, fmt.Errorf("%w: the provided key is a public key and is not known by agent", protocol.ErrAbort)
}

func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, error) {
func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.Signer, error) {
path, err := homedir.ExpandFile(path)
if err != nil {
return nil, fmt.Errorf("expand keyfile path: %w", err)
Expand All @@ -468,7 +468,7 @@ func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMeth
signer, err := ssh.ParsePrivateKey(key)
if err == nil {
c.Log().Debug("using an unencrypted private key", log.KeyFile, path)
return ssh.PublicKeys(signer), nil
return signer, nil
}

var ppErr *ssh.PassphraseMissingError
Expand All @@ -491,7 +491,7 @@ func (c *Connection) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMeth
if err != nil {
return nil, fmt.Errorf("%w: encrypted key %s decoding failed: %w", protocol.ErrAbort, path, err)
}
return ssh.PublicKeys(signer), nil
return signer, nil
}
}

Expand Down