package websocket
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"nhooyr.io/websocket/internal/errd"
)
type DialOptions struct {
HTTPClient *http .Client
HTTPHeader http .Header
Host string
Subprotocols []string
CompressionMode CompressionMode
CompressionThreshold int
}
func (opts *DialOptions ) cloneWithDefaults (ctx context .Context ) (context .Context , context .CancelFunc , *DialOptions ) {
var cancel context .CancelFunc
var o DialOptions
if opts != nil {
o = *opts
}
if o .HTTPClient == nil {
o .HTTPClient = http .DefaultClient
}
if o .HTTPClient .Timeout > 0 {
ctx , cancel = context .WithTimeout (ctx , o .HTTPClient .Timeout )
newClient := *o .HTTPClient
newClient .Timeout = 0
o .HTTPClient = &newClient
}
if o .HTTPHeader == nil {
o .HTTPHeader = http .Header {}
}
newClient := *o .HTTPClient
oldCheckRedirect := o .HTTPClient .CheckRedirect
newClient .CheckRedirect = func (req *http .Request , via []*http .Request ) error {
switch req .URL .Scheme {
case "ws" :
req .URL .Scheme = "http"
case "wss" :
req .URL .Scheme = "https"
}
if oldCheckRedirect != nil {
return oldCheckRedirect (req , via )
}
return nil
}
o .HTTPClient = &newClient
return ctx , cancel , &o
}
func Dial (ctx context .Context , u string , opts *DialOptions ) (*Conn , *http .Response , error ) {
return dial (ctx , u , opts , nil )
}
func dial (ctx context .Context , urls string , opts *DialOptions , rand io .Reader ) (_ *Conn , _ *http .Response , err error ) {
defer errd .Wrap (&err , "failed to WebSocket dial" )
var cancel context .CancelFunc
ctx , cancel , opts = opts .cloneWithDefaults (ctx )
if cancel != nil {
defer cancel ()
}
secWebSocketKey , err := secWebSocketKey (rand )
if err != nil {
return nil , nil , fmt .Errorf ("failed to generate Sec-WebSocket-Key: %w" , err )
}
var copts *compressionOptions
if opts .CompressionMode != CompressionDisabled {
copts = opts .CompressionMode .opts ()
}
resp , err := handshakeRequest (ctx , urls , opts , copts , secWebSocketKey )
if err != nil {
return nil , resp , err
}
respBody := resp .Body
resp .Body = nil
defer func () {
if err != nil {
r := io .LimitReader (respBody , 1024 )
timer := time .AfterFunc (time .Second *3 , func () {
respBody .Close ()
})
defer timer .Stop ()
b , _ := io .ReadAll (r )
respBody .Close ()
resp .Body = io .NopCloser (bytes .NewReader (b ))
}
}()
copts , err = verifyServerResponse (opts , copts , secWebSocketKey , resp )
if err != nil {
return nil , resp , err
}
rwc , ok := respBody .(io .ReadWriteCloser )
if !ok {
return nil , resp , fmt .Errorf ("response body is not a io.ReadWriteCloser: %T" , respBody )
}
return newConn (connConfig {
subprotocol : resp .Header .Get ("Sec-WebSocket-Protocol" ),
rwc : rwc ,
client : true ,
copts : copts ,
flateThreshold : opts .CompressionThreshold ,
br : getBufioReader (rwc ),
bw : getBufioWriter (rwc ),
}), resp , nil
}
func handshakeRequest (ctx context .Context , urls string , opts *DialOptions , copts *compressionOptions , secWebSocketKey string ) (*http .Response , error ) {
u , err := url .Parse (urls )
if err != nil {
return nil , fmt .Errorf ("failed to parse url: %w" , err )
}
switch u .Scheme {
case "ws" :
u .Scheme = "http"
case "wss" :
u .Scheme = "https"
case "http" , "https" :
default :
return nil , fmt .Errorf ("unexpected url scheme: %q" , u .Scheme )
}
req , err := http .NewRequestWithContext (ctx , "GET" , u .String (), nil )
if err != nil {
return nil , fmt .Errorf ("failed to create new http request: %w" , err )
}
if len (opts .Host ) > 0 {
req .Host = opts .Host
}
req .Header = opts .HTTPHeader .Clone ()
req .Header .Set ("Connection" , "Upgrade" )
req .Header .Set ("Upgrade" , "websocket" )
req .Header .Set ("Sec-WebSocket-Version" , "13" )
req .Header .Set ("Sec-WebSocket-Key" , secWebSocketKey )
if len (opts .Subprotocols ) > 0 {
req .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (opts .Subprotocols , "," ))
}
if copts != nil {
req .Header .Set ("Sec-WebSocket-Extensions" , copts .String ())
}
resp , err := opts .HTTPClient .Do (req )
if err != nil {
return nil , fmt .Errorf ("failed to send handshake request: %w" , err )
}
return resp , nil
}
func secWebSocketKey (rr io .Reader ) (string , error ) {
if rr == nil {
rr = rand .Reader
}
b := make ([]byte , 16 )
_ , err := io .ReadFull (rr , b )
if err != nil {
return "" , fmt .Errorf ("failed to read random data from rand.Reader: %w" , err )
}
return base64 .StdEncoding .EncodeToString (b ), nil
}
func verifyServerResponse (opts *DialOptions , copts *compressionOptions , secWebSocketKey string , resp *http .Response ) (*compressionOptions , error ) {
if resp .StatusCode != http .StatusSwitchingProtocols {
return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
}
if !headerContainsTokenIgnoreCase (resp .Header , "Connection" , "Upgrade" ) {
return nil , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
}
if !headerContainsTokenIgnoreCase (resp .Header , "Upgrade" , "WebSocket" ) {
return nil , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
}
if resp .Header .Get ("Sec-WebSocket-Accept" ) != secWebSocketAccept (secWebSocketKey ) {
return nil , fmt .Errorf ("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
resp .Header .Get ("Sec-WebSocket-Accept" ),
secWebSocketKey ,
)
}
err := verifySubprotocol (opts .Subprotocols , resp )
if err != nil {
return nil , err
}
return verifyServerExtensions (copts , resp .Header )
}
func verifySubprotocol (subprotos []string , resp *http .Response ) error {
proto := resp .Header .Get ("Sec-WebSocket-Protocol" )
if proto == "" {
return nil
}
for _ , sp2 := range subprotos {
if strings .EqualFold (sp2 , proto ) {
return nil
}
}
return fmt .Errorf ("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
}
func verifyServerExtensions (copts *compressionOptions , h http .Header ) (*compressionOptions , error ) {
exts := websocketExtensions (h )
if len (exts ) == 0 {
return nil , nil
}
ext := exts [0 ]
if ext .name != "permessage-deflate" || len (exts ) > 1 || copts == nil {
return nil , fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
}
_copts := *copts
copts = &_copts
for _ , p := range ext .params {
switch p {
case "client_no_context_takeover" :
copts .clientNoContextTakeover = true
continue
case "server_no_context_takeover" :
copts .serverNoContextTakeover = true
continue
}
if strings .HasPrefix (p , "server_max_window_bits=" ) {
continue
}
return nil , fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
}
return copts , nil
}
var bufioReaderPool sync .Pool
func getBufioReader (r io .Reader ) *bufio .Reader {
br , ok := bufioReaderPool .Get ().(*bufio .Reader )
if !ok {
return bufio .NewReader (r )
}
br .Reset (r )
return br
}
func putBufioReader (br *bufio .Reader ) {
bufioReaderPool .Put (br )
}
var bufioWriterPool sync .Pool
func getBufioWriter (w io .Writer ) *bufio .Writer {
bw , ok := bufioWriterPool .Get ().(*bufio .Writer )
if !ok {
return bufio .NewWriter (w )
}
bw .Reset (w )
return bw
}
func putBufioWriter (bw *bufio .Writer ) {
bufioWriterPool .Put (bw )
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .