diff --git a/errors.go b/errors.go index f5668b2..e66aee3 100644 --- a/errors.go +++ b/errors.go @@ -57,3 +57,7 @@ func (r *ConnectionRejectedError) Error() string { func (r *ConnectionRejectedError) StatusCode() int { return r.code } + +type EscapeHatch interface { + Escape() bool +} diff --git a/server.go b/server.go index f6cc8af..bb84f80 100644 --- a/server.go +++ b/server.go @@ -375,7 +375,7 @@ type Upgrader struct { // sent with appropriate HTTP error code and body set to error message. // // RejectConnectionError could be used to get more control on response. - OnHost func(host []byte) error + OnHost func(err error, host []byte) error // OnHeader is a callback that will be called after successful parsing of // header, that is not used during WebSocket handshake procedure. That is, @@ -388,7 +388,7 @@ type Upgrader struct { // sent with appropriate HTTP error code and body set to error message. // // RejectConnectionError could be used to get more control on response. - OnHeader func(key, value []byte) error + OnHeader func(err error, key, value []byte) error // OnBeforeUpgrade is a callback that will be called before sending // successful upgrade response. @@ -503,7 +503,7 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { nonce = make([]byte, nonceSize) ) - for err == nil { + for { line, e := readLine(br) if e != nil { return hs, e @@ -523,31 +523,33 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { case headerHostCanonical: headerSeen |= headerSeenHost if onHost := u.OnHost; onHost != nil { - err = onHost(v) + err = onHost(err, v) } case headerUpgradeCanonical: headerSeen |= headerSeenUpgrade if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { err = ErrHandshakeBadUpgrade + break } case headerConnectionCanonical: - headerSeen |= headerSeenConnection - if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) { - err = ErrHandshakeBadConnection + if bytes.Equal(v, specHeaderValueConnection) || btsHasToken(v, specHeaderValueConnectionLower) { + headerSeen |= headerSeenConnection } case headerSecVersionCanonical: headerSeen |= headerSeenSecVersion if !bytes.Equal(v, specHeaderValueSecVersion) { err = ErrHandshakeUpgradeRequired + break } case headerSecKeyCanonical: headerSeen |= headerSeenSecKey if len(v) != nonceSize { err = ErrHandshakeBadSecKey + break } else { copy(nonce, v) } @@ -562,12 +564,16 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { } if !ok { err = ErrMalformedRequest + break } } case headerSecExtensionsCanonical: if f := u.Negotiate; err == nil && f != nil { hs.Extensions, err = negotiateExtensions(v, hs.Extensions, f) + if err != nil { + break + } } // DEPRECATED path. if custom, check := u.ExtensionCustom, u.Extension; u.Negotiate == nil && (custom != nil || check != nil) { @@ -579,12 +585,13 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { } if !ok { err = ErrMalformedRequest + break } } default: if onHeader := u.OnHeader; onHeader != nil { - err = onHeader(k, v) + err = onHeader(err, k, v) } } } @@ -624,12 +631,14 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { default: panic("unknown headers state") } - case err == nil && u.OnBeforeUpgrade != nil: header[1], err = u.OnBeforeUpgrade() } if err != nil { var code int + if t, ok := err.(EscapeHatch); ok && t.Escape() { + return + } if rej, ok := err.(*ConnectionRejectedError); ok { code = rej.code header[1] = rej.header