diff --git a/cmd/args.go b/cmd/args.go index c3b359a..20508ee 100644 --- a/cmd/args.go +++ b/cmd/args.go @@ -40,14 +40,6 @@ func registerNetworkFlags(fs *pflag.FlagSet) { //cmd.MarkFlagsMutuallyExclusive("no-epm", "epm-filter") } -// FUTURE: automatically stage & execute file -/* -func registerStageFlags(fs *pflag.FlagSet) { - fs.StringVarP(&stageFilePath, "stage", "E", "", "File to stage and execute") - //fs.StringVarP(&stageArgs ...) -} -*/ - func registerExecutionFlags(fs *pflag.FlagSet) { fs.StringVarP(&exec.Input.Executable, "exec", "e", "", "Remote Windows `executable` to invoke") fs.StringVarP(&exec.Input.Arguments, "args", "a", "", "Process command line arguments") @@ -57,6 +49,12 @@ func registerExecutionFlags(fs *pflag.FlagSet) { //cmd.MarkFlagsMutuallyExclusive("executable", "command") } +func registerExecutionUploadFlags(fs *pflag.FlagSet) { + fs.StringVar(&uploadSource, "upload", "", "Upload local `file` to remote filesystem") + fs.StringVar(&uploadDest, "upload-dest", "", "Remote destination `path` for uploaded file (default: random temp path)") + fs.BoolVar(&exec.Upload.NoConfirm, "no-upload-confirm", false, "Skip upload confirmation check") +} + func registerExecutionOutputFlags(fs *pflag.FlagSet) { fs.StringVarP(&outputPath, "out", "o", "", "Fetch execution output to `file` or \"-\" for standard output") fs.StringVarP(&outputMethod, "out-method", "m", "smb", "Method to fetch execution output") @@ -153,24 +151,46 @@ func argsRpcClient(proto string, endpoint string) func(cmd *cobra.Command, args func argsOutput(methods ...string) func(cmd *cobra.Command, args []string) error { - var as []func(*cobra.Command, []string) error + var as []func(*cobra.Command, []string) error - for _, method := range methods { - if method == "smb" { - as = append(as, argsSmbClient()) - } - } + for _, method := range methods { + if method == "smb" { + as = append(as, argsSmbClient()) + } + } - return args(append(as, func(*cobra.Command, []string) (err error) { + return args(append(as, func(*cobra.Command, []string) (err error) { - if outputPath != "" { - if outputPath == "-" { - exec.Output.Writer = os.Stdout + if outputPath != "" { + if outputPath == "-" { + exec.Output.Writer = os.Stdout - } else if exec.Output.Writer, err = os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); err != nil { - log.Fatal().Err(err).Msg("Failed to open output file") - } - } - return - })...) + } else if exec.Output.Writer, err = os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); err != nil { + log.Fatal().Err(err).Msg("Failed to open output file") + } + } + return + })...) +} + +func argsUpload(methods ...string) func(cmd *cobra.Command, args []string) error { + + var as []func(*cobra.Command, []string) error + + for _, method := range methods { + if method == "smb" { + as = append(as, argsSmbClient()) + } + } + + return args(append(as, func(*cobra.Command, []string) (err error) { + + if uploadSource != "" { + exec.Upload.Reader, err = os.Open(uploadSource) + if err != nil { + return fmt.Errorf("open upload file: %w", err) + } + } + return + })...) } diff --git a/cmd/dcom.go b/cmd/dcom.go index 79fbd67..bcd289b 100644 --- a/cmd/dcom.go +++ b/cmd/dcom.go @@ -61,127 +61,133 @@ func dcomVisualStudioCmdInit() { } func dcomMmcCmdInit() { - dcomMmcExecFlags := newFlagSet("Execution") - registerExecutionFlags(dcomMmcExecFlags.Flags) - registerExecutionOutputFlags(dcomMmcExecFlags.Flags) - dcomMmcExecFlags.Flags.StringVar(&dcomMmc.WorkingDirectory, "directory", `C:\`, "Working `directory`") - dcomMmcExecFlags.Flags.StringVar(&dcomMmc.WindowState, "window", "Minimized", "Window state") - - cmdFlags[dcomMmcCmd] = []*flagSet{ - dcomMmcExecFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - dcomMmcCmd.Flags().AddFlagSet(dcomMmcExecFlags.Flags) - - // Constraints - dcomMmcCmd.MarkFlagsOneRequired("command", "exec") + dcomMmcExecFlags := newFlagSet("Execution") + registerExecutionFlags(dcomMmcExecFlags.Flags) + registerExecutionOutputFlags(dcomMmcExecFlags.Flags) + registerExecutionUploadFlags(dcomMmcExecFlags.Flags) + dcomMmcExecFlags.Flags.StringVar(&dcomMmc.WorkingDirectory, "directory", `C:\`, "Working `directory`") + dcomMmcExecFlags.Flags.StringVar(&dcomMmc.WindowState, "window", "Minimized", "Window state") + + cmdFlags[dcomMmcCmd] = []*flagSet{ + dcomMmcExecFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + dcomMmcCmd.Flags().AddFlagSet(dcomMmcExecFlags.Flags) + + // Constraints + dcomMmcCmd.MarkFlagsOneRequired("command", "exec", "upload") } func dcomShellWindowsCmdInit() { - dcomShellWindowsExecFlags := newFlagSet("Execution") - registerExecutionFlags(dcomShellWindowsExecFlags.Flags) - registerExecutionOutputFlags(dcomShellWindowsExecFlags.Flags) - dcomShellWindowsExecFlags.Flags.StringVar(&dcomShellWindows.WorkingDirectory, "directory", `C:\`, "Working directory `path`") - dcomShellWindowsExecFlags.Flags.StringVar(&dcomShellWindows.WindowState, "app-window", "0", "Application window state `ID`") - - cmdFlags[dcomShellWindowsCmd] = []*flagSet{ - dcomShellWindowsExecFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - dcomShellWindowsCmd.Flags().AddFlagSet(dcomShellWindowsExecFlags.Flags) - - // Constraints - dcomShellWindowsCmd.MarkFlagsOneRequired("command", "exec") + dcomShellWindowsExecFlags := newFlagSet("Execution") + registerExecutionFlags(dcomShellWindowsExecFlags.Flags) + registerExecutionOutputFlags(dcomShellWindowsExecFlags.Flags) + registerExecutionUploadFlags(dcomShellWindowsExecFlags.Flags) + dcomShellWindowsExecFlags.Flags.StringVar(&dcomShellWindows.WorkingDirectory, "directory", `C:\`, "Working directory `path`") + dcomShellWindowsExecFlags.Flags.StringVar(&dcomShellWindows.WindowState, "app-window", "0", "Application window state `ID`") + + cmdFlags[dcomShellWindowsCmd] = []*flagSet{ + dcomShellWindowsExecFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + dcomShellWindowsCmd.Flags().AddFlagSet(dcomShellWindowsExecFlags.Flags) + + // Constraints + dcomShellWindowsCmd.MarkFlagsOneRequired("command", "exec", "upload") } func dcomShellBrowserWindowCmdInit() { - dcomShellBrowserWindowExecFlags := newFlagSet("Execution") - registerExecutionFlags(dcomShellBrowserWindowExecFlags.Flags) - registerExecutionOutputFlags(dcomShellBrowserWindowExecFlags.Flags) - dcomShellBrowserWindowExecFlags.Flags.StringVar(&dcomShellBrowserWindow.WorkingDirectory, "directory", `C:\`, "Working directory `path`") - dcomShellBrowserWindowExecFlags.Flags.StringVar(&dcomShellBrowserWindow.WindowState, "app-window", "0", "Application window state `ID`") - - cmdFlags[dcomShellBrowserWindowCmd] = []*flagSet{ - dcomShellBrowserWindowExecFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - dcomShellBrowserWindowCmd.Flags().AddFlagSet(dcomShellBrowserWindowExecFlags.Flags) - - // Constraints - dcomShellBrowserWindowCmd.MarkFlagsOneRequired("command", "exec") + dcomShellBrowserWindowExecFlags := newFlagSet("Execution") + registerExecutionFlags(dcomShellBrowserWindowExecFlags.Flags) + registerExecutionOutputFlags(dcomShellBrowserWindowExecFlags.Flags) + registerExecutionUploadFlags(dcomShellBrowserWindowExecFlags.Flags) + dcomShellBrowserWindowExecFlags.Flags.StringVar(&dcomShellBrowserWindow.WorkingDirectory, "directory", `C:\`, "Working directory `path`") + dcomShellBrowserWindowExecFlags.Flags.StringVar(&dcomShellBrowserWindow.WindowState, "app-window", "0", "Application window state `ID`") + + cmdFlags[dcomShellBrowserWindowCmd] = []*flagSet{ + dcomShellBrowserWindowExecFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + dcomShellBrowserWindowCmd.Flags().AddFlagSet(dcomShellBrowserWindowExecFlags.Flags) + + // Constraints + dcomShellBrowserWindowCmd.MarkFlagsOneRequired("command", "exec", "upload") } func dcomHtafileCmdInit() { - dcomHtafileExecFlags := newFlagSet("Execution") - dcomHtafileExecFlags.Flags.StringVarP(&dcomHtafile.Url, "url", "U", "", "Load custom `URL`") - dcomHtafileExecFlags.Flags.StringVar(&dcomHtafile.Javascript, "js", "", "Execute JavaScript one-liner") - dcomHtafileExecFlags.Flags.StringVar(&dcomHtafile.Vbscript, "vbs", "", "Execute VBScript one-liner") - registerExecutionFlags(dcomHtafileExecFlags.Flags) - registerExecutionOutputFlags(dcomHtafileExecFlags.Flags) - - cmdFlags[dcomHtafileCmd] = []*flagSet{ - dcomHtafileExecFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - dcomHtafileCmd.Flags().AddFlagSet(dcomHtafileExecFlags.Flags) - - // Constraints - dcomHtafileCmd.MarkFlagsOneRequired("command", "exec", "url", "js", "vbs") + dcomHtafileExecFlags := newFlagSet("Execution") + dcomHtafileExecFlags.Flags.StringVarP(&dcomHtafile.Url, "url", "U", "", "Load custom `URL`") + dcomHtafileExecFlags.Flags.StringVar(&dcomHtafile.Javascript, "js", "", "Execute JavaScript one-liner") + dcomHtafileExecFlags.Flags.StringVar(&dcomHtafile.Vbscript, "vbs", "", "Execute VBScript one-liner") + registerExecutionFlags(dcomHtafileExecFlags.Flags) + registerExecutionOutputFlags(dcomHtafileExecFlags.Flags) + registerExecutionUploadFlags(dcomHtafileExecFlags.Flags) + + cmdFlags[dcomHtafileCmd] = []*flagSet{ + dcomHtafileExecFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + dcomHtafileCmd.Flags().AddFlagSet(dcomHtafileExecFlags.Flags) + + // Constraints + dcomHtafileCmd.MarkFlagsOneRequired("command", "exec", "url", "js", "vbs", "upload") } func dcomExcelMacroCmdInit() { - dcomExcelMacroExecFlags := newFlagSet("Execution") - dcomExcelMacroExecFlags.Flags.StringArrayVarP(&dcomExcelMacro.Macros, "macro", "M", nil, "XLM macro `code`") - dcomExcelMacroExecFlags.Flags.StringVar(&dcomExcelMacro.MacroFile, "macro-file", "", "XLM macro `file`") - registerExecutionFlags(dcomExcelMacroExecFlags.Flags) - registerExecutionOutputFlags(dcomExcelMacroExecFlags.Flags) - - cmdFlags[dcomExcelMacroCmd] = []*flagSet{ - dcomExcelMacroExecFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - dcomExcelMacroCmd.Flags().AddFlagSet(dcomExcelMacroExecFlags.Flags) - - // Constraints - dcomExcelMacroCmd.MarkFlagsOneRequired("command", "exec", "macro", "macro-file") - dcomExcelMacroCmd.MarkFlagsMutuallyExclusive("command", "exec", "macro", "macro-file") - dcomExcelMacroCmd.MarkFlagsMutuallyExclusive("macro", "macro-file", "out") + dcomExcelMacroExecFlags := newFlagSet("Execution") + dcomExcelMacroExecFlags.Flags.StringArrayVarP(&dcomExcelMacro.Macros, "macro", "M", nil, "XLM macro `code`") + dcomExcelMacroExecFlags.Flags.StringVar(&dcomExcelMacro.MacroFile, "macro-file", "", "XLM macro `file`") + registerExecutionFlags(dcomExcelMacroExecFlags.Flags) + registerExecutionOutputFlags(dcomExcelMacroExecFlags.Flags) + registerExecutionUploadFlags(dcomExcelMacroExecFlags.Flags) + + cmdFlags[dcomExcelMacroCmd] = []*flagSet{ + dcomExcelMacroExecFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + dcomExcelMacroCmd.Flags().AddFlagSet(dcomExcelMacroExecFlags.Flags) + + // Constraints + dcomExcelMacroCmd.MarkFlagsOneRequired("command", "exec", "macro", "macro-file", "upload") + dcomExcelMacroCmd.MarkFlagsMutuallyExclusive("command", "exec", "macro", "macro-file") + dcomExcelMacroCmd.MarkFlagsMutuallyExclusive("macro", "macro-file", "out") } func dcomVisualStudioDteCmdInit() { - dcomVisualStudioDteVsFlags := newFlagSet("Visual Studio") - dcomVisualStudioDteVsFlags.Flags.BoolVar(&dcomVisualStudioDte.Is2019, "vs-2019", false, "Target Visual Studio 2019") - dcomVisualStudioDteVsFlags.Flags.StringVar(&dcomVisualStudioDte.CommandName, "vs-command", "", "Visual Studio DTE command to execute") - dcomVisualStudioDteVsFlags.Flags.StringVar(&dcomVisualStudioDte.CommandArgs, "vs-args", "", "Visual Studio DTE command arguments") - - dcomVisualStudioDteExecFlags := newFlagSet("Execution") - registerExecutionFlags(dcomVisualStudioDteExecFlags.Flags) - registerExecutionOutputFlags(dcomVisualStudioDteExecFlags.Flags) - - cmdFlags[dcomVisualStudioDteCmd] = []*flagSet{ - dcomVisualStudioDteVsFlags, - dcomVisualStudioDteExecFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - dcomVisualStudioDteCmd.Flags().AddFlagSet(dcomVisualStudioDteVsFlags.Flags) - dcomVisualStudioDteCmd.Flags().AddFlagSet(dcomVisualStudioDteExecFlags.Flags) - - // Constraints - dcomVisualStudioDteCmd.MarkFlagsOneRequired("command", "exec", "vs-command") - dcomVisualStudioDteCmd.MarkFlagsMutuallyExclusive("command", "exec", "vs-command") - dcomVisualStudioDteCmd.MarkFlagsMutuallyExclusive("vs-command", "out") + dcomVisualStudioDteVsFlags := newFlagSet("Visual Studio") + dcomVisualStudioDteVsFlags.Flags.BoolVar(&dcomVisualStudioDte.Is2019, "vs-2019", false, "Target Visual Studio 2019") + dcomVisualStudioDteVsFlags.Flags.StringVar(&dcomVisualStudioDte.CommandName, "vs-command", "", "Visual Studio DTE command to execute") + dcomVisualStudioDteVsFlags.Flags.StringVar(&dcomVisualStudioDte.CommandArgs, "vs-args", "", "Visual Studio DTE command arguments") + + dcomVisualStudioDteExecFlags := newFlagSet("Execution") + registerExecutionFlags(dcomVisualStudioDteExecFlags.Flags) + registerExecutionOutputFlags(dcomVisualStudioDteExecFlags.Flags) + registerExecutionUploadFlags(dcomVisualStudioDteExecFlags.Flags) + + cmdFlags[dcomVisualStudioDteCmd] = []*flagSet{ + dcomVisualStudioDteVsFlags, + dcomVisualStudioDteExecFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + dcomVisualStudioDteCmd.Flags().AddFlagSet(dcomVisualStudioDteVsFlags.Flags) + dcomVisualStudioDteCmd.Flags().AddFlagSet(dcomVisualStudioDteExecFlags.Flags) + + // Constraints + dcomVisualStudioDteCmd.MarkFlagsOneRequired("command", "exec", "vs-command", "upload") + dcomVisualStudioDteCmd.MarkFlagsMutuallyExclusive("command", "exec", "vs-command") + dcomVisualStudioDteCmd.MarkFlagsMutuallyExclusive("vs-command", "out") } func dcomExcelXllCmdInit() { @@ -240,12 +246,13 @@ var ( Long: `Description: The mmc method uses the exposed MMC20.Application object to call Document.ActiveView.ShellExec, and ultimately spawn a process on the remote host.`, - Args: args(argsRpcClient("cifs", ""), - argsOutput("smb"), - argsAcceptValues("window", &dcomMmc.WindowState, "Minimized", "Maximized", "Restored"), - ), - Run: func(cmd *cobra.Command, args []string) { - dcomMmc.Client = &rpcClient + Args: args(argsRpcClient("cifs", ""), + argsOutput("smb"), + argsUpload("smb"), + argsAcceptValues("window", &dcomMmc.WindowState, "Minimized", "Maximized", "Restored"), + ), + Run: func(cmd *cobra.Command, args []string) { + dcomMmc.Client = &rpcClient ctx := log.With().Str("module", dcomexec.ModuleName).Str("method", dcomexec.MethodMmc). Logger().WithContext(gssapi.NewSecurityContext(context.Background())) @@ -261,12 +268,13 @@ var ( Long: `Description: The shellwindows method uses the exposed ShellWindows DCOM object on older Windows installations to call Item().Document.Application.ShellExecute, and spawn the provided process.`, - Args: args(argsRpcClient("host", ""), - argsOutput("smb"), - argsAcceptValues("app-window", &dcomShellWindows.WindowState, "0", "1", "2", "3", "4", "5", "7", "10"), - ), - Run: func(cmd *cobra.Command, args []string) { - dcomShellWindows.Client = &rpcClient + Args: args(argsRpcClient("host", ""), + argsOutput("smb"), + argsUpload("smb"), + argsAcceptValues("app-window", &dcomShellWindows.WindowState, "0", "1", "2", "3", "4", "5", "7", "10"), + ), + Run: func(cmd *cobra.Command, args []string) { + dcomShellWindows.Client = &rpcClient ctx := log.With().Str("module", dcomexec.ModuleName).Str("method", dcomexec.MethodShellWindows). Logger().WithContext(gssapi.NewSecurityContext(context.Background())) @@ -282,12 +290,13 @@ var ( Long: `Description: The shellbrowserwindow method uses the exposed ShellBrowserWindow DCOM object on older Windows installations to call Document.Application.ShellExecute, and spawn the provided process.`, - Args: args(argsRpcClient("host", ""), - argsOutput("smb"), - argsAcceptValues("app-window", &dcomShellBrowserWindow.WindowState, "0", "1", "2", "3", "4", "5", "7", "10"), - ), - Run: func(cmd *cobra.Command, args []string) { - dcomShellBrowserWindow.Client = &rpcClient + Args: args(argsRpcClient("host", ""), + argsOutput("smb"), + argsUpload("smb"), + argsAcceptValues("app-window", &dcomShellBrowserWindow.WindowState, "0", "1", "2", "3", "4", "5", "7", "10"), + ), + Run: func(cmd *cobra.Command, args []string) { + dcomShellBrowserWindow.Client = &rpcClient ctx := log.With().Str("module", dcomexec.ModuleName).Str("method", dcomexec.MethodShellBrowserWindow). Logger().WithContext(gssapi.NewSecurityContext(context.Background())) @@ -303,9 +312,9 @@ var ( Long: `Description: The htafile method uses the exposed "HTML Application" DCOM object to load a remote HTA application or execute inline. This is made possible by the Load method of the IPersistMoniker interface.`, - Args: args(argsRpcClient("host", ""), argsOutput("smb")), - RunE: func(cmd *cobra.Command, args []string) error { - dcomHtafile.Client = &rpcClient + Args: args(argsRpcClient("host", ""), argsOutput("smb"), argsUpload("smb")), + RunE: func(cmd *cobra.Command, args []string) error { + dcomHtafile.Client = &rpcClient dcomHtafile.Url = dcomexec.HtafileGetUrl(dcomHtafile.Url, dcomHtafile.Javascript, dcomHtafile.Vbscript, &exec) if url := strings.ToLower(dcomHtafile.Url); (strings.HasPrefix(url, "javascript:") || strings.HasPrefix(url, "vbscript:")) && len(url) > 508 { @@ -327,9 +336,9 @@ var ( Long: `Description: The macro method uses the exposed Excel.Application DCOM object to call ExecuteExcel4Macro, thus executing XLM macros at will. This method requires that the remote host has Microsoft Excel installed.`, - Args: args(argsRpcClient("host", ""), argsOutput("smb"), - func(*cobra.Command, []string) error { - if dcomExcelMacro.MacroFile != "" { + Args: args(argsRpcClient("host", ""), argsOutput("smb"), argsUpload("smb"), + func(*cobra.Command, []string) error { + if dcomExcelMacro.MacroFile != "" { f, err := os.Open(dcomExcelMacro.MacroFile) if err != nil { return fmt.Errorf("open macro file: %w", err) @@ -380,9 +389,9 @@ var ( Long: `Description: The dte method uses the exposed VisualStudio.DTE object to spawn a process via the ExecuteCommand method. This method requires that the remote host has Microsoft Visual Studio installed.`, - Args: args(argsRpcClient("host", ""), argsOutput("smb")), - Run: func(*cobra.Command, []string) { - dcomVisualStudioDte.Client = &rpcClient + Args: args(argsRpcClient("host", ""), argsOutput("smb"), argsUpload("smb")), + Run: func(*cobra.Command, []string) { + dcomVisualStudioDte.Client = &rpcClient ctx := log.With().Str("module", dcomexec.ModuleName).Str("method", dcomexec.MethodVisualStudioDTE). Logger().WithContext(gssapi.NewSecurityContext(context.Background())) diff --git a/cmd/root.go b/cmd/root.go index 208b58d..6315736 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -60,11 +60,12 @@ var ( returnCode int toClose []io.Closer - // === IO === - //stageFilePath string // FUTURE - outputMethod string - outputPath string - // ========== + // === IO === + outputMethod string + outputPath string + uploadSource string + uploadDest string + // ========== // === Logging === logJson bool // Log output in JSON lines @@ -89,10 +90,11 @@ var ( memProfileFile io.WriteCloser // ========================== - exec = goexec.ExecutionIO{ - Input: new(goexec.ExecutionInput), - Output: new(goexec.ExecutionOutput), - } + exec = goexec.ExecutionIO{ + Input: new(goexec.ExecutionInput), + Output: new(goexec.ExecutionOutput), + Upload: new(goexec.ExecutionUpload), + } adAuthOpts *adauth.Options credential *adauth.Credential @@ -167,50 +169,69 @@ Authors: FalconOps LLC (@FalconOpsLLC), smbClient.Proxy = proxy } - if outputPath != "" { - if outputMethod == "smb" { - if exec.Output.RemotePath == "" { - exec.Output.RemotePath = `C:\Windows\Temp\` + uuid.NewString() - } - exec.Output.Provider = &smb.OutputFileFetcher{ - Client: &smbClient, - Share: `ADMIN$`, // TODO: dynamic - SharePath: `C:\Windows`, - File: exec.Output.RemotePath, - DeleteOutputFile: !exec.Output.NoDelete, - } - } - } - return - }, - - PersistentPostRun: func(cmd *cobra.Command, args []string) { - - if memProfileFile != nil { - if err := pprof.WriteHeapProfile(memProfileFile); err != nil { - log.Error().Err(err).Msg("Failed to write memory profile") - return - } - } - - if cpuProfileFile != nil { - pprof.StopCPUProfile() - } - - if exec.Input != nil && exec.Input.StageFile != nil { - if err := exec.Input.StageFile.Close(); err != nil { - log.Warn().Err(err).Msg("Failed to close stage file") - } - } - - for _, c := range toClose { - if c != nil { - if err := c.Close(); err != nil { - log.Warn().Err(err).Msg("Failed to close stream") - } - } - } - }, + if outputPath != "" { + if outputMethod == "smb" { + if exec.Output.RemotePath == "" { + exec.Output.RemotePath = `C:\Windows\Temp\` + uuid.NewString() + } + exec.Output.Provider = &smb.OutputFileFetcher{ + Client: &smbClient, + Share: `ADMIN$`, // TODO: dynamic + SharePath: `C:\Windows`, + File: exec.Output.RemotePath, + DeleteOutputFile: !exec.Output.NoDelete, + } + } + } + + if uploadSource != "" { + if uploadDest == "" { + uploadDest = `C:\Windows\Temp\` + uuid.NewString() + } + exec.Upload.RemotePath = uploadDest + exec.Upload.Provider = &smb.FileStager{ + Client: &smbClient, + Share: `C$`, + SharePath: `C:\`, + File: uploadDest, + } + } + return + }, + + PersistentPostRun: func(cmd *cobra.Command, args []string) { + + if memProfileFile != nil { + if err := pprof.WriteHeapProfile(memProfileFile); err != nil { + log.Error().Err(err).Msg("Failed to write memory profile") + return + } + } + + if cpuProfileFile != nil { + pprof.StopCPUProfile() + } + + if exec.Input != nil && exec.Input.StageFile != nil { + if err := exec.Input.StageFile.Close(); err != nil { + log.Warn().Err(err).Msg("Failed to close stage file") + } + } + + if exec.Upload != nil && exec.Upload.Reader != nil { + if err := exec.Upload.Reader.Close(); err != nil { + log.Warn().Err(err).Msg("Failed to close upload file") + } + } + + for _, c := range toClose { + if c != nil { + if err := c.Close(); err != nil { + log.Warn().Err(err).Msg("Failed to close stream") + } + } + } + }, } ) diff --git a/cmd/scmr.go b/cmd/scmr.go index 0b45287..354768b 100644 --- a/cmd/scmr.go +++ b/cmd/scmr.go @@ -28,76 +28,73 @@ func scmrCmdInit() { } func scmrCreateCmdInit() { - scmrCreateFlags := newFlagSet("Service") - - scmrCreateFlags.Flags.StringVarP(&scmrCreate.DisplayName, "display-name", "n", "", "Display name of service to create") - scmrCreateFlags.Flags.StringVarP(&scmrCreate.ServiceName, "service", "s", "", "Name of service to create") - scmrCreateFlags.Flags.BoolVar(&scmrCreate.NoDelete, "no-delete", false, "Don't delete service after execution") - scmrCreateFlags.Flags.BoolVar(&scmrCreate.NoStart, "no-start", false, "Don't start service") - - scmrCreateExecFlags := newFlagSet("Execution") - - // TODO: SCMR output - //registerExecutionOutputFlags(scmrCreateExecFlags.Flags) - - scmrCreateExecFlags.Flags.StringVarP(&exec.Input.ExecutablePath, "executable-path", "f", "", "Full path to a remote Windows executable") - scmrCreateExecFlags.Flags.StringVarP(&exec.Input.Arguments, "args", "a", "", "Arguments to pass to the executable") - - scmrCreateCmd.Flags().AddFlagSet(scmrCreateFlags.Flags) - scmrCreateCmd.Flags().AddFlagSet(scmrCreateExecFlags.Flags) - - cmdFlags[scmrCreateCmd] = []*flagSet{ - scmrCreateExecFlags, - scmrCreateFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - - // Constraints - { - //scmrCreateCmd.MarkFlagsMutuallyExclusive("no-delete", "no-start") - if err := scmrCreateCmd.MarkFlagRequired("executable-path"); err != nil { - panic(err) - } - } + scmrCreateFlags := newFlagSet("Service") + + scmrCreateFlags.Flags.StringVarP(&scmrCreate.DisplayName, "display-name", "n", "", "Display name of service to create") + scmrCreateFlags.Flags.StringVarP(&scmrCreate.ServiceName, "service", "s", "", "Name of service to create") + scmrCreateFlags.Flags.BoolVar(&scmrCreate.NoDelete, "no-delete", false, "Don't delete service after execution") + scmrCreateFlags.Flags.BoolVar(&scmrCreate.NoStart, "no-start", false, "Don't start service") + + scmrCreateExecFlags := newFlagSet("Execution") + + // TODO: SCMR output + //registerExecutionOutputFlags(scmrCreateExecFlags.Flags) + registerExecutionUploadFlags(scmrCreateExecFlags.Flags) + + scmrCreateExecFlags.Flags.StringVarP(&exec.Input.ExecutablePath, "executable-path", "f", "", "Full path to a remote Windows executable") + scmrCreateExecFlags.Flags.StringVarP(&exec.Input.Arguments, "args", "a", "", "Arguments to pass to the executable") + + scmrCreateCmd.Flags().AddFlagSet(scmrCreateFlags.Flags) + scmrCreateCmd.Flags().AddFlagSet(scmrCreateExecFlags.Flags) + + cmdFlags[scmrCreateCmd] = []*flagSet{ + scmrCreateExecFlags, + scmrCreateFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + + // Constraints + { + //scmrCreateCmd.MarkFlagsMutuallyExclusive("no-delete", "no-start") + scmrCreateCmd.MarkFlagsOneRequired("executable-path", "upload") + } } func scmrChangeCmdInit() { - scmrChangeFlags := newFlagSet("Service Control") - - scmrChangeFlags.Flags.StringVarP(&scmrChange.ServiceName, "service-name", "s", "", "Name of service to modify") - scmrChangeFlags.Flags.BoolVar(&scmrChange.NoStart, "no-start", false, "Don't start service") - - scmrChangeExecFlags := newFlagSet("Execution") - - scmrChangeExecFlags.Flags.StringVarP(&exec.Input.ExecutablePath, "executable-path", "f", "", "Full path to remote Windows executable") - scmrChangeExecFlags.Flags.StringVarP(&exec.Input.Arguments, "args", "a", "", "Arguments to pass to executable") - - // TODO: SCMR output - //registerExecutionOutputFlags(scmrChangeExecFlags.Flags) - //registerStageFlags(scmrChangeExecFlags.Flags) - - cmdFlags[scmrChangeCmd] = []*flagSet{ - scmrChangeFlags, - scmrChangeExecFlags, - defaultAuthFlags, - defaultLogFlags, - defaultNetRpcFlags, - } - - scmrChangeCmd.Flags().AddFlagSet(scmrChangeFlags.Flags) - scmrChangeCmd.Flags().AddFlagSet(scmrChangeExecFlags.Flags) - - // Constraints - { - if err := scmrChangeCmd.MarkFlagRequired("service-name"); err != nil { - panic(err) - } - if err := scmrCreateCmd.MarkFlagRequired("executable-path"); err != nil { - panic(err) - } - } + scmrChangeFlags := newFlagSet("Service Control") + + scmrChangeFlags.Flags.StringVarP(&scmrChange.ServiceName, "service-name", "s", "", "Name of service to modify") + scmrChangeFlags.Flags.BoolVar(&scmrChange.NoStart, "no-start", false, "Don't start service") + + scmrChangeExecFlags := newFlagSet("Execution") + + scmrChangeExecFlags.Flags.StringVarP(&exec.Input.ExecutablePath, "executable-path", "f", "", "Full path to remote Windows executable") + scmrChangeExecFlags.Flags.StringVarP(&exec.Input.Arguments, "args", "a", "", "Arguments to pass to executable") + + // TODO: SCMR output + //registerExecutionOutputFlags(scmrChangeExecFlags.Flags) + registerExecutionUploadFlags(scmrChangeExecFlags.Flags) + + cmdFlags[scmrChangeCmd] = []*flagSet{ + scmrChangeFlags, + scmrChangeExecFlags, + defaultAuthFlags, + defaultLogFlags, + defaultNetRpcFlags, + } + + scmrChangeCmd.Flags().AddFlagSet(scmrChangeFlags.Flags) + scmrChangeCmd.Flags().AddFlagSet(scmrChangeExecFlags.Flags) + + // Constraints + { + if err := scmrChangeCmd.MarkFlagRequired("service-name"); err != nil { + panic(err) + } + scmrChangeCmd.MarkFlagsOneRequired("executable-path", "upload") + } } func scmrDeleteCmdInit() { @@ -139,13 +136,14 @@ var ( Long: `Description: The create method calls RCreateServiceW to create a new Windows service on the remote target with the provided executable & arguments as the lpBinaryPathName`, - Args: args( - argsRpcClient("cifs", "ncacn_np:[svcctl]"), - argsSmbClient(), - ), - - Run: func(cmd *cobra.Command, args []string) { - scmrCreate.Client = &rpcClient + Args: args( + argsRpcClient("cifs", "ncacn_np:[svcctl]"), + argsSmbClient(), + argsUpload("smb"), + ), + + Run: func(cmd *cobra.Command, args []string) { + scmrCreate.Client = &rpcClient scmrCreate.IO = exec log = log.With(). @@ -184,10 +182,13 @@ var ( using the RChangeServiceConfigW method rather than calling RCreateServiceW like scmr create. The modified service is restored to its original state after execution`, - Args: argsRpcClient("cifs", "ncacn_np:[svcctl]"), + Args: args( + argsRpcClient("cifs", "ncacn_np:[svcctl]"), + argsUpload("smb"), + ), - Run: func(cmd *cobra.Command, args []string) { - scmrChange.Client = &rpcClient + Run: func(cmd *cobra.Command, args []string) { + scmrChange.Client = &rpcClient scmrChange.IO = exec ctx := log.With(). diff --git a/cmd/tsch.go b/cmd/tsch.go index a9b0d71..8803ca3 100644 --- a/cmd/tsch.go +++ b/cmd/tsch.go @@ -36,12 +36,13 @@ func tschDemandCmdInit() { tschDemandFlags.Flags.StringVar(&tschDemand.UserSid, "sid", "S-1-5-18", "User `SID` to impersonate") tschDemandFlags.Flags.BoolVar(&tschDemand.NoDelete, "no-delete", false, "Don't delete task after execution") - tschDemandExecFlags := newFlagSet("Execution") + tschDemandExecFlags := newFlagSet("Execution") - registerExecutionFlags(tschDemandExecFlags.Flags) - registerExecutionOutputFlags(tschDemandExecFlags.Flags) + registerExecutionFlags(tschDemandExecFlags.Flags) + registerExecutionOutputFlags(tschDemandExecFlags.Flags) + registerExecutionUploadFlags(tschDemandExecFlags.Flags) - cmdFlags[tschDemandCmd] = []*flagSet{ + cmdFlags[tschDemandCmd] = []*flagSet{ tschDemandFlags, tschDemandExecFlags, defaultAuthFlags, @@ -49,9 +50,9 @@ func tschDemandCmdInit() { defaultNetRpcFlags, } - tschDemandCmd.Flags().AddFlagSet(tschDemandFlags.Flags) - tschDemandCmd.Flags().AddFlagSet(tschDemandExecFlags.Flags) - tschDemandCmd.MarkFlagsOneRequired("exec", "command") + tschDemandCmd.Flags().AddFlagSet(tschDemandFlags.Flags) + tschDemandCmd.Flags().AddFlagSet(tschDemandExecFlags.Flags) + tschDemandCmd.MarkFlagsOneRequired("exec", "command", "upload") } func tschCreateCmdInit() { @@ -65,12 +66,13 @@ func tschCreateCmdInit() { tschCreateFlags.Flags.BoolVar(&tschCreate.CallDelete, "call-delete", false, "Directly call SchRpcDelete to delete task") tschCreateFlags.Flags.StringVar(&tschCreate.UserSid, "sid", "S-1-5-18", "User `SID` to impersonate") - tschCreateExecFlags := newFlagSet("Execution") + tschCreateExecFlags := newFlagSet("Execution") - registerExecutionFlags(tschCreateExecFlags.Flags) - registerExecutionOutputFlags(tschCreateExecFlags.Flags) + registerExecutionFlags(tschCreateExecFlags.Flags) + registerExecutionOutputFlags(tschCreateExecFlags.Flags) + registerExecutionUploadFlags(tschCreateExecFlags.Flags) - cmdFlags[tschCreateCmd] = []*flagSet{ + cmdFlags[tschCreateCmd] = []*flagSet{ tschCreateFlags, tschCreateExecFlags, defaultAuthFlags, @@ -78,9 +80,9 @@ func tschCreateCmdInit() { defaultNetRpcFlags, } - tschCreateCmd.Flags().AddFlagSet(tschCreateFlags.Flags) - tschCreateCmd.Flags().AddFlagSet(tschCreateExecFlags.Flags) - tschCreateCmd.MarkFlagsOneRequired("exec", "command") + tschCreateCmd.Flags().AddFlagSet(tschCreateFlags.Flags) + tschCreateCmd.Flags().AddFlagSet(tschCreateExecFlags.Flags) + tschCreateCmd.MarkFlagsOneRequired("exec", "command", "upload") } func tschChangeCmdInit() { @@ -90,12 +92,13 @@ func tschChangeCmdInit() { tschChangeFlags.Flags.BoolVar(&tschChange.NoStart, "no-start", false, "Don't start the task") tschChangeFlags.Flags.BoolVar(&tschChange.NoRevert, "no-revert", false, "Don't restore the original task definition") - tschChangeExecFlags := newFlagSet("Execution") + tschChangeExecFlags := newFlagSet("Execution") - registerExecutionFlags(tschChangeExecFlags.Flags) - registerExecutionOutputFlags(tschChangeExecFlags.Flags) + registerExecutionFlags(tschChangeExecFlags.Flags) + registerExecutionOutputFlags(tschChangeExecFlags.Flags) + registerExecutionUploadFlags(tschChangeExecFlags.Flags) - cmdFlags[tschChangeCmd] = []*flagSet{ + cmdFlags[tschChangeCmd] = []*flagSet{ tschChangeFlags, tschChangeExecFlags, defaultAuthFlags, @@ -108,10 +111,10 @@ func tschChangeCmdInit() { // Constraints { - if err := tschChangeCmd.MarkFlagRequired("task"); err != nil { - panic(err) - } - tschChangeCmd.MarkFlagsOneRequired("exec", "command") + if err := tschChangeCmd.MarkFlagRequired("task"); err != nil { + panic(err) + } + tschChangeCmd.MarkFlagsOneRequired("exec", "command", "upload") } } @@ -152,14 +155,15 @@ var ( Similar to the create method, the demand method will call SchRpcRegisterTask, But rather than setting a defined time when the task will start, it will additionally call SchRpcRun to forcefully start the task.`, - Args: args( - argsRpcClient("cifs", "ncacn_np:[atsvc]"), - argsOutput("smb"), - argsTask, - ), - - Run: func(*cobra.Command, []string) { - tschDemand.Client = &rpcClient + Args: args( + argsRpcClient("cifs", "ncacn_np:[atsvc]"), + argsOutput("smb"), + argsUpload("smb"), + argsTask, + ), + + Run: func(*cobra.Command, []string) { + tschDemand.Client = &rpcClient tschDemand.TaskPath = tschTask ctx := log.With(). @@ -180,14 +184,15 @@ var ( with an automatic start time.This method avoids directly calling SchRpcRun, and can even avoid calling SchRpcDelete by populating the DeleteExpiredTaskAfter Setting.`, - Args: args( - argsRpcClient("cifs", "ncacn_np:[atsvc]"), - argsOutput("smb"), - argsTask, - ), - - Run: func(*cobra.Command, []string) { - tschCreate.Client = &rpcClient + Args: args( + argsRpcClient("cifs", "ncacn_np:[atsvc]"), + argsOutput("smb"), + argsUpload("smb"), + argsTask, + ), + + Run: func(*cobra.Command, []string) { + tschCreate.Client = &rpcClient tschCreate.TaskPath = tschTask ctx := log.With(). @@ -206,14 +211,15 @@ var ( Long: `Description: The change method calls SchRpcRetrieveTask to fetch the definition of an existing task (-t), then modifies the task definition to spawn a process`, - Args: args( - argsRpcClient("cifs", "ncacn_np:[atsvc]"), - argsOutput("smb"), - - func(*cobra.Command, []string) error { - return tschexec.ValidateTaskPath(tschChange.TaskPath) - }, - ), + Args: args( + argsRpcClient("cifs", "ncacn_np:[atsvc]"), + argsOutput("smb"), + argsUpload("smb"), + + func(*cobra.Command, []string) error { + return tschexec.ValidateTaskPath(tschChange.TaskPath) + }, + ), Run: func(*cobra.Command, []string) { tschChange.Client = &rpcClient diff --git a/cmd/wmi.go b/cmd/wmi.go index e63feba..009fac2 100644 --- a/cmd/wmi.go +++ b/cmd/wmi.go @@ -50,12 +50,13 @@ func wmiCallCmdInit() { } func wmiProcCmdInit() { - wmiProcExecFlags := newFlagSet("Execution") + wmiProcExecFlags := newFlagSet("Execution") - registerExecutionFlags(wmiProcExecFlags.Flags) - registerExecutionOutputFlags(wmiProcExecFlags.Flags) + registerExecutionFlags(wmiProcExecFlags.Flags) + registerExecutionOutputFlags(wmiProcExecFlags.Flags) + registerExecutionUploadFlags(wmiProcExecFlags.Flags) - wmiProcExecFlags.Flags.StringVarP(&wmiProc.WorkingDirectory, "directory", "d", `C:\`, "Working directory") + wmiProcExecFlags.Flags.StringVarP(&wmiProc.WorkingDirectory, "directory", "d", `C:\`, "Working directory") cmdFlags[wmiProcCmd] = []*flagSet{ wmiProcExecFlags, @@ -117,13 +118,14 @@ var ( The proc method creates an instance of the Win32_Process WMI class, then calls the Win32_Process.Create method with the provided command (-c), and optional working directory (-d).`, - Args: args( - argsRpcClient("cifs", ""), - argsOutput("smb"), - ), - - Run: func(cmd *cobra.Command, args []string) { - wmiProc.Client = &rpcClient + Args: args( + argsRpcClient("cifs", ""), + argsOutput("smb"), + argsUpload("smb"), + ), + + Run: func(cmd *cobra.Command, args []string) { + wmiProc.Client = &rpcClient wmiProc.IO = exec wmiProc.Resource = "//./root/cimv2" diff --git a/pkg/goexec/io.go b/pkg/goexec/io.go index 874dd77..d372119 100644 --- a/pkg/goexec/io.go +++ b/pkg/goexec/io.go @@ -9,23 +9,48 @@ import ( ) type OutputProvider interface { - GetOutput(ctx context.Context, writer io.Writer) (err error) - Clean(ctx context.Context) (err error) + GetOutput(ctx context.Context, writer io.Writer) (err error) + Clean(ctx context.Context) (err error) +} + +type InputProvider interface { + Upload(ctx context.Context, reader io.Reader) (err error) + Clean(ctx context.Context) (err error) } type ExecutionIO struct { - Cleaner + Cleaner - Input *ExecutionInput - Output *ExecutionOutput + Input *ExecutionInput + Output *ExecutionOutput + Upload *ExecutionUpload } type ExecutionOutput struct { - NoDelete bool - RemotePath string - Timeout time.Duration - Provider OutputProvider - Writer io.WriteCloser + NoDelete bool + RemotePath string + Timeout time.Duration + Provider OutputProvider + Writer io.WriteCloser +} + +type ExecutionUpload struct { + NoConfirm bool + RemotePath string + Provider InputProvider + Reader io.ReadCloser +} + +// UploadConfirmer is an optional interface that InputProvider implementations +// can satisfy to confirm a file was successfully uploaded. +type UploadConfirmer interface { + ConfirmUpload(ctx context.Context) error +} + +// UploadRemover is an optional interface that InputProvider implementations +// can satisfy to remove a previously uploaded file from the remote filesystem. +type UploadRemover interface { + RemoveUploadedFile(ctx context.Context) error } type ExecutionInput struct { @@ -36,6 +61,20 @@ type ExecutionInput struct { Command string } +func (execIO *ExecutionIO) DoUpload(ctx context.Context) (err error) { + if execIO.Upload != nil && execIO.Upload.Provider != nil && execIO.Upload.Reader != nil { + return execIO.Upload.Provider.Upload(ctx, execIO.Upload.Reader) + } + return nil +} + +func (execIO *ExecutionIO) CleanUpload(ctx context.Context) (err error) { + if execIO.Upload != nil && execIO.Upload.Provider != nil { + return execIO.Upload.Provider.Clean(ctx) + } + return nil +} + func (execIO *ExecutionIO) GetOutput(ctx context.Context) (err error) { if execIO.Output.Provider != nil { ctx = context.WithValue(ctx, ContextOptionOutputTimeout, execIO.Output.Timeout) diff --git a/pkg/goexec/method.go b/pkg/goexec/method.go index 7fa6d94..bed980e 100644 --- a/pkg/goexec/method.go +++ b/pkg/goexec/method.go @@ -101,31 +101,82 @@ func ExecuteCleanAuxiliaryMethod(ctx context.Context, module CleanAuxiliaryMetho } func ExecuteCleanMethod(ctx context.Context, module CleanExecutionMethod, execIO *ExecutionIO) (err error) { - log := zerolog.Ctx(ctx) - - if err = ExecuteMethod(ctx, module, execIO); err != nil { - return - } - - if err = module.Clean(ctx); err != nil { - log.Error().Err(err).Msg("Module cleanup failed") - err = nil - } - - if execIO.Output != nil && execIO.Output.Provider != nil { - log.Info().Msg("Collecting output") - - defer func() { - if cleanErr := execIO.Clean(ctx); cleanErr != nil { - log.Debug().Err(cleanErr).Msg("Output provider cleanup failed") - } - }() - - if err := execIO.GetOutput(ctx); err != nil { - log.Error().Err(err).Msg("Output collection failed") - return fmt.Errorf("get output: %w", err) - } - log.Debug().Msg("Output collection succeeded") - } - return + log := zerolog.Ctx(ctx) + + // Connect + if err = module.Connect(ctx); err != nil { + log.Error().Err(err).Msg("Connection failed") + return fmt.Errorf("connect: %w", err) + } + log.Debug().Msg("Module connected") + + // Init + if err = module.Init(ctx); err != nil { + log.Error().Err(err).Msg("Module initialization failed") + return fmt.Errorf("init module: %w", err) + } + log.Debug().Msg("Module initialized") + + // Upload file (before execution) + if execIO.Upload != nil && execIO.Upload.Provider != nil { + log.Info().Str("dest", execIO.Upload.RemotePath).Msg("Uploading file") + if err = execIO.DoUpload(ctx); err != nil { + log.Error().Err(err).Msg("Upload failed") + return fmt.Errorf("upload: %w", err) + } + if !execIO.Upload.NoConfirm { + if confirmer, ok := execIO.Upload.Provider.(UploadConfirmer); ok { + if err = confirmer.ConfirmUpload(ctx); err != nil { + log.Warn().Err(err).Msg("Upload confirmation failed") + } + } + } + // Clean up upload provider resources (close file handles) + if cleanErr := execIO.CleanUpload(ctx); cleanErr != nil { + log.Debug().Err(cleanErr).Msg("Upload cleanup failed") + } + } + + // Execute (only if a command/executable was provided) + executed := false + if execIO.Input != nil && (execIO.Input.Executable != "" || execIO.Input.Command != "" || execIO.Input.ExecutablePath != "") { + if err = module.Execute(ctx, execIO); err != nil { + log.Error().Err(err).Msg("Execution failed") + return fmt.Errorf("execute: %w", err) + } + executed = true + } + + // Remove uploaded file after execution (upload+execute mode only) + if executed && execIO.Upload != nil && execIO.Upload.Provider != nil { + if remover, ok := execIO.Upload.Provider.(UploadRemover); ok { + if removeErr := remover.RemoveUploadedFile(ctx); removeErr != nil { + log.Warn().Err(removeErr).Msg("Failed to remove uploaded file") + } + } + } + + // Module cleanup + if err = module.Clean(ctx); err != nil { + log.Error().Err(err).Msg("Module cleanup failed") + err = nil + } + + // Output collection + if execIO.Output != nil && execIO.Output.Provider != nil { + log.Info().Msg("Collecting output") + + defer func() { + if cleanErr := execIO.Clean(ctx); cleanErr != nil { + log.Debug().Err(cleanErr).Msg("Output provider cleanup failed") + } + }() + + if err := execIO.GetOutput(ctx); err != nil { + log.Error().Err(err).Msg("Output collection failed") + return fmt.Errorf("get output: %w", err) + } + log.Debug().Msg("Output collection succeeded") + } + return } diff --git a/pkg/goexec/smb/input.go b/pkg/goexec/smb/input.go index b9cb3bc..fe76e33 100644 --- a/pkg/goexec/smb/input.go +++ b/pkg/goexec/smb/input.go @@ -1,66 +1,104 @@ package smb import ( - "context" - "fmt" - "github.com/FalconOpsLLC/goexec/pkg/goexec" - "io" - "os" - "path" - "strings" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/FalconOpsLLC/goexec/pkg/goexec" + "github.com/rs/zerolog" ) type FileStager struct { - goexec.Cleaner + goexec.Cleaner + + Client *Client + + Share string + SharePath string + File string + relativePath string + ForceReconnect bool +} + +func (o *FileStager) Upload(ctx context.Context, reader io.Reader) (err error) { + + // Calculate relative path from share root to target file (matches OutputFileFetcher pattern) + shp := pathPrefix.ReplaceAllString(strings.ToLower(strings.ReplaceAll(o.SharePath, `\`, "/")), "") + fp := pathPrefix.ReplaceAllString(strings.ToLower(strings.ReplaceAll(o.File, `\`, "/")), "") + + if o.relativePath, err = filepath.Rel(shp, fp); err != nil { + return fmt.Errorf("calculate relative path: %w", err) + } + + if o.ForceReconnect || !o.Client.connected { + err = o.Client.Connect(ctx) + if err != nil { + return + } + defer o.AddCleaners(o.Client.Close) + } + + if o.ForceReconnect || o.Client.share != o.Share { + err = o.Client.Mount(ctx, o.Share) + if err != nil { + return + } + } + + writer, err := o.Client.mount.OpenFile(o.relativePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("open remote file for writing: %w", err) + } - Client *Client + if _, err = io.Copy(writer, reader); err != nil { + return + } - Share string - SharePath string - File string - relativePath string - ForceReconnect bool - DeleteStage bool + o.AddCleaners(func(_ context.Context) error { return writer.Close() }) + + return +} + +// RemoveUploadedFile deletes the uploaded file from the remote filesystem. +// The share must already be mounted from a prior Upload call. +func (o *FileStager) RemoveUploadedFile(ctx context.Context) error { + log := zerolog.Ctx(ctx) + + if o.Client.mount == nil { + return fmt.Errorf("share not mounted") + } + + if err := o.Client.mount.Remove(o.relativePath); err != nil { + return fmt.Errorf("remove remote file: %w", err) + } + + log.Info().Str("path", o.File).Msg("Removed uploaded file") + return nil } -func (o *FileStager) Stage(ctx context.Context, reader io.Reader) (err error) { - - o.relativePath = path.Join( - strings.ReplaceAll(pathPrefix.ReplaceAllString(o.SharePath, ""), `\`, "/"), - strings.ReplaceAll(pathPrefix.ReplaceAllString(o.File, ""), `\`, "/"), - ) - - if o.ForceReconnect || !o.Client.connected { - err = o.Client.Connect(ctx) - if err != nil { - return - } - defer o.AddCleaners(o.Client.Close) - } - - if o.ForceReconnect || o.Client.share != o.Share { - err = o.Client.Mount(ctx, o.Share) - if err != nil { - return - } - } - - writer, err := o.Client.mount.OpenFile(o.relativePath, os.O_WRONLY, 0644) - if err != nil { - return fmt.Errorf("open remote file for writing: %w", err) - } - - if _, err = io.Copy(writer, reader); err != nil { - return - } - - o.AddCleaners(func(_ context.Context) error { return writer.Close() }) - - if o.DeleteStage { - o.AddCleaners(func(_ context.Context) error { - return o.Client.mount.Remove(o.relativePath) - }) - } - - return +// ConfirmUpload checks that the uploaded file exists on the remote filesystem +// and logs the file path and size. The share must already be mounted from a +// prior Upload call. +func (o *FileStager) ConfirmUpload(ctx context.Context) error { + log := zerolog.Ctx(ctx) + + if o.Client.mount == nil { + return fmt.Errorf("share not mounted") + } + + info, err := o.Client.mount.Stat(o.relativePath) + if err != nil { + return fmt.Errorf("stat remote file: %w", err) + } + + log.Info(). + Str("path", o.File). + Int64("size", info.Size()). + Msg("Upload confirmed") + + return nil }