package websocket
import (
"bufio"
"context"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
type MessageType int
const (
MessageText MessageType = iota + 1
MessageBinary
)
type Conn struct {
noCopy noCopy
subprotocol string
rwc io .ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio .Reader
bw *bufio .Writer
readTimeoutStop atomic .Pointer [func () bool ]
writeTimeoutStop atomic .Pointer [func () bool ]
readMu *mu
readHeaderBuf [8 ]byte
readControlBuf [maxControlPayload ]byte
msgReader *msgReader
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8 ]byte
writeHeader header
closeStateMu sync .RWMutex
closeReceivedErr error
closeSentErr error
closeReadMu sync .Mutex
closeReadCtx context .Context
closeReadDone chan struct {}
closing atomic .Bool
closeMu sync .Mutex
closed chan struct {}
pingCounter atomic .Int64
activePingsMu sync .Mutex
activePings map [string ]chan <- struct {}
onPingReceived func (context .Context , []byte ) bool
onPongReceived func (context .Context , []byte )
}
type connConfig struct {
subprotocol string
rwc io .ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func (context .Context , []byte ) bool
onPongReceived func (context .Context , []byte )
br *bufio .Reader
bw *bufio .Writer
}
func newConn (cfg connConfig ) *Conn {
c := &Conn {
subprotocol : cfg .subprotocol ,
rwc : cfg .rwc ,
client : cfg .client ,
copts : cfg .copts ,
flateThreshold : cfg .flateThreshold ,
br : cfg .br ,
bw : cfg .bw ,
closed : make (chan struct {}),
activePings : make (map [string ]chan <- struct {}),
onPingReceived : cfg .onPingReceived ,
onPongReceived : cfg .onPongReceived ,
}
c .readMu = newMu (c )
c .writeFrameMu = newMu (c )
c .msgReader = newMsgReader (c )
c .msgWriter = newMsgWriter (c )
if c .client {
c .writeBuf = extractBufioWriterBuf (c .bw , c .rwc )
}
if c .flate () && c .flateThreshold == 0 {
c .flateThreshold = 128
if !c .msgWriter .flateContextTakeover () {
c .flateThreshold = 512
}
}
runtime .SetFinalizer (c , func (c *Conn ) {
c .close ()
})
return c
}
func (c *Conn ) Subprotocol () string {
return c .subprotocol
}
func (c *Conn ) close () error {
c .closeMu .Lock ()
defer c .closeMu .Unlock ()
if c .isClosed () {
return net .ErrClosed
}
runtime .SetFinalizer (c , nil )
close (c .closed )
err := c .rwc .Close ()
c .msgWriter .close ()
c .msgReader .close ()
return err
}
func (c *Conn ) setupWriteTimeout (ctx context .Context ) {
stop := context .AfterFunc (ctx , func () {
c .clearWriteTimeout ()
c .close ()
})
swapTimeoutStop (&c .writeTimeoutStop , &stop )
}
func (c *Conn ) clearWriteTimeout () {
swapTimeoutStop (&c .writeTimeoutStop , nil )
}
func (c *Conn ) setupReadTimeout (ctx context .Context ) {
stop := context .AfterFunc (ctx , func () {
c .clearReadTimeout ()
c .close ()
})
swapTimeoutStop (&c .readTimeoutStop , &stop )
}
func (c *Conn ) clearReadTimeout () {
swapTimeoutStop (&c .readTimeoutStop , nil )
}
func swapTimeoutStop (p *atomic .Pointer [func () bool ], newStop *func () bool ) {
oldStop := p .Swap (newStop )
if oldStop != nil {
(*oldStop )()
}
}
func (c *Conn ) flate () bool {
return c .copts != nil
}
func (c *Conn ) Ping (ctx context .Context ) error {
p := c .pingCounter .Add (1 )
err := c .ping (ctx , strconv .FormatInt (p , 10 ))
if err != nil {
return fmt .Errorf ("failed to ping: %w" , err )
}
return nil
}
func (c *Conn ) ping (ctx context .Context , p string ) error {
pong := make (chan struct {}, 1 )
c .activePingsMu .Lock ()
c .activePings [p ] = pong
c .activePingsMu .Unlock ()
defer func () {
c .activePingsMu .Lock ()
delete (c .activePings , p )
c .activePingsMu .Unlock ()
}()
err := c .writeControl (ctx , opPing , []byte (p ))
if err != nil {
return err
}
select {
case <- c .closed :
return net .ErrClosed
case <- ctx .Done ():
return fmt .Errorf ("failed to wait for pong: %w" , ctx .Err ())
case <- pong :
return nil
}
}
type mu struct {
c *Conn
ch chan struct {}
}
func newMu (c *Conn ) *mu {
return &mu {
c : c ,
ch : make (chan struct {}, 1 ),
}
}
func (m *mu ) forceLock () {
m .ch <- struct {}{}
}
func (m *mu ) tryLock () bool {
select {
case m .ch <- struct {}{}:
return true
default :
return false
}
}
func (m *mu ) lock (ctx context .Context ) error {
select {
case <- m .c .closed :
return net .ErrClosed
case <- ctx .Done ():
return fmt .Errorf ("failed to acquire lock: %w" , ctx .Err ())
case m .ch <- struct {}{}:
select {
case <- m .c .closed :
m .unlock ()
return net .ErrClosed
default :
}
return nil
}
}
func (m *mu ) unlock () {
select {
case <- m .ch :
default :
}
}
type noCopy struct {}
func (*noCopy ) Lock () {}
The pages are generated with Golds v0.8.4 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .