diff --git a/.gitignore b/.gitignore index f425547e..f713869e 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,12 @@ darwin-arm64/sqlcmd linux-amd64/sqlcmd linux-arm64/sqlcmd linux-s390x/sqlcmd + +# Build artifacts in root +/sqlcmd +/sqlcmd_binary + +# certificates used for local testing +*.der +*.pem +*.pfx diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index bb0b4502..ea655b47 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -57,6 +57,7 @@ type SQLCmdArguments struct { ApplicationIntent string EncryptConnection string HostNameInCertificate string + ServerCertificate string DriverLoggingLevel int ExitOnError bool ErrorSeverityLevel uint8 @@ -127,6 +128,15 @@ const ( removeControlCharacters = "remove-control-characters" ) +func encryptConnectionAllowsTLS(value string) bool { + switch strings.ToLower(value) { + case "s", "strict", "m", "mandatory", "true", "t", "yes", "1": + return true + default: + return false + } +} + // Validate arguments for settings not describe func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { if a.ListServers != "" { @@ -144,6 +154,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { err = mutuallyExclusiveError("-E", `-U/-P`) case a.UseAad && len(a.AuthenticationMethod) > 0: err = mutuallyExclusiveError("-G", "--authentication-method") + case len(a.HostNameInCertificate) > 0 && len(a.ServerCertificate) > 0: + err = mutuallyExclusiveError("-F", "-J") case a.PacketSize != 0 && (a.PacketSize < 512 || a.PacketSize > 32767): err = localizer.Errorf(`'-a %#v': Packet size has to be a number between 512 and 32767.`, a.PacketSize) // Ignore 0 even though it's technically an invalid input @@ -157,6 +169,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { err = rangeParameterError("-y", fmt.Sprint(*a.VariableTypeWidth), 0, 8000, true) case a.QueryTimeout < 0 || a.QueryTimeout > 65534: err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true) + case a.ServerCertificate != "" && !encryptConnectionAllowsTLS(a.EncryptConnection): + err = localizer.Errorf("The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict).") } } if err != nil { @@ -429,6 +443,8 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().StringVarP(&args.ApplicationIntent, applicationIntent, "K", "default", localizer.Sprintf("Declares the application workload type when connecting to a server. The only currently supported value is ReadOnly. If %s is not specified, the sqlcmd utility will not support connectivity to a secondary replica in an Always On availability group", localizer.ApplicationIntentFlagShort)) rootCmd.Flags().StringVarP(&args.EncryptConnection, encryptConnection, "N", "default", localizer.Sprintf("This switch is used by the client to request an encrypted connection")) rootCmd.Flags().StringVarP(&args.HostNameInCertificate, "host-name-in-certificate", "F", "", localizer.Sprintf("Specifies the host name in the server certificate.")) + rootCmd.Flags().StringVarP(&args.ServerCertificate, "server-certificate", "J", "", localizer.Sprintf("Specifies the path to a server certificate file (PEM, DER, or CER) to match against the server's TLS certificate. Use when encryption is enabled (-N true, -N mandatory, or -N strict) for certificate pinning instead of standard certificate validation.")) + rootCmd.MarkFlagsMutuallyExclusive("host-name-in-certificate", "server-certificate") // Can't use NoOptDefVal until this fix: https://github.com/spf13/cobra/issues/866 //rootCmd.Flags().Lookup(encryptConnection).NoOptDefVal = "true" rootCmd.Flags().BoolVarP(&args.Vertical, "vertical", "", false, localizer.Sprintf("Prints the output in vertical format. This option sets the sqlcmd scripting variable %s to '%s'. The default is false", sqlcmd.SQLCMDFORMAT, "vert")) @@ -721,6 +737,7 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq connect.Encrypt = args.EncryptConnection } connect.HostNameInCertificate = args.HostNameInCertificate + connect.ServerCertificate = args.ServerCertificate connect.PacketSize = args.PacketSize connect.WorkstationName = args.WorkstationName connect.LogLevel = args.DriverLoggingLevel diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 73d06bc9..511816b2 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -111,6 +111,18 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-N", "s", "-F", "myserver.domain.com"}, func(args SQLCmdArguments) bool { return args.EncryptConnection == "s" && args.HostNameInCertificate == "myserver.domain.com" }}, + {[]string{"-N", "s", "-J", "/path/to/cert.pem"}, func(args SQLCmdArguments) bool { + return args.EncryptConnection == "s" && args.ServerCertificate == "/path/to/cert.pem" + }}, + {[]string{"-N", "strict", "-J", "/path/to/cert.der"}, func(args SQLCmdArguments) bool { + return args.EncryptConnection == "strict" && args.ServerCertificate == "/path/to/cert.der" + }}, + {[]string{"-N", "m", "-J", "/path/to/cert.cer"}, func(args SQLCmdArguments) bool { + return args.EncryptConnection == "m" && args.ServerCertificate == "/path/to/cert.cer" + }}, + {[]string{"-N", "true", "-J", "/path/to/cert2.pem"}, func(args SQLCmdArguments) bool { + return args.EncryptConnection == "true" && args.ServerCertificate == "/path/to/cert2.pem" + }}, } for _, test := range commands { @@ -154,7 +166,7 @@ func TestInvalidCommandLine(t *testing.T) { {[]string{"-E", "-U", "someuser"}, "The -E and the -U/-P options are mutually exclusive."}, {[]string{"-L", "-q", `"select 1"`}, "The -L parameter can not be used in combination with other parameters."}, {[]string{"-i", "foo.sql", "-q", `"select 1"`}, "The i and the -Q/-q options are mutually exclusive."}, - {[]string{"-r", "5"}, `'-r 5': Unexpected argument. Argument value has to be one of [0 1].`}, + {[]string{"-r", "5"}, "'-r 5': Unexpected argument. Argument value has to be one of [0 1]."}, {[]string{"-w", "x"}, "'-w x': value must be greater than 8 and less than 65536."}, {[]string{"-y", "111111"}, "'-y 111111': value must be greater than or equal to 0 and less than or equal to 8000."}, {[]string{"-Y", "-2"}, "'-Y -2': value must be greater than or equal to 0 and less than or equal to 8000."}, @@ -162,6 +174,10 @@ func TestInvalidCommandLine(t *testing.T) { {[]string{"-;"}, "';': Unknown Option. Enter '-?' for help."}, {[]string{"-t", "-2"}, "'-t -2': value must be greater than or equal to 0 and less than or equal to 65534."}, {[]string{"-N", "invalid"}, "'-N invalid': Unexpected argument. Argument value has to be one of [m[andatory] yes 1 t[rue] disable o[ptional] no 0 f[alse] s[trict]]."}, + {[]string{"-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, + {[]string{"-N", "optional", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, + {[]string{"-N", "disable", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, + {[]string{"-N", "strict", "-F", "myserver.domain.com", "-J", "/path/to/cert.pem"}, "The -F and the -J options are mutually exclusive."}, } for _, test := range commands { diff --git a/pkg/sqlcmd/connect.go b/pkg/sqlcmd/connect.go index 57342777..95af0871 100644 --- a/pkg/sqlcmd/connect.go +++ b/pkg/sqlcmd/connect.go @@ -60,6 +60,8 @@ type ConnectSettings struct { ChangePassword string // The HostNameInCertificate is the name to use for the host in the certificate validation HostNameInCertificate string + // ServerCertificate is the path to a certificate file to match against the server's TLS certificate + ServerCertificate string } func (c ConnectSettings) authenticationMethod() string { @@ -150,6 +152,9 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err if connect.HostNameInCertificate != "" { query.Add(msdsn.HostNameInCertificate, connect.HostNameInCertificate) } + if connect.ServerCertificate != "" { + query.Add(msdsn.ServerCertificate, connect.ServerCertificate) + } if connect.LogLevel > 0 { query.Add(msdsn.LogParam, fmt.Sprint(connect.LogLevel)) }