package websocket
import (
"bufio"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"compress/flate"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/util"
)
func (c *Conn ) Writer (ctx context .Context , typ MessageType ) (io .WriteCloser , error ) {
w , err := c .writer (ctx , typ )
if err != nil {
return nil , fmt .Errorf ("failed to get writer: %w" , err )
}
return w , nil
}
func (c *Conn ) Write (ctx context .Context , typ MessageType , p []byte ) error {
_ , err := c .write (ctx , typ , p )
if err != nil {
return fmt .Errorf ("failed to write msg: %w" , err )
}
return nil
}
type msgWriter struct {
c *Conn
mu *mu
writeMu *mu
closed bool
ctx context .Context
opcode opcode
flate bool
trimWriter *trimLastFourBytesWriter
flateWriter *flate .Writer
}
func newMsgWriter (c *Conn ) *msgWriter {
mw := &msgWriter {
c : c ,
mu : newMu (c ),
writeMu : newMu (c ),
}
return mw
}
func (mw *msgWriter ) ensureFlate () {
if mw .trimWriter == nil {
mw .trimWriter = &trimLastFourBytesWriter {
w : util .WriterFunc (mw .write ),
}
}
if mw .flateWriter == nil {
mw .flateWriter = getFlateWriter (mw .trimWriter )
}
mw .flate = true
}
func (mw *msgWriter ) flateContextTakeover () bool {
if mw .c .client {
return !mw .c .copts .clientNoContextTakeover
}
return !mw .c .copts .serverNoContextTakeover
}
func (c *Conn ) writer (ctx context .Context , typ MessageType ) (io .WriteCloser , error ) {
err := c .msgWriter .reset (ctx , typ )
if err != nil {
return nil , err
}
return c .msgWriter , nil
}
func (c *Conn ) write (ctx context .Context , typ MessageType , p []byte ) (int , error ) {
mw , err := c .writer (ctx , typ )
if err != nil {
return 0 , err
}
if !c .flate () {
defer c .msgWriter .mu .unlock ()
return c .writeFrame (ctx , true , false , c .msgWriter .opcode , p )
}
n , err := mw .Write (p )
if err != nil {
return n , err
}
err = mw .Close ()
return n , err
}
func (mw *msgWriter ) reset (ctx context .Context , typ MessageType ) error {
err := mw .mu .lock (ctx )
if err != nil {
return err
}
mw .ctx = ctx
mw .opcode = opcode (typ )
mw .flate = false
mw .closed = false
mw .trimWriter .reset ()
return nil
}
func (mw *msgWriter ) putFlateWriter () {
if mw .flateWriter != nil {
putFlateWriter (mw .flateWriter )
mw .flateWriter = nil
}
}
func (mw *msgWriter ) Write (p []byte ) (_ int , err error ) {
err = mw .writeMu .lock (mw .ctx )
if err != nil {
return 0 , fmt .Errorf ("failed to write: %w" , err )
}
defer mw .writeMu .unlock ()
if mw .closed {
return 0 , errors .New ("cannot use closed writer" )
}
defer func () {
if err != nil {
err = fmt .Errorf ("failed to write: %w" , err )
mw .c .close (err )
}
}()
if mw .c .flate () {
if mw .opcode != opContinuation && len (p ) >= mw .c .flateThreshold {
mw .ensureFlate ()
}
}
if mw .flate {
return mw .flateWriter .Write (p )
}
return mw .write (p )
}
func (mw *msgWriter ) write (p []byte ) (int , error ) {
n , err := mw .c .writeFrame (mw .ctx , false , mw .flate , mw .opcode , p )
if err != nil {
return n , fmt .Errorf ("failed to write data frame: %w" , err )
}
mw .opcode = opContinuation
return n , nil
}
func (mw *msgWriter ) Close () (err error ) {
defer errd .Wrap (&err , "failed to close writer" )
err = mw .writeMu .lock (mw .ctx )
if err != nil {
return err
}
defer mw .writeMu .unlock ()
if mw .closed {
return errors .New ("writer already closed" )
}
mw .closed = true
if mw .flate {
err = mw .flateWriter .Flush ()
if err != nil {
return fmt .Errorf ("failed to flush flate: %w" , err )
}
}
_, err = mw .c .writeFrame (mw .ctx , true , mw .flate , mw .opcode , nil )
if err != nil {
return fmt .Errorf ("failed to write fin frame: %w" , err )
}
if mw .flate && !mw .flateContextTakeover () {
mw .putFlateWriter ()
}
mw .mu .unlock ()
return nil
}
func (mw *msgWriter ) close () {
if mw .c .client {
mw .c .writeFrameMu .forceLock ()
putBufioWriter (mw .c .bw )
}
mw .writeMu .forceLock ()
mw .putFlateWriter ()
}
func (c *Conn ) writeControl (ctx context .Context , opcode opcode , p []byte ) error {
ctx , cancel := context .WithTimeout (ctx , time .Second *5 )
defer cancel ()
_ , err := c .writeFrame (ctx , true , false , opcode , p )
if err != nil {
return fmt .Errorf ("failed to write control frame %v: %w" , opcode , err )
}
return nil
}
func (c *Conn ) writeFrame (ctx context .Context , fin bool , flate bool , opcode opcode , p []byte ) (_ int , err error ) {
err = c .writeFrameMu .lock (ctx )
if err != nil {
return 0 , err
}
c .closeMu .Lock ()
wroteClose := c .wroteClose
c .closeMu .Unlock ()
if wroteClose && opcode != opClose {
c .writeFrameMu .unlock ()
select {
case <- ctx .Done ():
return 0 , ctx .Err ()
case <- c .closed :
return 0 , net .ErrClosed
}
}
defer c .writeFrameMu .unlock ()
select {
case <- c .closed :
return 0 , net .ErrClosed
case c .writeTimeout <- ctx :
}
defer func () {
if err != nil {
select {
case <- c .closed :
err = net .ErrClosed
case <- ctx .Done ():
err = ctx .Err ()
default :
}
c .close (err )
err = fmt .Errorf ("failed to write frame: %w" , err )
}
}()
c .writeHeader .fin = fin
c .writeHeader .opcode = opcode
c .writeHeader .payloadLength = int64 (len (p ))
if c .client {
c .writeHeader .masked = true
_, err = io .ReadFull (rand .Reader , c .writeHeaderBuf [:4 ])
if err != nil {
return 0 , fmt .Errorf ("failed to generate masking key: %w" , err )
}
c .writeHeader .maskKey = binary .LittleEndian .Uint32 (c .writeHeaderBuf [:])
}
c .writeHeader .rsv1 = false
if flate && (opcode == opText || opcode == opBinary ) {
c .writeHeader .rsv1 = true
}
err = writeFrameHeader (c .writeHeader , c .bw , c .writeHeaderBuf [:])
if err != nil {
return 0 , err
}
n , err := c .writeFramePayload (p )
if err != nil {
return n , err
}
if c .writeHeader .fin {
err = c .bw .Flush ()
if err != nil {
return n , fmt .Errorf ("failed to flush: %w" , err )
}
}
select {
case <- c .closed :
if opcode == opClose {
return n , nil
}
return n , net .ErrClosed
case c .writeTimeout <- context .Background ():
}
return n , nil
}
func (c *Conn ) writeFramePayload (p []byte ) (n int , err error ) {
defer errd .Wrap (&err , "failed to write frame payload" )
if !c .writeHeader .masked {
return c .bw .Write (p )
}
maskKey := c .writeHeader .maskKey
for len (p ) > 0 {
if c .bw .Available () == 0 {
err = c .bw .Flush ()
if err != nil {
return n , err
}
}
i := c .bw .Buffered ()
j := len (p )
if j > c .bw .Available () {
j = c .bw .Available ()
}
_ , err := c .bw .Write (p [:j ])
if err != nil {
return n , err
}
maskKey = mask (maskKey , c .writeBuf [i :c .bw .Buffered ()])
p = p [j :]
n += j
}
return n , nil
}
func extractBufioWriterBuf (bw *bufio .Writer , w io .Writer ) []byte {
var writeBuf []byte
bw .Reset (util .WriterFunc (func (p2 []byte ) (int , error ) {
writeBuf = p2 [:cap (p2 )]
return len (p2 ), nil
}))
bw .WriteByte (0 )
bw .Flush ()
bw .Reset (w )
return writeBuf
}
func (c *Conn ) writeError (code StatusCode , err error ) {
c .setCloseErr (err )
c .writeClose (code , err .Error())
c .close (nil )
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .