package websocket
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"path/filepath"
"strings"
"nhooyr.io/websocket/internal/errd"
)
type AcceptOptions struct {
Subprotocols []string
InsecureSkipVerify bool
OriginPatterns []string
CompressionMode CompressionMode
CompressionThreshold int
}
func (opts *AcceptOptions ) cloneWithDefaults () *AcceptOptions {
var o AcceptOptions
if opts != nil {
o = *opts
}
return &o
}
func Accept (w http .ResponseWriter , r *http .Request , opts *AcceptOptions ) (*Conn , error ) {
return accept (w , r , opts )
}
func accept (w http .ResponseWriter , r *http .Request , opts *AcceptOptions ) (_ *Conn , err error ) {
defer errd .Wrap (&err , "failed to accept WebSocket connection" )
errCode , err := verifyClientRequest (w , r )
if err != nil {
http .Error (w , err .Error(), errCode )
return nil , err
}
opts = opts .cloneWithDefaults ()
if !opts .InsecureSkipVerify {
err = authenticateOrigin (r , opts .OriginPatterns )
if err != nil {
if errors .Is (err , filepath .ErrBadPattern ) {
log .Printf ("websocket: %v" , err )
err = errors .New (http .StatusText (http .StatusForbidden ))
}
http .Error (w , err .Error(), http .StatusForbidden )
return nil , err
}
}
hj , ok := w .(http .Hijacker )
if !ok {
err = errors .New ("http.ResponseWriter does not implement http.Hijacker" )
http .Error (w , http .StatusText (http .StatusNotImplemented ), http .StatusNotImplemented )
return nil , err
}
w .Header ().Set ("Upgrade" , "websocket" )
w .Header ().Set ("Connection" , "Upgrade" )
key := r .Header .Get ("Sec-WebSocket-Key" )
w .Header ().Set ("Sec-WebSocket-Accept" , secWebSocketAccept (key ))
subproto := selectSubprotocol (r , opts .Subprotocols )
if subproto != "" {
w .Header ().Set ("Sec-WebSocket-Protocol" , subproto )
}
copts , ok := selectDeflate (websocketExtensions (r .Header ), opts .CompressionMode )
if ok {
w .Header ().Set ("Sec-WebSocket-Extensions" , copts .String ())
}
w .WriteHeader (http .StatusSwitchingProtocols )
if ginWriter , ok := w .(interface {
WriteHeaderNow ()
}); ok {
ginWriter .WriteHeaderNow ()
}
netConn , brw , err := hj .Hijack ()
if err != nil {
err = fmt .Errorf ("failed to hijack connection: %w" , err )
http .Error (w , http .StatusText (http .StatusInternalServerError ), http .StatusInternalServerError )
return nil , err
}
b , _ := brw .Reader .Peek (brw .Reader .Buffered ())
brw .Reader .Reset (io .MultiReader (bytes .NewReader (b ), netConn ))
return newConn (connConfig {
subprotocol : w .Header ().Get ("Sec-WebSocket-Protocol" ),
rwc : netConn ,
client : false ,
copts : copts ,
flateThreshold : opts .CompressionThreshold ,
br : brw .Reader ,
bw : brw .Writer ,
}), nil
}
func verifyClientRequest (w http .ResponseWriter , r *http .Request ) (errCode int , _ error ) {
if !r .ProtoAtLeast (1 , 1 ) {
return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
}
if !headerContainsTokenIgnoreCase (r .Header , "Connection" , "Upgrade" ) {
w .Header ().Set ("Connection" , "Upgrade" )
w .Header ().Set ("Upgrade" , "websocket" )
return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
}
if !headerContainsTokenIgnoreCase (r .Header , "Upgrade" , "websocket" ) {
w .Header ().Set ("Connection" , "Upgrade" )
w .Header ().Set ("Upgrade" , "websocket" )
return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
}
if r .Method != "GET" {
return http .StatusMethodNotAllowed , fmt .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
}
if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
w .Header ().Set ("Sec-WebSocket-Version" , "13" )
return http .StatusBadRequest , fmt .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
}
websocketSecKeys := r .Header .Values ("Sec-WebSocket-Key" )
if len (websocketSecKeys ) == 0 {
return http .StatusBadRequest , errors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
}
if len (websocketSecKeys ) > 1 {
return http .StatusBadRequest , errors .New ("WebSocket protocol violation: multiple Sec-WebSocket-Key headers" )
}
websocketSecKey := strings .TrimSpace (websocketSecKeys [0 ])
if v , err := base64 .StdEncoding .DecodeString (websocketSecKey ); err != nil || len (v ) != 16 {
return http .StatusBadRequest , fmt .Errorf ("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string" , websocketSecKey )
}
return 0 , nil
}
func authenticateOrigin (r *http .Request , originHosts []string ) error {
origin := r .Header .Get ("Origin" )
if origin == "" {
return nil
}
u , err := url .Parse (origin )
if err != nil {
return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
}
if strings .EqualFold (r .Host , u .Host ) {
return nil
}
for _ , hostPattern := range originHosts {
matched , err := match (hostPattern , u .Host )
if err != nil {
return fmt .Errorf ("failed to parse filepath pattern %q: %w" , hostPattern , err )
}
if matched {
return nil
}
}
if u .Host == "" {
return fmt .Errorf ("request Origin %q is not a valid URL with a host" , origin )
}
return fmt .Errorf ("request Origin %q is not authorized for Host %q" , u .Host , r .Host )
}
func match (pattern , s string ) (bool , error ) {
return filepath .Match (strings .ToLower (pattern ), strings .ToLower (s ))
}
func selectSubprotocol (r *http .Request , subprotocols []string ) string {
cps := headerTokens (r .Header , "Sec-WebSocket-Protocol" )
for _ , sp := range subprotocols {
for _ , cp := range cps {
if strings .EqualFold (sp , cp ) {
return cp
}
}
}
return ""
}
func selectDeflate (extensions []websocketExtension , mode CompressionMode ) (*compressionOptions , bool ) {
if mode == CompressionDisabled {
return nil , false
}
for _ , ext := range extensions {
switch ext .name {
case "permessage-deflate" :
copts , ok := acceptDeflate (ext , mode )
if ok {
return copts , true
}
}
}
return nil , false
}
func acceptDeflate (ext websocketExtension , mode CompressionMode ) (*compressionOptions , bool ) {
copts := mode .opts ()
for _ , p := range ext .params {
switch p {
case "client_no_context_takeover" :
copts .clientNoContextTakeover = true
continue
case "server_no_context_takeover" :
copts .serverNoContextTakeover = true
continue
case "client_max_window_bits" ,
"server_max_window_bits=15" :
continue
}
if strings .HasPrefix (p , "client_max_window_bits=" ) {
continue
}
return nil , false
}
return copts , true
}
func headerContainsTokenIgnoreCase (h http .Header , key , token string ) bool {
for _ , t := range headerTokens (h , key ) {
if strings .EqualFold (t , token ) {
return true
}
}
return false
}
type websocketExtension struct {
name string
params []string
}
func websocketExtensions (h http .Header ) []websocketExtension {
var exts []websocketExtension
extStrs := headerTokens (h , "Sec-WebSocket-Extensions" )
for _ , extStr := range extStrs {
if extStr == "" {
continue
}
vals := strings .Split (extStr , ";" )
for i := range vals {
vals [i ] = strings .TrimSpace (vals [i ])
}
e := websocketExtension {
name : vals [0 ],
params : vals [1 :],
}
exts = append (exts , e )
}
return exts
}
func headerTokens (h http .Header , key string ) []string {
key = textproto .CanonicalMIMEHeaderKey (key )
var tokens []string
for _ , v := range h [key ] {
v = strings .TrimSpace (v )
for _ , t := range strings .Split (v , "," ) {
t = strings .TrimSpace (t )
tokens = append (tokens , t )
}
}
return tokens
}
var keyGUID = []byte ("258EAFA5-E914-47DA-95CA-C5AB0DC85B11" )
func secWebSocketAccept (secWebSocketKey string ) string {
h := sha1 .New ()
h .Write ([]byte (secWebSocketKey ))
h .Write (keyGUID )
return base64 .StdEncoding .EncodeToString (h .Sum (nil ))
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .