package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/util"
"nhooyr.io/websocket/internal/xsync"
)
func (c *Conn ) Reader (ctx context .Context ) (MessageType , io .Reader , error ) {
return c .reader (ctx )
}
func (c *Conn ) Read (ctx context .Context ) (MessageType , []byte , error ) {
typ , r , err := c .Reader (ctx )
if err != nil {
return 0 , nil , err
}
b , err := io .ReadAll (r )
return typ , b , err
}
func (c *Conn ) CloseRead (ctx context .Context ) context .Context {
ctx , cancel := context .WithCancel (ctx )
c .wg .Add (1 )
go func () {
defer c .CloseNow ()
defer c .wg .Done ()
defer cancel ()
_ , _ , err := c .Reader (ctx )
if err == nil {
c .Close (StatusPolicyViolation , "unexpected data message" )
}
}()
return ctx
}
func (c *Conn ) SetReadLimit (n int64 ) {
if n >= 0 {
n ++
}
c .msgReader .limitReader .limit .Store (n )
}
const defaultReadLimit = 32768
func newMsgReader (c *Conn ) *msgReader {
mr := &msgReader {
c : c ,
fin : true ,
}
mr .readFunc = mr .read
mr .limitReader = newLimitReader (c , mr .readFunc , defaultReadLimit +1 )
return mr
}
func (mr *msgReader ) resetFlate () {
if mr .flateContextTakeover () {
if mr .dict == nil {
mr .dict = &slidingWindow {}
}
mr .dict .init (32768 )
}
if mr .flateBufio == nil {
mr .flateBufio = getBufioReader (mr .readFunc )
}
if mr .flateContextTakeover () {
mr .flateReader = getFlateReader (mr .flateBufio , mr .dict .buf )
} else {
mr .flateReader = getFlateReader (mr .flateBufio , nil )
}
mr .limitReader .r = mr .flateReader
mr .flateTail .Reset (deflateMessageTail )
}
func (mr *msgReader ) putFlateReader () {
if mr .flateReader != nil {
putFlateReader (mr .flateReader )
mr .flateReader = nil
}
}
func (mr *msgReader ) close () {
mr .c .readMu .forceLock ()
mr .putFlateReader ()
if mr .dict != nil {
mr .dict .close ()
mr .dict = nil
}
if mr .flateBufio != nil {
putBufioReader (mr .flateBufio )
}
if mr .c .client {
putBufioReader (mr .c .br )
mr .c .br = nil
}
}
func (mr *msgReader ) flateContextTakeover () bool {
if mr .c .client {
return !mr .c .copts .serverNoContextTakeover
}
return !mr .c .copts .clientNoContextTakeover
}
func (c *Conn ) readRSV1Illegal (h header ) bool {
if !c .flate () {
return true
}
if h .opcode != opText && h .opcode != opBinary {
return true
}
return false
}
func (c *Conn ) readLoop (ctx context .Context ) (header , error ) {
for {
h , err := c .readFrameHeader (ctx )
if err != nil {
return header {}, err
}
if h .rsv1 && c .readRSV1Illegal (h ) || h .rsv2 || h .rsv3 {
err := fmt .Errorf ("received header with unexpected rsv bits set: %v:%v:%v" , h .rsv1 , h .rsv2 , h .rsv3 )
c .writeError (StatusProtocolError , err )
return header {}, err
}
if !c .client && !h .masked {
return header {}, errors .New ("received unmasked frame from client" )
}
switch h .opcode {
case opClose , opPing , opPong :
err = c .handleControl (ctx , h )
if err != nil {
if h .opcode == opClose && CloseStatus (err ) != -1 {
return header {}, err
}
return header {}, fmt .Errorf ("failed to handle control frame %v: %w" , h .opcode , err )
}
case opContinuation , opText , opBinary :
return h , nil
default :
err := fmt .Errorf ("received unknown opcode %v" , h .opcode )
c .writeError (StatusProtocolError , err )
return header {}, err
}
}
}
func (c *Conn ) readFrameHeader (ctx context .Context ) (header , error ) {
select {
case <- c .closed :
return header {}, net .ErrClosed
case c .readTimeout <- ctx :
}
h , err := readFrameHeader (c .br , c .readHeaderBuf [:])
if err != nil {
select {
case <- c .closed :
return header {}, net .ErrClosed
case <- ctx .Done ():
return header {}, ctx .Err ()
default :
c .close (err )
return header {}, err
}
}
select {
case <- c .closed :
return header {}, net .ErrClosed
case c .readTimeout <- context .Background ():
}
return h , nil
}
func (c *Conn ) readFramePayload (ctx context .Context , p []byte ) (int , error ) {
select {
case <- c .closed :
return 0 , net .ErrClosed
case c .readTimeout <- ctx :
}
n , err := io .ReadFull (c .br , p )
if err != nil {
select {
case <- c .closed :
return n , net .ErrClosed
case <- ctx .Done ():
return n , ctx .Err ()
default :
err = fmt .Errorf ("failed to read frame payload: %w" , err )
c .close (err )
return n , err
}
}
select {
case <- c .closed :
return n , net .ErrClosed
case c .readTimeout <- context .Background ():
}
return n , err
}
func (c *Conn ) handleControl (ctx context .Context , h header ) (err error ) {
if h .payloadLength < 0 || h .payloadLength > maxControlPayload {
err := fmt .Errorf ("received control frame payload with invalid length: %d" , h .payloadLength )
c .writeError (StatusProtocolError , err )
return err
}
if !h .fin {
err := errors .New ("received fragmented control frame" )
c .writeError (StatusProtocolError , err )
return err
}
ctx , cancel := context .WithTimeout (ctx , time .Second *5 )
defer cancel ()
b := c .readControlBuf [:h .payloadLength ]
_, err = c .readFramePayload (ctx , b )
if err != nil {
return err
}
if h .masked {
mask (h .maskKey , b )
}
switch h .opcode {
case opPing :
return c .writeControl (ctx , opPong , b )
case opPong :
c .activePingsMu .Lock ()
pong , ok := c .activePings [string (b )]
c .activePingsMu .Unlock ()
if ok {
select {
case pong <- struct {}{}:
default :
}
}
return nil
}
defer func () {
c .readCloseFrameErr = err
}()
ce , err := parseClosePayload (b )
if err != nil {
err = fmt .Errorf ("received invalid close payload: %w" , err )
c .writeError (StatusProtocolError , err )
return err
}
err = fmt .Errorf ("received close frame: %w" , ce )
c .setCloseErr (err )
c .writeClose (ce .Code , ce .Reason )
c .close (err )
return err
}
func (c *Conn ) reader (ctx context .Context ) (_ MessageType , _ io .Reader , err error ) {
defer errd .Wrap (&err , "failed to get reader" )
err = c .readMu .lock (ctx )
if err != nil {
return 0 , nil , err
}
defer c .readMu .unlock ()
if !c .msgReader .fin {
err = errors .New ("previous message not read to completion" )
c .close (fmt .Errorf ("failed to get reader: %w" , err ))
return 0 , nil , err
}
h , err := c .readLoop (ctx )
if err != nil {
return 0 , nil , err
}
if h .opcode == opContinuation {
err := errors .New ("received continuation frame without text or binary frame" )
c .writeError (StatusProtocolError , err )
return 0 , nil , err
}
c .msgReader .reset (ctx , h )
return MessageType (h .opcode ), c .msgReader , nil
}
type msgReader struct {
c *Conn
ctx context .Context
flate bool
flateReader io .Reader
flateBufio *bufio .Reader
flateTail strings .Reader
limitReader *limitReader
dict *slidingWindow
fin bool
payloadLength int64
maskKey uint32
readFunc util .ReaderFunc
}
func (mr *msgReader ) reset (ctx context .Context , h header ) {
mr .ctx = ctx
mr .flate = h .rsv1
mr .limitReader .reset (mr .readFunc )
if mr .flate {
mr .resetFlate ()
}
mr .setFrame (h )
}
func (mr *msgReader ) setFrame (h header ) {
mr .fin = h .fin
mr .payloadLength = h .payloadLength
mr .maskKey = h .maskKey
}
func (mr *msgReader ) Read (p []byte ) (n int , err error ) {
err = mr .c .readMu .lock (mr .ctx )
if err != nil {
return 0 , fmt .Errorf ("failed to read: %w" , err )
}
defer mr .c .readMu .unlock ()
n , err = mr .limitReader .Read (p )
if mr .flate && mr .flateContextTakeover () {
p = p [:n ]
mr .dict .write (p )
}
if errors .Is (err , io .EOF ) || errors .Is (err , io .ErrUnexpectedEOF ) && mr .fin && mr .flate {
mr .putFlateReader ()
return n , io .EOF
}
if err != nil {
err = fmt .Errorf ("failed to read: %w" , err )
mr .c .close (err )
}
return n , err
}
func (mr *msgReader ) read (p []byte ) (int , error ) {
for {
if mr .payloadLength == 0 {
if mr .fin {
if mr .flate {
return mr .flateTail .Read (p )
}
return 0 , io .EOF
}
h , err := mr .c .readLoop (mr .ctx )
if err != nil {
return 0 , err
}
if h .opcode != opContinuation {
err := errors .New ("received new data message without finishing the previous message" )
mr .c .writeError (StatusProtocolError , err )
return 0 , err
}
mr .setFrame (h )
continue
}
if int64 (len (p )) > mr .payloadLength {
p = p [:mr .payloadLength ]
}
n , err := mr .c .readFramePayload (mr .ctx , p )
if err != nil {
return n , err
}
mr .payloadLength -= int64 (n )
if !mr .c .client {
mr .maskKey = mask (mr .maskKey , p )
}
return n , nil
}
}
type limitReader struct {
c *Conn
r io .Reader
limit xsync .Int64
n int64
}
func newLimitReader (c *Conn , r io .Reader , limit int64 ) *limitReader {
lr := &limitReader {
c : c ,
}
lr .limit .Store (limit )
lr .reset (r )
return lr
}
func (lr *limitReader ) reset (r io .Reader ) {
lr .n = lr .limit .Load ()
lr .r = r
}
func (lr *limitReader ) Read (p []byte ) (int , error ) {
if lr .n < 0 {
return lr .r .Read (p )
}
if lr .n == 0 {
err := fmt .Errorf ("read limited at %v bytes" , lr .limit .Load ())
lr .c .writeError (StatusMessageTooBig , err )
return 0 , err
}
if int64 (len (p )) > lr .n {
p = p [:lr .n ]
}
n , err := lr .r .Read (p )
lr .n -= int64 (n )
if lr .n < 0 {
lr .n = 0
}
return n , err
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .