package tls
import (
"errors"
"fmt"
"github.com/refraction-networking/utls/internal/tls13"
)
type LoadSessionTrackerState int
const NeverCalled LoadSessionTrackerState = 0
const UtlsAboutToCall LoadSessionTrackerState = 1
const CalledByULoadSession LoadSessionTrackerState = 2
const CalledByGoTLS LoadSessionTrackerState = 3
type sessionControllerState int
const NoSession sessionControllerState = 0
const SessionTicketExtInitialized sessionControllerState = 1
const SessionTicketExtAllSet sessionControllerState = 2
const PskExtInitialized sessionControllerState = 3
const PskExtAllSet sessionControllerState = 4
type sessionController struct {
sessionTicketExt ISessionTicketExtension
pskExtension PreSharedKeyExtension
uconnRef *UConn
state sessionControllerState
loadSessionTracker LoadSessionTrackerState
callingLoadSession bool
locked bool
}
func newSessionController (uconn *UConn ) *sessionController {
return &sessionController {
uconnRef : uconn ,
sessionTicketExt : nil ,
pskExtension : nil ,
state : NoSession ,
locked : false ,
callingLoadSession : false ,
loadSessionTracker : NeverCalled ,
}
}
func (s *sessionController ) isSessionLocked () bool {
return s .locked
}
type shouldLoadSessionResult int
const shouldReturn shouldLoadSessionResult = 0
const shouldSetTicket shouldLoadSessionResult = 1
const shouldSetPsk shouldLoadSessionResult = 2
const shouldLoad shouldLoadSessionResult = 3
func (s *sessionController ) shouldLoadSession () shouldLoadSessionResult {
if s .sessionTicketExt == nil && s .pskExtension == nil || s .uconnRef .clientHelloBuildStatus != NotBuilt {
return shouldReturn
}
if s .state == SessionTicketExtInitialized {
return shouldSetTicket
}
if s .state == PskExtInitialized {
return shouldSetPsk
}
return shouldLoad
}
func (s *sessionController ) utlsAboutToLoadSession () {
uAssert (s .state == NoSession && !s .locked , "tls: aboutToLoadSession failed: must only load session when the session of the client hello is not locked and when there's currently no session" )
s .loadSessionTracker = UtlsAboutToCall
}
func (s *sessionController ) assertHelloNotBuilt (caller string ) {
if s .uconnRef .clientHelloBuildStatus != NotBuilt {
panic (fmt .Sprintf ("tls: %s failed: we can't modify the session after the clientHello is built" , caller ))
}
}
func (s *sessionController ) assertControllerState (caller string , desired sessionControllerState , moreDesiredStates ...sessionControllerState ) {
if s .state != desired && !anyTrue (moreDesiredStates , func (_ int , state *sessionControllerState ) bool {
return s .state == *state
}) {
panic (fmt .Sprintf ("tls: %s failed: undesired controller state %d" , caller , s .state ))
}
}
func (s *sessionController ) assertNotLocked (caller string ) {
if s .locked {
panic (fmt .Sprintf ("tls: %s failed: you must not modify the session after it's locked" , caller ))
}
}
func (s *sessionController ) assertCanSkip (caller , extensionName string ) {
if !s .uconnRef .skipResumptionOnNilExtension {
panic (fmt .Sprintf ("tls: %s failed: session resumption is enabled, but there is no %s in the ClientHelloSpec; Please consider provide one in the ClientHelloSpec; If this is intentional, you may consider disable resumption by setting Config.SessionTicketsDisabled to true, or set Config.PreferSkipResumptionOnNilExtension to true to suppress this exception" , caller , extensionName ))
}
}
func (s *sessionController ) finalCheck () {
s .assertControllerState ("SessionController.finalCheck" , PskExtAllSet , SessionTicketExtAllSet , NoSession )
s .locked = true
}
func initializationGuard [E Initializable , I func (E )](extension E , initializer I ) {
uAssert (!extension .IsInitialized (), "tls: initialization failed: the extension is already initialized" )
initializer (extension )
uAssert (extension .IsInitialized (), "tls: initialization failed: the extension is not initialized after initialization" )
}
func (s *sessionController ) initSessionTicketExt (session *SessionState , ticket []byte ) {
s .assertNotLocked ("initSessionTicketExt" )
s .assertHelloNotBuilt ("initSessionTicketExt" )
s .assertControllerState ("initSessionTicketExt" , NoSession )
panicOnNil ("initSessionTicketExt" , session , ticket )
if s .sessionTicketExt == nil {
s .assertCanSkip ("initSessionTicketExt" , "session ticket extension" )
return
}
initializationGuard (s .sessionTicketExt , func (e ISessionTicketExtension ) {
s .sessionTicketExt .InitializeByUtls (session , ticket )
})
s .state = SessionTicketExtInitialized
}
func (s *sessionController ) initPskExt (session *SessionState , earlySecret *tls13 .EarlySecret , binderKey []byte , pskIdentities []pskIdentity ) {
s .assertNotLocked ("initPskExt" )
s .assertHelloNotBuilt ("initPskExt" )
s .assertControllerState ("initPskExt" , NoSession )
panicOnNil ("initPskExt" , session , earlySecret , pskIdentities )
if s .pskExtension == nil {
s .assertCanSkip ("initPskExt" , "pre-shared key extension" )
return
}
initializationGuard (s .pskExtension , func (e PreSharedKeyExtension ) {
publicPskIdentities := mapSlice (pskIdentities , func (private pskIdentity ) PskIdentity {
return PskIdentity {
Label : private .label ,
ObfuscatedTicketAge : private .obfuscatedTicketAge ,
}
})
e .InitializeByUtls (session , earlySecret .Secret (), binderKey , publicPskIdentities )
})
s .state = PskExtInitialized
}
func (s *sessionController ) setSessionTicketToUConn () {
uAssert (s .sessionTicketExt != nil && s .state == SessionTicketExtInitialized , "tls: setSessionTicketExt failed: invalid state" )
s .uconnRef .HandshakeState .Session = s .sessionTicketExt .GetSession ()
s .uconnRef .HandshakeState .Hello .SessionTicket = s .sessionTicketExt .GetTicket ()
s .state = SessionTicketExtAllSet
}
func (s *sessionController ) setPskToUConn () {
uAssert (s .pskExtension != nil && (s .state == PskExtInitialized || s .state == PskExtAllSet ), "tls: setPskToUConn failed: invalid state" )
pskCommon := s .pskExtension .GetPreSharedKeyCommon ()
if s .state == PskExtInitialized {
s .uconnRef .HandshakeState .State13 .EarlySecret = pskCommon .EarlySecret
s .uconnRef .HandshakeState .Session = pskCommon .Session
s .uconnRef .HandshakeState .Hello .PskIdentities = pskCommon .Identities
s .uconnRef .HandshakeState .Hello .PskBinders = pskCommon .Binders
} else if s .state == PskExtAllSet {
uAssert (s .uconnRef .HandshakeState .Session == pskCommon .Session && sliceEq (s .uconnRef .HandshakeState .State13 .EarlySecret , pskCommon .EarlySecret ) &&
allTrue (s .uconnRef .HandshakeState .Hello .PskIdentities , func (i int , psk *PskIdentity ) bool {
return pskCommon .Identities [i ].ObfuscatedTicketAge == psk .ObfuscatedTicketAge && sliceEq (pskCommon .Identities [i ].Label , psk .Label )
}), "tls: setPskToUConn failed: only binders are allowed to change on state `PskAllSet`" )
}
s .uconnRef .HandshakeState .State13 .BinderKey = pskCommon .BinderKey
s .state = PskExtAllSet
}
func (s *sessionController ) shouldUpdateBinders () bool {
if s .pskExtension == nil {
return false
}
return (s .state == PskExtInitialized || s .state == PskExtAllSet )
}
func (s *sessionController ) updateBinders () {
uAssert (s .shouldUpdateBinders (), "tls: updateBinders failed: shouldn't update binders" )
s .pskExtension .PatchBuiltHello (s .uconnRef .HandshakeState .Hello )
}
func (s *sessionController ) overrideExtension (extension Initializable , override func (), initializedState sessionControllerState ) error {
panicOnNil ("overrideExtension" , extension )
s .assertNotLocked ("overrideExtension" )
s .assertControllerState ("overrideExtension" , NoSession )
override ()
if extension .IsInitialized () {
s .state = initializedState
}
return nil
}
func (s *sessionController ) overridePskExt (pskExt PreSharedKeyExtension ) error {
return s .overrideExtension (pskExt , func () { s .pskExtension = pskExt }, PskExtInitialized )
}
func (s *sessionController ) overrideSessionTicketExt (sessionTicketExt ISessionTicketExtension ) error {
return s .overrideExtension (sessionTicketExt , func () { s .sessionTicketExt = sessionTicketExt }, SessionTicketExtInitialized )
}
func (s *sessionController ) syncSessionExts () error {
uAssert (s .uconnRef .clientHelloBuildStatus == NotBuilt , "tls: checkSessionExts failed: we can't modify the session after the clientHello is built" )
s .assertNotLocked ("checkSessionExts" )
s .assertHelloNotBuilt ("checkSessionExts" )
s .assertControllerState ("checkSessionExts" , NoSession , SessionTicketExtInitialized , PskExtInitialized )
numSessionExt := 0
hasPskExt := false
for i , e := range s .uconnRef .Extensions {
switch ext := e .(type ) {
case ISessionTicketExtension :
uAssert (numSessionExt == 0 , "tls: checkSessionExts failed: multiple ISessionTicketExtensions in the extension list" )
if s .sessionTicketExt == nil {
s .sessionTicketExt = ext
} else {
s .uconnRef .Extensions [i ] = s .sessionTicketExt
}
numSessionExt += 1
case PreSharedKeyExtension :
uAssert (i == len (s .uconnRef .Extensions )-1 , "tls: checkSessionExts failed: PreSharedKeyExtension must be the last extension" )
if s .pskExtension == nil {
s .pskExtension = ext
} else {
s .uconnRef .Extensions [i ] = s .pskExtension
}
s .pskExtension .SetOmitEmptyPsk (s .uconnRef .config .OmitEmptyPsk )
hasPskExt = true
}
}
if numSessionExt == 0 {
if s .state == SessionTicketExtInitialized {
return errors .New ("tls: checkSessionExts failed: the user provided a session ticket, but the specification doesn't contain one" )
}
s .sessionTicketExt = nil
s .uconnRef .HandshakeState .Session = nil
s .uconnRef .HandshakeState .Hello .SessionTicket = nil
}
if !hasPskExt {
if s .state == PskExtInitialized {
return errors .New ("tls: checkSessionExts failed: the user provided a psk, but the specification doesn't contain one" )
}
s .pskExtension = nil
s .uconnRef .HandshakeState .State13 .BinderKey = nil
s .uconnRef .HandshakeState .State13 .EarlySecret = nil
s .uconnRef .HandshakeState .Session = nil
s .uconnRef .HandshakeState .Hello .PskIdentities = nil
}
return nil
}
func (s *sessionController ) onEnterLoadSessionCheck () {
uAssert (!s .locked , "tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: session is set and locked, no call to loadSession is allowed" )
switch s .loadSessionTracker {
case UtlsAboutToCall , NeverCalled :
s .callingLoadSession = true
case CalledByULoadSession , CalledByGoTLS :
panic ("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: you must not call loadSession() twice" )
default :
panic ("tls: LoadSessionCoordinator.onEnterLoadSessionCheck failed: unimplemented state" )
}
}
func (s *sessionController ) onLoadSessionReturn () {
uAssert (s .callingLoadSession , "tls: LoadSessionCoordinator.onLoadSessionReturn failed: it's not loading sessions, perhaps this function is not being called by loadSession." )
switch s .loadSessionTracker {
case NeverCalled :
s .loadSessionTracker = CalledByGoTLS
case UtlsAboutToCall :
s .loadSessionTracker = CalledByULoadSession
default :
panic ("tls: LoadSessionCoordinator.onLoadSessionReturn failed: unimplemented state" )
}
s .callingLoadSession = false
}
func (s *sessionController ) shouldLoadSessionWriteBinders () bool {
uAssert (s .callingLoadSession , "tls: shouldWriteBinders failed: LoadSessionCoordinator isn't loading sessions, perhaps this function is not being called by loadSession." )
switch s .loadSessionTracker {
case NeverCalled :
return true
case UtlsAboutToCall :
return false
default :
panic ("tls: shouldWriteBinders failed: unimplemented state" )
}
}
The pages are generated with Golds v0.8.4 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .