package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
type MessageType int
const (
MessageText MessageType = iota + 1
MessageBinary
)
type Conn struct {
noCopy noCopy
subprotocol string
rwc io .ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio .Reader
bw *bufio .Writer
readTimeout chan context .Context
writeTimeout chan context .Context
readMu *mu
readHeaderBuf [8 ]byte
readControlBuf [maxControlPayload ]byte
msgReader *msgReader
readCloseFrameErr error
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8 ]byte
writeHeader header
wg sync .WaitGroup
closed chan struct {}
closeMu sync .Mutex
closeErr error
wroteClose bool
pingCounter int32
activePingsMu sync .Mutex
activePings map [string ]chan <- struct {}
}
type connConfig struct {
subprotocol string
rwc io .ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio .Reader
bw *bufio .Writer
}
func newConn (cfg connConfig ) *Conn {
c := &Conn {
subprotocol : cfg .subprotocol ,
rwc : cfg .rwc ,
client : cfg .client ,
copts : cfg .copts ,
flateThreshold : cfg .flateThreshold ,
br : cfg .br ,
bw : cfg .bw ,
readTimeout : make (chan context .Context ),
writeTimeout : make (chan context .Context ),
closed : make (chan struct {}),
activePings : make (map [string ]chan <- struct {}),
}
c .readMu = newMu (c )
c .writeFrameMu = newMu (c )
c .msgReader = newMsgReader (c )
c .msgWriter = newMsgWriter (c )
if c .client {
c .writeBuf = extractBufioWriterBuf (c .bw , c .rwc )
}
if c .flate () && c .flateThreshold == 0 {
c .flateThreshold = 128
if !c .msgWriter .flateContextTakeover () {
c .flateThreshold = 512
}
}
runtime .SetFinalizer (c , func (c *Conn ) {
c .close (errors .New ("connection garbage collected" ))
})
c .wg .Add (1 )
go func () {
defer c .wg .Done ()
c .timeoutLoop ()
}()
return c
}
func (c *Conn ) Subprotocol () string {
return c .subprotocol
}
func (c *Conn ) close (err error ) {
c .closeMu .Lock ()
defer c .closeMu .Unlock ()
if c .isClosed () {
return
}
if err == nil {
err = c .rwc .Close ()
}
c .setCloseErrLocked (err )
close (c .closed )
runtime .SetFinalizer (c , nil )
c .rwc .Close ()
c .wg .Add (1 )
go func () {
defer c .wg .Done ()
c .msgWriter .close ()
c .msgReader .close ()
}()
}
func (c *Conn ) timeoutLoop () {
readCtx := context .Background ()
writeCtx := context .Background ()
for {
select {
case <- c .closed :
return
case writeCtx = <- c .writeTimeout :
case readCtx = <- c .readTimeout :
case <- readCtx .Done ():
c .setCloseErr (fmt .Errorf ("read timed out: %w" , readCtx .Err ()))
c .wg .Add (1 )
go func () {
defer c .wg .Done ()
c .writeError (StatusPolicyViolation , errors .New ("read timed out" ))
}()
case <- writeCtx .Done ():
c .close (fmt .Errorf ("write timed out: %w" , writeCtx .Err ()))
return
}
}
}
func (c *Conn ) flate () bool {
return c .copts != nil
}
func (c *Conn ) Ping (ctx context .Context ) error {
p := atomic .AddInt32 (&c .pingCounter , 1 )
err := c .ping (ctx , strconv .Itoa (int (p )))
if err != nil {
return fmt .Errorf ("failed to ping: %w" , err )
}
return nil
}
func (c *Conn ) ping (ctx context .Context , p string ) error {
pong := make (chan struct {}, 1 )
c .activePingsMu .Lock ()
c .activePings [p ] = pong
c .activePingsMu .Unlock ()
defer func () {
c .activePingsMu .Lock ()
delete (c .activePings , p )
c .activePingsMu .Unlock ()
}()
err := c .writeControl (ctx , opPing , []byte (p ))
if err != nil {
return err
}
select {
case <- c .closed :
return net .ErrClosed
case <- ctx .Done ():
err := fmt .Errorf ("failed to wait for pong: %w" , ctx .Err ())
c .close (err )
return err
case <- pong :
return nil
}
}
type mu struct {
c *Conn
ch chan struct {}
}
func newMu (c *Conn ) *mu {
return &mu {
c : c ,
ch : make (chan struct {}, 1 ),
}
}
func (m *mu ) forceLock () {
m .ch <- struct {}{}
}
func (m *mu ) tryLock () bool {
select {
case m .ch <- struct {}{}:
return true
default :
return false
}
}
func (m *mu ) lock (ctx context .Context ) error {
select {
case <- m .c .closed :
return net .ErrClosed
case <- ctx .Done ():
err := fmt .Errorf ("failed to acquire lock: %w" , ctx .Err ())
m .c .close (err )
return err
case m .ch <- struct {}{}:
select {
case <- m .c .closed :
m .unlock ()
return net .ErrClosed
default :
}
return nil
}
}
func (m *mu ) unlock () {
select {
case <- m .ch :
default :
}
}
type noCopy struct {}
func (*noCopy ) Lock () {}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .