diff --git a/INTERNALS.md b/INTERNALS.md index 9f93e17..ef04b88 100644 --- a/INTERNALS.md +++ b/INTERNALS.md @@ -13,8 +13,14 @@ Go pot as the names suggests is surprisingly 🥁... written in go! It is a HTTP ## Components Go pot is made up of a few different components that come together to make the staller. Some of the more idiomatic components are: -* **Staller**: A http handler that will stall for a request for a given amount of time. It gets a generator instance it will keep on calling for new data until just before the timeout it has been given is reached. At which point it will correctly terminate the response. +* **Staller**: A special handler that will stall for a request for a given amount of time. It gets a generator instance it will keep on calling for new data until just before the timeout it has been given is reached. At which point it will correctly terminate the response. * **Generator**: A generator will provide an infinite stream of fake structured data. That can be serialized into a number of different formats. * **TimeoutWatcher**: The timeout watcher will keep track of how long a bot is willing to wait for a response. It will do this by watching when a given IP address disconnects. If it gets a few similar disconnects in a row it will assume that that is the maximum time a bot is willing to wait for a response and then give a time just under that to the staller. * **Cluster**: The cluster is a way of sharing information about how long bots are willing to wait for a response to other nodes in the cluster. It uses memberlist (go) -* **Recast**: Recast is a way of restarting / reallocating IP addresses to avoid being blacklisted by connecting clients. It uses telemetry to see if stalling connections and moves to a different IP block if not. \ No newline at end of file +* **Recast**: Recast is a way of restarting / reallocating IP addresses to avoid being blacklisted by connecting clients. It uses telemetry to see if stalling connections and moves to a different IP block if not. +* **Detect / Multi protocol listener** Detect aims to watch for traffic on a TCP listener and make a guess at which protocol data being sent down the pipe belongs to. It does this by. +When a new connection is opened: + * Wait for some data to be sent by the client + * If no data is sent in X seconds begin to "probe" by sending different protocol headers back down the pipe + * Wait for data while probes are sent + * If still no data is sent change to the fallback handler diff --git a/README.md b/README.md index 683da2d..c230036 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # go-pot 🍯 -A HTTP tarpit written in Go designed to maximize bot misery through very slowly feeding them an infinite stream of fake secrets. +A Multi Protocol tarpit written in Go designed to maximize bot misery through very slowly feeding them an infinite stream of fake secrets. @@ -7,8 +7,8 @@ A HTTP tarpit written in Go designed to maximize bot misery through very slowly - **Realistic output**: Go pot will respond to requests with an infinite stream of realistic looking, parseable structured data full of fake secrets. `xml`, `json`, `yaml`, `hcl`, `toml`, `csv`, `ini`, and `sql` are all supported. - **Multiple protocols**: Both `http` and `ftp` are supported out of the box. Each with a tailored implementation. *More protocols are planned.* - **Intelligent stalling**: Go pot will attempt to work out how long a bot is willing to wait for a response and stall for exactly that long. This is done gradually making requests slower and slower until a timeout is reached. (Or the bot hangs forever!) -- **Small Profile**: Go pot can run on extremely low resource machines and is designed to be as lightweight as possible. -- **Clustering Support**: Go pot can be run in a clustered mode where multiple instances can share information about how long bots are willing to wait for a response. Also in cluster mode nodes can be configured to restart / reallocate IP addresses to avoid being blacklisted by connecting clients. +- **Small Profile**: Go pot aims to target fairly low end hardware. +- **Clustering Support**: Go pot can be run in a clustered mode where multiple instances can share information about how long bots are willing to wait for a response. Also in cluster mode nodes can be configured to restart / reallocate IP addresses to avoid being blacklisted by connecting clients. (Currently tested on AWS ECS) - **Customizable**: Go pot can be customized to respond with different different response times. ## Installation diff --git a/cmd/ftp.go b/cmd/ftp.go index 6c77e5d..e3699e6 100644 --- a/cmd/ftp.go +++ b/cmd/ftp.go @@ -23,7 +23,10 @@ var ftpCommand = &cobra.Command{ // Make sure only the FTP server is enabled conf.FtpServer.Enabled = true + + // Disable the server and multi protocol conf.Server.Disable = true + conf.MultiProtocol.Enabled = false di := di.CreateContainer(conf) di.Run() diff --git a/cmd/http.go b/cmd/http.go index 2986fa9..1cb3574 100644 --- a/cmd/http.go +++ b/cmd/http.go @@ -21,9 +21,12 @@ var httpCommand = &cobra.Command{ os.Exit(1) } + // Enable the server + conf.Server.Disable = false + // Make sure only the HTTP server is enabled conf.FtpServer.Enabled = false - conf.Server.Disable = false + conf.MultiProtocol.Enabled = false di := di.CreateContainer(conf) di.Run() diff --git a/config/config.go b/config/config.go index 3e42a98..b1d4420 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "strings" "github.com/knadh/koanf/providers/env" @@ -14,6 +15,7 @@ type ( Config struct { Server serverConfig `koanf:"server"` FtpServer ftpServerConfig `koanf:"ftp_server"` + MultiProtocol multiProtocolConfig `koanf:"multi_protocol"` Logging loggingConfig `koanf:"logging"` Cluster clusterConfig `koanf:"cluster"` TimeoutWatcher timeoutWatcherConfig `koanf:"timeout_watcher"` @@ -22,13 +24,28 @@ type ( Staller stallerConfig `koanf:"staller"` } + // Multi protocol configuration + multiProtocolConfig struct { + // If the multi protocol server should be started or not + Enabled bool `koanf:"enabled"` + + // The port to listen on. + Port int `koanf:"port" validate:"required,min=1,max=65535,no_duplicate_port"` + + // The protocol detectors to use + Host string `koanf:"host" validate:"omitempty"` + + // The protocol detectors to use + Protocols []string `koanf:"protocols" validate:"required,dive,oneof=http ftp all"` + } + // Server specific configuration serverConfig struct { // If the http server should be disabled Disable bool `koanf:"disable"` // Server port to listen on - Port int `koanf:"port" validate:"required,min=1,max=65535"` + Port int `koanf:"port" validate:"required,min=1,max=65535,no_duplicate_port"` // Server host to listen on Host string `koanf:"host" validate:"required"` @@ -74,13 +91,13 @@ type ( // The port to listen on N.b this is the control port // port 20 is used for data transfer by default in active mode. - Port int `koanf:"port" validate:"required,min=1,max=65535"` + Port int `koanf:"port" validate:"required,min=1,max=65535,no_duplicate_port"` // Host to listen on Host string `koanf:"host" validate:"required"` // Lower bound of ports exposed for passive mode default 50000-50100 - PassivePortRange string `koanf:"passive_port_range" validate:"omitempty,port_range"` + PassivePortRange string `koanf:"passive_port_range" validate:"omitempty,port_range,no_duplicate_port_range"` // The common name for the self signed certificate CertCommonName string `koanf:"cert_common_name" validate:"omitempty"` @@ -121,7 +138,7 @@ type ( Mode string `koanf:"mode" validate:"required_if=Enabled true,omitempty,oneof=fargate_ecs lan wan"` // The bind address for the cluster to listen on - BindPort int `koanf:"bind_port" validate:"required_if=Enabled true,omitempty,min=1,max=65535"` + BindPort int `koanf:"bind_port" validate:"required_if=Enabled true,omitempty,min=1,max=65535,no_duplicate_port"` // Known Peers KnownPeerIps []string `koanf:"known_peer_ips" validate:"required_if=Mode lan Mode wan,omitempty"` @@ -240,7 +257,7 @@ type ( Enabled bool `koanf:"enabled"` // The port for the prometheus collection endpoint - Port int `koanf:"prometheus_port" validate:"required,min=1,max=65535"` + Port int `koanf:"prometheus_port" validate:"required,min=1,max=65535,no_duplicate_port"` // The path for the prometheus endpoint Path string `koanf:"prometheus_path" validate:"required"` @@ -341,6 +358,29 @@ func NewConfig(cmd *cobra.Command, flagsUsed flagMap) (*Config, error) { setStringSlice(k, "server.access_log.fields_to_log") setStringSlice(k, "ftp_server.command_log.commands_to_log") setStringSlice(k, "ftp_server.command_log.additional_fields") + setStringSlice(k, "multi_protocol.protocols") + + // Implicitly enable each server type if specific configuration changes have been made to the default configuration + if err := setIfNotDefault(k, "ftp_server.enabled", true, map[string]interface{}{ + "ftp_server.port": defaultConfig.FtpServer.Port, + "ftp_server.host": defaultConfig.FtpServer.Host, + }); err != nil { + return nil, err + } + + if err := setIfNotDefault(k, "server.disable", false, map[string]interface{}{ + "server.port": defaultConfig.Server.Port, + "server.host": defaultConfig.Server.Host, + }); err != nil { + return nil, err + } + + if err := setIfNotDefault(k, "multi_protocol.enabled", false, map[string]interface{}{ + "multi_protocol.port": defaultConfig.MultiProtocol.Port, + "multi_protocol.host": defaultConfig.MultiProtocol.Host, + }); err != nil { + return nil, err + } var cfg *Config if err := k.UnmarshalWithConf("", &cfg, koanf.UnmarshalConf{Tag: "koanf"}); err != nil { @@ -378,3 +418,25 @@ func setStringSlice(k *koanf.Koanf, key string) { k.Delete(key) } } + +// In the event any of the given flags are not set to the given default value then the target key will be set to the target state +func setIfNotDefault(k *koanf.Koanf, targetKey string, targetState interface{}, configToCheckForChanges map[string]interface{}) error { + // Assert that the desired key exists + if !k.Exists(targetKey) { + return fmt.Errorf("key %s does not exist when trying to set implicitly", targetKey) + } + + // If the key is not the default value then we don't need to do anything + if k.Get(targetKey) == targetState { + return nil + } + + for key, value := range configToCheckForChanges { + if k.Get(key) != value { + k.Set(targetKey, targetState) + return nil + } + } + + return nil +} diff --git a/config/default.go b/config/default.go index 6128184..0bf0c5f 100644 --- a/config/default.go +++ b/config/default.go @@ -4,6 +4,12 @@ import "go.uber.org/zap/zapcore" // Default configuration values for the application var defaultConfig = Config{ + MultiProtocol: multiProtocolConfig{ + Enabled: false, + Host: "0.0.0.0", + Protocols: []string{"all"}, + Port: 8081, + }, Server: serverConfig{ Disable: false, Port: 8080, diff --git a/config/flags.go b/config/flags.go index 5325c30..34a3b43 100644 --- a/config/flags.go +++ b/config/flags.go @@ -222,6 +222,7 @@ var ftpFlags = flagMap{ } var startFlags = flagMap{ + // @todo - Next Major release [Swap this to be enabled] "http-disabled": { flagName: "http-disabled", configKey: "server.disable", @@ -237,6 +238,36 @@ var startFlags = flagMap{ configType: "bool", defaultValue: defaultConfig.FtpServer.Enabled, }, + + // Multi-protocol options + "multi-protocol": { + flagName: "multi-protocol", + configKey: "multi_protocol.enabled", + description: "Allows for multiple honeypots to bind to the same pot (will override --[protocol]-enabled / --[protocol]-disabled) flags", + configType: "bool", + defaultValue: defaultConfig.MultiProtocol.Enabled, + }, + "multi-protocol-port": { + flagName: "multi-protocol-port", + configKey: "multi_protocol.port", + description: "The port to use for the multi protocol. Default(8081). Has no effect unless --multi-protocol specified", + configType: "int", + defaultValue: defaultConfig.MultiProtocol.Port, + }, + "multi-protocol-host": { + flagName: "multi-protocol-host", + configKey: "multi_protocol.host", + description: "The host to bin the multi protocol listener to", + configType: "string", + defaultValue: defaultConfig.MultiProtocol.Host, + }, + "multi-protocol-protocols": { + flagName: "multi-protocol-protocols", + configKey: "multi_protocol.protocols", + description: "Comma separated list of protocols to enable as part of the 'multi-protocol' listener. Can be one of 'ftp', 'http' or 'all'", + configType: "string", + defaultValue: strings.Join(defaultConfig.MultiProtocol.Protocols, ","), + }, } func GetStartFlags() flagMap { diff --git a/config/validation.go b/config/validation.go index 9fe4154..c41e32a 100644 --- a/config/validation.go +++ b/config/validation.go @@ -10,6 +10,76 @@ import ( var portRangeRegex = regexp.MustCompile(`^(\d+)-(\d+)$`) +type ( + duplicatePortValidatorMeta struct { + name string + ports []int + } + + // Stateful validator use to assert all ports with the "no_duplicate_port" and "no_duplicate_port_range" tags are unique to stop later bind conflicts + duplicatePortValidator struct { + registeredPorts map[int]*duplicatePortValidatorMeta + registeredPortRanges []*duplicatePortValidatorMeta + } +) + +func NewDuplicatePortValidator() *duplicatePortValidator { + return &duplicatePortValidator{ + registeredPorts: make(map[int]*duplicatePortValidatorMeta), + registeredPortRanges: make([]*duplicatePortValidatorMeta, 0), + } +} + +func (v *duplicatePortValidator) ValidatePort(fl validator.FieldLevel) bool { + port := int(fl.Field().Int()) + if _, ok := v.registeredPorts[port]; ok { + return false + } + + for _, portRange := range v.registeredPortRanges { + if port >= portRange.ports[0] && port <= portRange.ports[1] { + return false + } + } + + // Register the port + v.registeredPorts[port] = &duplicatePortValidatorMeta{ + name: fl.FieldName(), + ports: []int{port}, + } + + return true +} + +func (v *duplicatePortValidator) ValidatePortRange(fl validator.FieldLevel) bool { + minPort, maxPort, err := ParsePortRange(fl.Field().String()) + if err != nil { + return false + } + + // Assert there are no conflicts with existing port ranges + for _, port := range v.registeredPortRanges { + if max(port.ports[0], minPort) <= min(port.ports[1], maxPort) { + return false + } + } + + // Assert there are no conflicts with existing ports + for port, _ := range v.registeredPorts { + if port >= minPort && port <= maxPort { + return false + } + } + + // Register the port range + v.registeredPortRanges = append(v.registeredPortRanges, &duplicatePortValidatorMeta{ + name: fl.Param(), + ports: []int{minPort, maxPort}, + }) + + return true +} + func ParsePortRange(portRange string) (int, int, error) { matches := portRangeRegex.FindStringSubmatch(portRange) if len(matches) < 3 { @@ -51,5 +121,15 @@ func newConfigValidator() (*validator.Validate, error) { if err := v.RegisterValidation("port_range", validatePortRange); err != nil { return nil, err } + + dpv := NewDuplicatePortValidator() + if err := v.RegisterValidation("no_duplicate_port", dpv.ValidatePort); err != nil { + return nil, err + } + + if err := v.RegisterValidation("no_duplicate_port_range", dpv.ValidatePortRange); err != nil { + return nil, err + } + return v, nil } diff --git a/di/di.go b/di/di.go index 70d4a98..1552ec9 100644 --- a/di/di.go +++ b/di/di.go @@ -16,6 +16,9 @@ import ( "github.com/ryanolee/go-pot/core/metrics" "github.com/ryanolee/go-pot/core/stall" "github.com/ryanolee/go-pot/generator" + "github.com/ryanolee/go-pot/protocol/detect" + "github.com/ryanolee/go-pot/protocol/detect/detector" + "github.com/ryanolee/go-pot/protocol/fallback" "github.com/ryanolee/go-pot/protocol/ftp" ftpDi "github.com/ryanolee/go-pot/protocol/ftp/di" "github.com/ryanolee/go-pot/protocol/ftp/driver" @@ -35,8 +38,8 @@ import ( // Creates the dependency injection container for the application func CreateContainer(conf *config.Config) *fx.App { - if !conf.FtpServer.Enabled && conf.Server.Disable { - fmt.Print("Both FTP and HTTP servers are disabled. There is nothing to do. Exiting.") + if !conf.FtpServer.Enabled && !conf.MultiProtocol.Enabled && conf.Server.Disable { + fmt.Println("All honeypots disabled. There is nothing to do. Exiting.") os.Exit(0) } @@ -73,6 +76,16 @@ func CreateContainer(conf *config.Config) *fx.App { fx.As(new(gossip.IMemberlist)), ), + // Detectors + fx.Annotate(detector.NewHttpDetector, fx.As(new(detector.ProtocolDetector)), fx.ResultTags(`group:"detectors"`)), + fx.Annotate(detector.NewFtpDetector, fx.As(new(detector.ProtocolDetector)), fx.ResultTags(`group:"detectors"`)), + + // Multi Protocol Listener + fx.Annotate(detect.NewMulitProtocolListener, fx.ParamTags(``, `group:"detectors"`)), + + // Fallback Server + fallback.NewFallbackProtocolServer, + // Http Server http.NewServer, fx.Annotate( @@ -105,49 +118,82 @@ func CreateContainer(conf *config.Config) *fx.App { }), // Shutdown hook - fx.Invoke(func(shutdown fx.Shutdowner) { + fx.Invoke(func(shutdown fx.Shutdowner, pool *stall.StallerPool, logger *zap.Logger) { go func() { shutdownChannel := make(chan os.Signal, 1) signal.Notify(shutdownChannel, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) <-shutdownChannel - zap.L().Info("Shutting down...") + + logger.Warn("Shutting down...") + + // Stop pool before normal FX lifecycle hook so that all active connections are closed + logger.Info("Flushing stall pool") + pool.Stop() + err := shutdown.Shutdown() if err != nil { - zap.L().Sugar().Fatalf("Error shutting down, Forcing shutdown", zap.Error(err)) + logger.Sugar().Fatalf("Error shutting down, Forcing shutdown", zap.Error(err)) } time.Sleep(time.Second * 30) - zap.L().Sugar().Fatal("Deadline has passed after 30 seconds, Forcing shutdown.") + logger.Error("Deadline has passed after 30 seconds, Forcing shutdown.") + os.Exit(0) }() }), // Start HTTP server - fx.Invoke(func(c *config.Config, s *http.Server) { - zap.L().Info("HTTP Server Enabled: ", zap.Bool("enabled", !conf.Server.Disable)) - if conf.Server.Disable { - zap.L().Info("Http is disabled") + fx.Invoke(func(c *config.Config, s *http.Server, logger *zap.Logger, mlp *detect.MultiProtocolListener) { + logger.Info("HTTP Server Enabled: ", zap.Bool("enabled", !conf.Server.Disable)) + if conf.Server.Disable && !mlp.ProtocolEnabled("http") { + logger.Info("Http is disabled") return } - zap.L().Info("Starting Http server", zap.Int("port", s.ListenPort), zap.String("host", s.ListenHost)) + logger.Info("Starting Http server", + zap.Int("port", s.ListenPort), + zap.String("host", s.ListenHost), + zap.Bool("managed_by_multi_protocol_listener", mlp.ProtocolEnabled("http")), + ) go func() { if err := s.Start(); err != nil { - zap.L().Fatal("Failed to start Http server", zap.Error(err)) + logger.Fatal("Failed to start Http server", zap.Error(err)) + os.Exit(1) } }() }), // Start Ftp server - fx.Invoke(func(c *config.Config, s *ftpserver.FtpServer) { - if !conf.FtpServer.Enabled { - zap.L().Info("Ftp is disabled") + fx.Invoke(func(c *config.Config, s *ftpserver.FtpServer, logger *zap.Logger, mlp *detect.MultiProtocolListener) { + if !conf.FtpServer.Enabled && !mlp.ProtocolEnabled("ftp") { + logger.Info("Ftp is disabled") return } - zap.L().Info("Starting Ftp server", zap.Int("port", c.FtpServer.Port), zap.String("host", c.FtpServer.Host), zap.String("passive_port_range", c.FtpServer.PassivePortRange)) + + logger.Info("Starting Ftp server", + zap.Int("port", c.FtpServer.Port), + zap.String("host", c.FtpServer.Host), + zap.String("passive_port_range", c.FtpServer.PassivePortRange), + zap.Bool("managed_by_multi_protocol_listener", mlp.ProtocolEnabled("ftp")), + ) + go func() { if err := s.ListenAndServe(); err != nil { - zap.L().Sugar().Fatalf("Failed to start Ftp server", "error", err) + logger.Sugar().Warn("Failed to start Ftp server", "error", err) + } + }() + }), + + // Start multi protocol listener + fx.Invoke(func(ls *detect.MultiProtocolListener, conf *config.Config, logger *zap.Logger) { + if ls == nil { + return + } + + logger.Info("Starting Multi Protocol Listener", zap.Int("port", conf.MultiProtocol.Port), zap.String("host", conf.MultiProtocol.Host)) + go func() { + if err := ls.Listen(); err != nil { + logger.Sugar().Fatalf("Failed to start Multi Protocol Listener", zap.Error(err)) } }() }), diff --git a/examples/config/reference.yml b/examples/config/reference.yml index 4b109ce..d1b3398 100644 --- a/examples/config/reference.yml +++ b/examples/config/reference.yml @@ -5,6 +5,7 @@ # Configuration for the go-pot server server: # If the http staller should be enabled + # Implicitly enabled by changing the port or host enabled: true # Port for the go-pot server to listen on @@ -218,8 +219,8 @@ staller: # Metric configuration for the FTP side of the staller ftp_server: - # If the fep server should be enabled or not + # Implicitly enabled by changing the port or host enabled: false # Port the FTP server should bind to @@ -305,3 +306,23 @@ ftp_server: # - type: always "ftp" # - none: No fields additional_fields: "id" + +# Enables connection muxing for multiple honeypot protocols to bind to the same port +# This is useful for when proxying a large number of ports to a single running instance of the honeypot +multi_protocol: + # If the multi protocol server should be enabled or not + # Implicitly enabled by changing the port or host + enabled: false + + # Port the multi protocol server should bind to + port: 2021 + + # The host for the go pot server + host: 0.0.0.0 + + # The supported protocols for the multi protocol server + # The following protocols are available: + # - ftp: The FTP protocol (Including FTPS) + # - http: The HTTP protocol + # - all: All available protocols + protocols: ['all'] \ No newline at end of file diff --git a/protocol/detect/conditionallistener.go b/protocol/detect/conditionallistener.go new file mode 100644 index 0000000..d96b6da --- /dev/null +++ b/protocol/detect/conditionallistener.go @@ -0,0 +1,55 @@ +package detect + +import ( + "net" + + "github.com/ryanolee/go-pot/config" +) + +type ( + // Listener that binds itself to a multi protocol "detector" + ConditionalListener struct { + receiverChannel chan net.Conn + shutdownChannel chan bool + address net.Addr + } +) + +func NewConditionalListenerFromConfig(config *config.Config) *ConditionalListener { + return NewConditionalListener( + &net.TCPAddr{ + IP: net.ParseIP(config.MultiProtocol.Host), + Port: config.MultiProtocol.Port, + }, + ) +} + +func NewConditionalListener(address net.Addr) *ConditionalListener { + return &ConditionalListener{ + receiverChannel: make(chan net.Conn), + shutdownChannel: make(chan bool, 1), + address: address, + } +} + +func (l *ConditionalListener) Close() error { + l.shutdownChannel <- true + return nil +} + +func (l *ConditionalListener) Accept() (net.Conn, error) { + select { + case <-l.shutdownChannel: + return nil, net.ErrClosed + case connection := <-l.receiverChannel: + return connection, nil + } +} + +func (l *ConditionalListener) Addr() net.Addr { + return l.address +} + +func (l *ConditionalListener) Dispatch(conn net.Conn) { + l.receiverChannel <- conn +} diff --git a/protocol/detect/conn.go b/protocol/detect/conn.go new file mode 100644 index 0000000..5511931 --- /dev/null +++ b/protocol/detect/conn.go @@ -0,0 +1,123 @@ +package detect + +import ( + "context" + "net" + "time" +) + +type ( + // Connection that can be replayed + rewindableConn struct { + conn net.Conn + buffer []byte + playBufferOnNextRead bool + } +) + +const ( + rewindBufferSize = 128 + pollInterval = time.Millisecond * 100 +) + +func newRewindableConnFromConn(conn net.Conn) *rewindableConn { + return &rewindableConn{ + conn: conn, + buffer: make([]byte, rewindBufferSize), + } +} + +// Custom method to rewind the connection +func (c *rewindableConn) Rewind() { + c.playBufferOnNextRead = true +} + +func (c *rewindableConn) Erase() { + c.playBufferOnNextRead = false + c.buffer = nil +} + +func (c *rewindableConn) Close() error { + return c.conn.Close() +} + +func (c *rewindableConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *rewindableConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *rewindableConn) ReadWithTimeout(ctx context.Context, b []byte, timeout time.Duration) (int, error) { + ticker := time.NewTicker(pollInterval) + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + for { + select { + case <-ticker.C: + n, err := c.Read(b) + if (n > 0 && !allZeroes(b)) || err != nil { + return n, err + } + case <-timeoutCtx.Done(): + return 0, nil + } + } +} + +func (c *rewindableConn) Read(b []byte) (int, error) { + // Replay buffer if needed + if c.playBufferOnNextRead && c.buffer != nil { + n := copy(b, c.buffer) + + // Once entire buffer is read, erase it + if len(c.buffer) == 0 || allZeroes(c.buffer) { + c.Erase() + return n, nil + } + + // Cut off part of the buffer that was read + c.buffer = c.buffer[n:] + + return n, nil + } + + // Read from the connection & handle error + n, err := c.conn.Read(b) + if err != nil { + return n, err + } + + // Copy back to buffer + if c.buffer != nil { + copy(c.buffer, b) + } + + return n, nil +} + +func (c *rewindableConn) Write(b []byte) (n int, err error) { + return c.conn.Write(b) +} + +func (c *rewindableConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *rewindableConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *rewindableConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func allZeroes(b []byte) bool { + for _, v := range b { + if v != 0 { + return false + } + } + return true +} diff --git a/protocol/detect/detector/detector.go b/protocol/detect/detector/detector.go new file mode 100644 index 0000000..4f19bfc --- /dev/null +++ b/protocol/detect/detector/detector.go @@ -0,0 +1,9 @@ +package detector + +type ( + ProtocolDetector interface { + ProtocolName() string + IsMatch([]byte) bool + GetProbe() []byte + } +) diff --git a/protocol/detect/detector/ftp.go b/protocol/detect/detector/ftp.go new file mode 100644 index 0000000..5ac599b --- /dev/null +++ b/protocol/detect/detector/ftp.go @@ -0,0 +1,26 @@ +package detector + +import "regexp" + +type ( + FtpDetector struct { + } +) + +var ftpDetectionRegex = regexp.MustCompile(`^(USER\s{1}[\w-_./]+|AUTH TLS)`) + +func NewFtpDetector() *FtpDetector { + return &FtpDetector{} +} + +func (d *FtpDetector) ProtocolName() string { + return "ftp" +} + +func (d *FtpDetector) IsMatch(buffer []byte) bool { + return ftpDetectionRegex.Match(buffer) +} + +func (d *FtpDetector) GetProbe() []byte { + return []byte("220 FTP Server\r\n") +} diff --git a/protocol/detect/detector/http.go b/protocol/detect/detector/http.go new file mode 100644 index 0000000..4cf7900 --- /dev/null +++ b/protocol/detect/detector/http.go @@ -0,0 +1,26 @@ +package detector + +import "regexp" + +type ( + HttpDetector struct { + } +) + +var httpDetectionRegex = regexp.MustCompile(`(?i)^(GET|POST|HEAD|PUT|DELETE|OPTIONS|TRACE|CONNECT)`) + +func NewHttpDetector() *HttpDetector { + return &HttpDetector{} +} + +func (d *HttpDetector) ProtocolName() string { + return "http" +} + +func (d *HttpDetector) IsMatch(buffer []byte) bool { + return httpDetectionRegex.Match(buffer) +} + +func (d *HttpDetector) GetProbe() []byte { + return nil +} diff --git a/protocol/detect/listener.go b/protocol/detect/listener.go new file mode 100644 index 0000000..9df3ea2 --- /dev/null +++ b/protocol/detect/listener.go @@ -0,0 +1,259 @@ +package detect + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "sync" + "time" + + "github.com/ryanolee/go-pot/config" + "github.com/ryanolee/go-pot/protocol/detect/detector" + "github.com/thoas/go-funk" + "go.uber.org/fx" + "go.uber.org/zap" +) + +type ( + MultiProtocolListener struct { + port int + host string + protocolDetectors map[string]detector.ProtocolDetector + protocolListeners map[string]*ConditionalListener + shutdownChannel chan bool + listenerContext context.Context + listenerCancel context.CancelFunc + enableAll bool + logger *zap.Logger + } +) + +const ( + initialReadTimeout = 2 * time.Second + detectReadTimeout = 6 * time.Second + probeInterval = 500 * time.Millisecond +) + +func NewMulitProtocolListener(lf fx.Lifecycle, detectors []detector.ProtocolDetector, config *config.Config, logger *zap.Logger) *MultiProtocolListener { + if !config.MultiProtocol.Enabled { + return &MultiProtocolListener{} + } + + protocolDetectors := make(map[string]detector.ProtocolDetector) + protocolListeners := make(map[string]*ConditionalListener) + enableAll := funk.ContainsString(config.MultiProtocol.Protocols, "all") + + for _, detector := range detectors { + if !funk.ContainsString(config.MultiProtocol.Protocols, detector.ProtocolName()) && !enableAll { + continue + } + protocolDetectors[detector.ProtocolName()] = detector + protocolListeners[detector.ProtocolName()] = NewConditionalListenerFromConfig(config) + + } + + protocolListeners["fallback"] = NewConditionalListenerFromConfig(config) + + shutdownChannel := make(chan bool, 1) + protocolListener := &MultiProtocolListener{ + protocolDetectors: protocolDetectors, + protocolListeners: protocolListeners, + shutdownChannel: shutdownChannel, + port: config.MultiProtocol.Port, + host: config.MultiProtocol.Host, + logger: logger, + enableAll: enableAll, + } + + lf.Append(fx.StopHook(func(ctx context.Context) error { + protocolListener.Shutdown() + return nil + })) + + return protocolListener +} + +func (l *MultiProtocolListener) Shutdown() { + l.listenerCancel() + for _, listener := range l.protocolListeners { + listener.Close() + } +} + +func (l *MultiProtocolListener) Listen() error { + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", l.host, l.port)) + if err != nil { + return err + } + + listenerContext, cancel := context.WithCancel(context.Background()) + l.listenerCancel = cancel + + go func() { + <-listenerContext.Done() + listener.Close() + }() + + for { + + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + l.logger.Info("listener closed shutting down...") + return nil + } + + l.logger.Error("error accepting connection", zap.Error(err)) + continue + } + + go l.HandleConnection(listenerContext, conn) + } +} + +func (l *MultiProtocolListener) HandleConnection(ctx context.Context, conn net.Conn) { + + // Wait for any data to be available on the connection + rewindableConn := newRewindableConnFromConn(conn) + context, cancel := context.WithCancel(ctx) + defer cancel() + + successful, err := l.attemptHandoff(context, rewindableConn, initialReadTimeout) + + // Attempt to handoff the connection to the associated listener + if err != nil { + cancel() + l.logger.Error("Failed to read data from connection", zap.Error(err)) + return + } + + // If the connection was not successfully handed off begin a probe + // to determine the protocol + if successful { + return + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + l.probe(context, rewindableConn) + if err != nil { + cancel() + } + wg.Done() + }() + + go func() { + successful, err := l.attemptHandoff(context, rewindableConn, detectReadTimeout) + + if err != nil { + l.logger.Error("Failed to read data from connection during probe attempt", zap.Error(err)) + cancel() + } + + if !successful { + l.logger.Info("No data read from connection during probe attempt") + l.HandoffConnection([]byte{}, rewindableConn) + } + + wg.Done() + }() + + wg.Wait() +} + +// Writes various probes down the pipe for connections that require banners to be sent +func (l *MultiProtocolListener) probe(context context.Context, conn net.Conn) error { + ticker := time.NewTicker(probeInterval) + for _, probe := range l.protocolDetectors { + if probe.GetProbe() == nil { + continue + } + + select { + case <-context.Done(): + return nil + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(probeInterval)) + _, err := conn.Write(probe.GetProbe()) + if err != nil { + l.logger.Error("Failed to write probe", zap.Error(err)) + return err + } + } + } + return nil +} + +func (l *MultiProtocolListener) attemptHandoff(ctx context.Context, conn *rewindableConn, timeout time.Duration) (bool, error) { + data := make([]byte, rewindBufferSize) + + conn.SetReadDeadline(time.Now().Add(timeout)) + n, err := conn.ReadWithTimeout(ctx, data, timeout) + + // If the connection errored out + // not because of a timeout, close the connection + if err != nil && !os.IsTimeout(err) { + conn.Close() + return false, err + } + + // If the context was cancelled for any reason close the overall connection + if errors.Is(ctx.Err(), context.Canceled) { + conn.Close() + return false, fmt.Errorf("Context cancelled") + } + + // No data was read from the connection + if n == 0 { + return false, nil + } + + l.HandoffConnection(data, conn) + return true, nil +} + +func (l *MultiProtocolListener) FindMatchingProtocol(data []byte) (string, error) { + for name, listener := range l.protocolDetectors { + if listener.IsMatch(data) { + return name, nil + } + } + + return "", fmt.Errorf("No Protocol found") +} + +func (l *MultiProtocolListener) ProtocolEnabled(protocol string) bool { + _, ok := l.protocolDetectors[protocol] + return l.enableAll || ok +} + +func (l *MultiProtocolListener) GetListenerForProtocol(protocol string) *ConditionalListener { + listener, ok := l.protocolListeners[protocol] + if !ok { + return l.protocolListeners["fallback"] + } + + return listener +} + +// Once we have any data commit to a decision and handoff the connection the +func (l *MultiProtocolListener) HandoffConnection(data []byte, conn *rewindableConn) { + protocol, err := l.FindMatchingProtocol(data) + if err != nil { + l.logger.Info("error finding matching protocol reverting to fallback handler", zap.String("data", string(data))) + protocol = "fallback" + } else { + l.logger.Info("found protocol handler for sent data", zap.String("protocol", protocol), zap.String("data", string(data))) + } + + protocolListener := l.GetListenerForProtocol(protocol) + + // Reset the connection + conn.Rewind() + + // Perform the handoff + protocolListener.Dispatch(conn) +} diff --git a/protocol/fallback/server.go b/protocol/fallback/server.go new file mode 100644 index 0000000..fc5dac6 --- /dev/null +++ b/protocol/fallback/server.go @@ -0,0 +1,93 @@ +package fallback + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/ryanolee/go-pot/config" + "github.com/ryanolee/go-pot/protocol/detect" + "go.uber.org/fx" + "go.uber.org/zap" +) + +type ( + // Connection handler to use in the event no other protocols are matched + fallbackProtocolServer struct { + multiProtocolListener *detect.MultiProtocolListener + } +) + +const slowSendInterval = 2000 * time.Millisecond + +var defaultMessage = []byte("‍") + +func NewFallbackProtocolServer(config *config.Config, lf fx.Lifecycle, mpl *detect.MultiProtocolListener) (*fallbackProtocolServer, error) { + if !config.MultiProtocol.Enabled { + return nil, fmt.Errorf("MultiProtocol is not enabled") + } + + cancelContext, cancel := context.WithCancel(context.Background()) + f := &fallbackProtocolServer{ + multiProtocolListener: mpl, + } + + lf.Append(fx.Hook{ + OnStart: func(context context.Context) error { + go func() { + err := f.Start(cancelContext) + if err != nil { + zap.S().Errorf("Error starting fallback protocol server: %v", err) + } + }() + return nil + }, + OnStop: func(context context.Context) error { + cancel() + return nil + }, + }) + + return f, nil +} + +func (f *fallbackProtocolServer) Start(ctx context.Context) error { + listener := f.multiProtocolListener.GetListenerForProtocol("fallback") // Send an infinite stream of Zero width joiners + defer listener.Close() + + if listener == nil { + return fmt.Errorf("Listener not found") + } + + go func() { + <-ctx.Done() + listener.Close() + }() + + for { + + conn, err := listener.Accept() + if err != nil { + return err + } + + go f.handleConnection(ctx, conn) + } +} + +func (f *fallbackProtocolServer) handleConnection(ctx context.Context, conn net.Conn) { + defer conn.Close() + ticker := time.NewTicker(slowSendInterval) + for { + select { + case <-ticker.C: + _, err := conn.Write(defaultMessage) + if err != nil { + return + } + case <-ctx.Done(): + return + } + } +} diff --git a/protocol/ftp/driver/server.go b/protocol/ftp/driver/server.go index 6096e72..32b608b 100644 --- a/protocol/ftp/driver/server.go +++ b/protocol/ftp/driver/server.go @@ -3,9 +3,11 @@ package driver import ( "crypto/tls" "fmt" + "net" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/ryanolee/go-pot/config" + "github.com/ryanolee/go-pot/protocol/detect" "github.com/ryanolee/go-pot/protocol/ftp/logging" "github.com/ryanolee/go-pot/protocol/ftp/throttle" "go.uber.org/zap" @@ -22,7 +24,12 @@ type FtpServerDriver struct { logger *logging.FtpCommandLogger } -func NewFtpServerDriver(c *config.Config, cf *FtpClientDriverFactory, throttle *throttle.FtpThrottle, logger *logging.FtpCommandLogger) (*FtpServerDriver, error) { +func NewFtpServerDriver(c *config.Config, cf *FtpClientDriverFactory, throttle *throttle.FtpThrottle, logger *logging.FtpCommandLogger, ls *detect.MultiProtocolListener) (*FtpServerDriver, error) { + multiProtocolEnabled := ls.ProtocolEnabled("ftp") + if !c.FtpServer.Enabled && !multiProtocolEnabled { + return nil, nil + } + cert, err := getSelfSignedCert(c) if err != nil { return nil, err @@ -33,6 +40,13 @@ func NewFtpServerDriver(c *config.Config, cf *FtpClientDriverFactory, throttle * return nil, err } + var listener net.Listener = nil + listenerAddr := fmt.Sprintf("%s:%d", c.FtpServer.Host, c.FtpServer.Port) + if multiProtocolEnabled { + listener = ls.GetListenerForProtocol("ftp") + listenerAddr = "" + } + return &FtpServerDriver{ clientFactory: cf, throttle: throttle, @@ -51,8 +65,11 @@ func NewFtpServerDriver(c *config.Config, cf *FtpClientDriverFactory, throttle * DisableMLST: false, DisableMFMT: false, + // Listener set + ListenAddr: listenerAddr, + Listener: listener, + // Connection port range - ListenAddr: fmt.Sprintf("%s:%d", c.FtpServer.Host, c.FtpServer.Port), PassiveTransferPortRange: &ftpserver.PortRange{ Start: lowerRange, End: upperRange, diff --git a/protocol/ftp/server.go b/protocol/ftp/server.go index f71183f..c380463 100644 --- a/protocol/ftp/server.go +++ b/protocol/ftp/server.go @@ -3,13 +3,15 @@ package ftp import ( ftpserver "github.com/fclairamb/ftpserverlib" "github.com/ryanolee/go-pot/config" + "github.com/ryanolee/go-pot/protocol/detect" "github.com/ryanolee/go-pot/protocol/ftp/driver" ) -func NewServer(driver *driver.FtpServerDriver, conf *config.Config) *ftpserver.FtpServer { - if !conf.FtpServer.Enabled { +func NewServer(driver *driver.FtpServerDriver, conf *config.Config, ls *detect.MultiProtocolListener) *ftpserver.FtpServer { + if !conf.FtpServer.Enabled && !ls.ProtocolEnabled("ftp") { return nil } - return ftpserver.NewFtpServer(driver) + server := ftpserver.NewFtpServer(driver) + return server } diff --git a/protocol/ftp/throttle/ftp_throttle.go b/protocol/ftp/throttle/ftp_throttle.go index 9d6d11a..61a9832 100644 --- a/protocol/ftp/throttle/ftp_throttle.go +++ b/protocol/ftp/throttle/ftp_throttle.go @@ -7,6 +7,7 @@ import ( "time" "github.com/ryanolee/go-pot/config" + "github.com/ryanolee/go-pot/protocol/detect" "go.uber.org/fx" "go.uber.org/zap" ) @@ -27,8 +28,8 @@ type FtpThrottle struct { closeChannel chan bool } -func NewFtpThrottle(lf fx.Lifecycle, cfg *config.Config) *FtpThrottle { - if !cfg.FtpServer.Enabled { +func NewFtpThrottle(lf fx.Lifecycle, cfg *config.Config, mlp *detect.MultiProtocolListener) *FtpThrottle { + if !cfg.FtpServer.Enabled && !mlp.ProtocolEnabled("ftp") { return nil } diff --git a/protocol/http/server.go b/protocol/http/server.go index 8926b52..f6978c2 100644 --- a/protocol/http/server.go +++ b/protocol/http/server.go @@ -3,6 +3,7 @@ package http import ( "context" "fmt" + "net" "time" "github.com/gofiber/fiber/v2" @@ -10,6 +11,7 @@ import ( "go.uber.org/zap" "github.com/ryanolee/go-pot/config" + "github.com/ryanolee/go-pot/protocol/detect" "github.com/ryanolee/go-pot/protocol/http/logging" "github.com/ryanolee/go-pot/protocol/http/stall" ) @@ -19,8 +21,9 @@ type ( App *fiber.App ListenPort int ListenHost string - Logger *zap.Logger + // Custom listener for the server (Overriding the default listener) + Listener net.Listener stallerFactory *stall.HttpStallerFactory } ) @@ -31,10 +34,17 @@ func NewServer( cfg *config.Config, logging logging.IServerLogger, stallerFactory *stall.HttpStallerFactory, + ls *detect.MultiProtocolListener, + logger *zap.Logger, ) *Server { // Only enable the trusted proxy check if we have trusted proxies trustedProxyCheck := len(cfg.Server.TrustedProxies) > 0 + var listener net.Listener = nil + if ls.ProtocolEnabled("http") { + listener = ls.GetListenerForProtocol("http") + } + server := &Server{ App: fiber.New(fiber.Config{ IdleTimeout: time.Second * 15, @@ -47,7 +57,7 @@ func NewServer( EnableTrustedProxyCheck: trustedProxyCheck, ErrorHandler: func(c *fiber.Ctx, err error) error { // All is always ok even if we have an error. Just log it and return an empty response - zap.L().Error("Error in request", zap.Error(err)) + logger.Error("Error in request", zap.Error(err)) return c.Status(fiber.StatusOK).SendString("{}") }, }), @@ -55,13 +65,15 @@ func NewServer( ListenPort: cfg.Server.Port, ListenHost: cfg.Server.Host, + Listener: listener, + stallerFactory: stallerFactory, } lf.Append(fx.Hook{ OnStop: func(ctx context.Context) error { - zap.L().Sugar().Info("Shutting down server") - return server.App.Shutdown() + logger.Sugar().Info("Http Shutting down server") + return server.App.ShutdownWithTimeout(time.Second * 5) }, }) @@ -86,5 +98,10 @@ func (s *Server) Start() error { return staller.StallContextBuffer(c) }) + if s.Listener != nil { + s.App.Listener(s.Listener) + return s.Start() + } + return s.App.Listen(fmt.Sprintf("%s:%d", s.ListenHost, s.ListenPort)) }