Source File
nat.go
Belonging Package
crypto/internal/bigmod
// Copyright 2021 The Go Authors. All rights reserved.// Use of this source code is governed by a BSD-style// license that can be found in the LICENSE file.package bigmodimport ()const (// _W is the size in bits of our limbs._W = bits.UintSize// _S is the size in bytes of our limbs._S = _W / 8)// choice represents a constant-time boolean. The value of choice is always// either 1 or 0. We use an int instead of bool in order to make decisions in// constant time by turning it into a mask.type choice uintfunc ( choice) choice { return 1 ^ }const yes = choice(1)const no = choice(0)// ctMask is all 1s if on is yes, and all 0s otherwise.func ( choice) uint { return -uint() }// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this// function does not depend on its inputs.func (, uint) choice {// If x != y, then either x - y or y - x will generate a carry., := bits.Sub(, , 0), := bits.Sub(, , 0)return not(choice( | ))}// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this// function does not depend on its inputs.func (, uint) choice {// If x < y, then x - y generates a carry., := bits.Sub(, , 0)return not(choice())}// Nat represents an arbitrary natural number//// Each Nat has an announced length, which is the number of limbs it has stored.// Operations on this number are allowed to leak this length, but will not leak// any information about the values contained in those limbs.type Nat struct {// limbs is little-endian in base 2^W with W = bits.UintSize.limbs []uint}// preallocTarget is the size in bits of the numbers used to implement the most// common and most performant RSA key size. It's also enough to cover some of// the operations of key sizes up to 4096.const preallocTarget = 2048const preallocLimbs = (preallocTarget + _W - 1) / _W// NewNat returns a new nat with a size of zero, just like new(Nat), but with// the preallocated capacity to hold a number of up to preallocTarget bits.// NewNat inlines, so the allocation can live on the stack.func () *Nat {:= make([]uint, 0, preallocLimbs)return &Nat{}}// expand expands x to n limbs, leaving its value unchanged.func ( *Nat) ( int) *Nat {if len(.limbs) > {panic("bigmod: internal error: shrinking nat")}if cap(.limbs) < {:= make([]uint, )copy(, .limbs).limbs =return}:= .limbs[len(.limbs):]for := range {[] = 0}.limbs = .limbs[:]return}// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).func ( *Nat) ( int) *Nat {if cap(.limbs) < {.limbs = make([]uint, )return}for := range .limbs {.limbs[] = 0}.limbs = .limbs[:]return}// set assigns x = y, optionally resizing x to the appropriate size.func ( *Nat) ( *Nat) *Nat {.reset(len(.limbs))copy(.limbs, .limbs)return}// setBig assigns x = n, optionally resizing n to the appropriate size.//// The announced length of x is set based on the actual bit size of the input,// ignoring leading zeroes.func ( *Nat) ( *big.Int) *Nat {:= .Bits().reset(len())for := range {.limbs[] = uint([])}return}// Bytes returns x as a zero-extended big-endian byte slice. The size of the// slice will match the size of m.//// x must have the same size as m and it must be reduced modulo m.func ( *Nat) ( *Modulus) []byte {:= .Size():= make([]byte, )for , := range .limbs {for := 0; < _S; ++ {--if < 0 {if == 0 {break}panic("bigmod: modulus is smaller than nat")}[] = byte()>>= 8}}return}// SetBytes assigns x = b, where b is a slice of big-endian bytes.// SetBytes returns an error if b >= m.//// The output will be resized to the size of m and overwritten.func ( *Nat) ( []byte, *Modulus) (*Nat, error) {if := .setBytes(, ); != nil {return nil,}if .cmpGeq(.nat) == yes {return nil, errors.New("input overflows the modulus")}return , nil}// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.// SetOverflowingBytes returns an error if b has a longer bit length than m, but// reduces overflowing values up to 2^⌈log2(m)⌉ - 1.//// The output will be resized to the size of m and overwritten.func ( *Nat) ( []byte, *Modulus) (*Nat, error) {if := .setBytes(, ); != nil {return nil,}:= _W - bitLen(.limbs[len(.limbs)-1])if < .leading {return nil, errors.New("input overflows the modulus size")}.maybeSubtractModulus(no, )return , nil}// bigEndianUint returns the contents of buf interpreted as a// big-endian encoded uint value.func ( []byte) uint {if _W == 64 {return uint(binary.BigEndian.Uint64())}return uint(binary.BigEndian.Uint32())}func ( *Nat) ( []byte, *Modulus) error {.resetFor(), := len(), 0for < len(.limbs) && >= _S {.limbs[] = bigEndianUint([-_S : ])-= _S++}for := 0; < _W && < len(.limbs) && > 0; += 8 {.limbs[] |= uint([-1]) <<--}if > 0 {return errors.New("input overflows the modulus size")}return nil}// Equal returns 1 if x == y, and 0 otherwise.//// Both operands must have the same announced length.func ( *Nat) ( *Nat) choice {// Eliminate bounds checks in the loop.:= len(.limbs):= .limbs[:]:= .limbs[:]:= yesfor := 0; < ; ++ {&= ctEq([], [])}return}// IsZero returns 1 if x == 0, and 0 otherwise.func ( *Nat) () choice {// Eliminate bounds checks in the loop.:= len(.limbs):= .limbs[:]:= yesfor := 0; < ; ++ {&= ctEq([], 0)}return}// cmpGeq returns 1 if x >= y, and 0 otherwise.//// Both operands must have the same announced length.func ( *Nat) ( *Nat) choice {// Eliminate bounds checks in the loop.:= len(.limbs):= .limbs[:]:= .limbs[:]var uintfor := 0; < ; ++ {_, = bits.Sub([], [], )}// If there was a carry, then subtracting y underflowed, so// x is not greater than or equal to y.return not(choice())}// assign sets x <- y if on == 1, and does nothing otherwise.//// Both operands must have the same announced length.func ( *Nat) ( choice, *Nat) *Nat {// Eliminate bounds checks in the loop.:= len(.limbs):= .limbs[:]:= .limbs[:]:= ctMask()for := 0; < ; ++ {[] ^= & ([] ^ [])}return}// add computes x += y and returns the carry.//// Both operands must have the same announced length.func ( *Nat) ( *Nat) ( uint) {// Eliminate bounds checks in the loop.:= len(.limbs):= .limbs[:]:= .limbs[:]for := 0; < ; ++ {[], = bits.Add([], [], )}return}// sub computes x -= y. It returns the borrow of the subtraction.//// Both operands must have the same announced length.func ( *Nat) ( *Nat) ( uint) {// Eliminate bounds checks in the loop.:= len(.limbs):= .limbs[:]:= .limbs[:]for := 0; < ; ++ {[], = bits.Sub([], [], )}return}// Modulus is used for modular arithmetic, precomputing relevant constants.//// Moduli are assumed to be odd numbers. Moduli can also leak the exact// number of bits needed to store their value, and are stored without padding.//// Their actual value is still kept secret.type Modulus struct {// The underlying natural number for this modulus.//// This will be stored without any padding, and shouldn't alias with any// other natural number being used.nat *Natleading int // number of leading zeros in the modulusm0inv uint // -nat.limbs[0]⁻¹ mod _Wrr *Nat // R*R for montgomeryRepresentation}// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).func ( *Modulus) *Nat {:= NewNat().ExpandFor()// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the// most significant limb to 1. We then get to R*R by shifting left by _W// n + 1 times.:= len(.limbs).limbs[-1] = 1for := - 1; < 2*; ++ {.shiftIn(0, ) // x = x * 2^_W mod m}return}// minusInverseModW computes -x⁻¹ mod _W with x odd.//// This operation is used to precompute a constant involved in Montgomery// multiplication.func ( uint) uint {// Every iteration of this loop doubles the least-significant bits of// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough// for 64 bits (and wastes only one iteration for 32 bits).//// See https://crypto.stackexchange.com/a/47496.:=for := 0; < 5; ++ {= * (2 - *)}return -}// NewModulusFromBig creates a new Modulus from a [big.Int].//// The Int must be odd. The number of significant bits (and nothing else) is// leaked through timing side-channels.func ( *big.Int) (*Modulus, error) {if := .Bits(); len() == 0 {return nil, errors.New("modulus must be >= 0")} else if [0]&1 != 1 {return nil, errors.New("modulus must be odd")}:= &Modulus{}.nat = NewNat().setBig().leading = _W - bitLen(.nat.limbs[len(.nat.limbs)-1]).m0inv = minusInverseModW(.nat.limbs[0]).rr = rr()return , nil}// bitLen is a version of bits.Len that only leaks the bit length of n, but not// its value. bits.Len and bits.LeadingZeros use a lookup table for the// low-order bits on some architectures.func ( uint) int {var int// We assume, here and elsewhere, that comparison to zero is constant time// with respect to different non-zero values.for != 0 {++>>= 1}return}// Size returns the size of m in bytes.func ( *Modulus) () int {return (.BitLen() + 7) / 8}// BitLen returns the size of m in bits.func ( *Modulus) () int {return len(.nat.limbs)*_W - int(.leading)}// Nat returns m as a Nat. The return value must not be written to.func ( *Modulus) () *Nat {return .nat}// shiftIn calculates x = x << _W + y mod m.//// This assumes that x is already reduced mod m.func ( *Nat) ( uint, *Modulus) *Nat {:= NewNat().resetFor()// Eliminate bounds checks in the loop.:= len(.nat.limbs):= .limbs[:]:= .limbs[:]:= .nat.limbs[:]// Each iteration of this loop computes x = 2x + b mod m, where b is a bit// from y. Effectively, it left-shifts x and adds y one bit at a time,// reducing it every time.//// To do the reduction, each iteration computes both 2x + b and 2x + b - m.// The next iteration (and finally the return line) will use either result// based on whether 2x + b overflows m.:= nofor := _W - 1; >= 0; -- {:= ( >> ) & 1var uint:= ctMask()for := 0; < ; ++ {:= [] ^ ( & ([] ^ []))[], = bits.Add(, , )[], = bits.Sub([], [], )}// Like in maybeSubtractModulus, we need the subtraction if either it// didn't underflow (meaning 2x + b > m) or if computing 2x + b// overflowed (meaning 2x + b > 2^_W*n > m).= not(choice()) | choice()}return .assign(, )}// Mod calculates out = x mod m.//// This works regardless how large the value of x is.//// The output will be resized to the size of m and overwritten.func ( *Nat) ( *Nat, *Modulus) *Nat {.resetFor()// Working our way from the most significant to the least significant limb,// we can insert each limb at the least significant position, shifting all// previous limbs left by _W. This way each limb will get shifted by the// correct number of bits. We can insert at least N - 1 limbs without// overflowing m. After that, we need to reduce every time we shift.:= len(.limbs) - 1// For the first N - 1 limbs we can skip the actual shifting and position// them at the shifted position, which starts at min(N - 2, i).:= len(.nat.limbs) - 2if < {=}for := ; >= 0; -- {.limbs[] = .limbs[]--}// We shift in the remaining limbs, reducing modulo m each time.for >= 0 {.shiftIn(.limbs[], )--}return}// ExpandFor ensures x has the right size to work with operations modulo m.//// The announced size of x must be smaller than or equal to that of m.func ( *Nat) ( *Modulus) *Nat {return .expand(len(.nat.limbs))}// resetFor ensures out has the right size to work with operations modulo m.//// out is zeroed and may start at any size.func ( *Nat) ( *Modulus) *Nat {return .reset(len(.nat.limbs))}// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.//// It can be used to reduce modulo m a value up to 2m - 1, which is a common// range for results computed by higher level operations.//// always is usually a carry that indicates that the operation that produced x// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.//// x and m operands must have the same announced length.func ( *Nat) ( choice, *Modulus) {:= NewNat().set():= .sub(.nat)// We keep the result if x - m didn't underflow (meaning x >= m)// or if always was set.:= not(choice()) | choice().assign(, )}// Sub computes x = x - y mod m.//// The length of both operands must be the same as the modulus. Both operands// must already be reduced modulo m.func ( *Nat) ( *Nat, *Modulus) *Nat {:= .sub()// If the subtraction underflowed, add m.:= NewNat().set().add(.nat).assign(choice(), )return}// Add computes x = x + y mod m.//// The length of both operands must be the same as the modulus. Both operands// must already be reduced modulo m.func ( *Nat) ( *Nat, *Modulus) *Nat {:= .add().maybeSubtractModulus(choice(), )return}// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and// n = len(m.nat.limbs).//// Faster Montgomery multiplication replaces standard modular multiplication for// numbers in this representation.//// This assumes that x is already reduced mod m.func ( *Nat) ( *Modulus) *Nat {// A Montgomery multiplication (which computes a * b / R) by R * R works out// to a multiplication by R, which takes the value out of the Montgomery domain.return .montgomeryMul(, .rr, )}// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and// n = len(m.nat.limbs).//// This assumes that x is already reduced mod m.func ( *Nat) ( *Modulus) *Nat {// By Montgomery multiplying with 1 not in Montgomery representation, we// convert out back from Montgomery representation, because it works out to// dividing by R.:= NewNat().ExpandFor().limbs[0] = 1return .montgomeryMul(, , )}// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and// n = len(m.nat.limbs), also known as a Montgomery multiplication.//// All inputs should be the same length and already reduced modulo m.// x will be resized to the size of m and overwritten.func ( *Nat) ( *Nat, *Nat, *Modulus) *Nat {:= len(.nat.limbs):= .nat.limbs[:]:= .limbs[:]:= .limbs[:]switch {default:// Attempt to use a stack-allocated backing array.:= make([]uint, 0, preallocLimbs*2)if cap() < *2 {= make([]uint, 0, *2)}= [:*2]// This loop implements Word-by-Word Montgomery Multiplication, as// described in Algorithm 4 (Fig. 3) of "Efficient Software// Implementations of Modular Exponentiation" by Shay Gueron// [https://eprint.iacr.org/2011/239.pdf].var uintfor := 0; < ; ++ {_ = [+] // bounds check elimination hint// Step 1 (T = a × b) is computed as a large pen-and-paper column// multiplication of two numbers with n base-2^_W digits. If we just// wanted to produce 2n-wide T, we would do//// for i := 0; i < n; i++ {// d := bLimbs[i]// T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)// }//// where d is a digit of the multiplier, T[i:n+i] is the shifted// position of the product of that digit, and T[n+i] is the final carry.// Note that T[i] isn't modified after processing the i-th digit.//// Instead of running two loops, one for Step 1 and one for Steps 2–6,// the result of Step 1 is computed during the next loop. This is// possible because each iteration only uses T[i] in Step 2 and then// discards it in Step 6.:= []:= addMulVVW([:+], , )// Step 6 is replaced by shifting the virtual window we operate// over: T of the algorithm is T[i:] for us. That means that T1 in// Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.:= [] * .m0inv// Step 4 and 5 add Y × m to T, which as mentioned above is stored// at T[i:]. The two carries (from a × d and Y × m) are added up in// the next word T[n+i], and the carry bit from that addition is// brought forward to the next iteration.:= addMulVVW([:+], , )[+], = bits.Add(, , )}// Finally for Step 7 we copy the final T window into x, and subtract m// if necessary (which as explained in maybeSubtractModulus can be the// case both if x >= m, or if x overflowed).//// The paper suggests in Section 4 that we can do an "Almost Montgomery// Multiplication" by subtracting only in the overflow case, but the// cost is very similar since the constant time subtraction tells us if// x >= m as a side effect, and taking care of the broken invariant is// highly undesirable (see https://go.dev/issue/13907).copy(.reset().limbs, [:]).maybeSubtractModulus(choice(), )// The following specialized cases follow the exact same algorithm, but// optimized for the sizes most used in RSA. addMulVVW is implemented in// assembly with loop unrolling depending on the architecture and bounds// checks are removed by the compiler thanks to the constant size.case 1024 / _W:const = 1024 / _W // compiler hint:= make([]uint, *2)var uintfor := 0; < ; ++ {:= []:= addMulVVW1024(&[], &[0], ):= [] * .m0inv:= addMulVVW1024(&[], &[0], )[+], = bits.Add(, , )}copy(.reset().limbs, [:]).maybeSubtractModulus(choice(), )case 1536 / _W:const = 1536 / _W // compiler hint:= make([]uint, *2)var uintfor := 0; < ; ++ {:= []:= addMulVVW1536(&[], &[0], ):= [] * .m0inv:= addMulVVW1536(&[], &[0], )[+], = bits.Add(, , )}copy(.reset().limbs, [:]).maybeSubtractModulus(choice(), )case 2048 / _W:const = 2048 / _W // compiler hint:= make([]uint, *2)var uintfor := 0; < ; ++ {:= []:= addMulVVW2048(&[], &[0], ):= [] * .m0inv:= addMulVVW2048(&[], &[0], )[+], = bits.Add(, , )}copy(.reset().limbs, [:]).maybeSubtractModulus(choice(), )}return}// addMulVVW multiplies the multi-word value x by the single-word value y,// adding the result to the multi-word value z and returning the final carry.// It can be thought of as one row of a pen-and-paper column multiplication.func (, []uint, uint) ( uint) {_ = [len()-1] // bounds check elimination hintfor := range {, := bits.Mul([], ), := bits.Add(, [], 0)// We use bits.Add with zero to get an add-with-carry instruction that// absorbs the carry from the previous bits.Add., _ = bits.Add(, 0, ), = bits.Add(, , 0), _ = bits.Add(, 0, )=[] =}return}// Mul calculates x = x * y mod m.//// The length of both operands must be the same as the modulus. Both operands// must already be reduced modulo m.func ( *Nat) ( *Nat, *Modulus) *Nat {// A Montgomery multiplication by a value out of the Montgomery domain// takes the result out of Montgomery representation.:= NewNat().set().montgomeryRepresentation() // xR = x * R mod mreturn .montgomeryMul(, , ) // x = xR * y / R mod m}// Exp calculates out = x^e mod m.//// The exponent e is represented in big-endian order. The output will be resized// to the size of m and overwritten. x must already be reduced modulo m.func ( *Nat) ( *Nat, []byte, *Modulus) *Nat {// We use a 4 bit window. For our RSA workload, 4 bit windows are faster// than 2 bit windows, but use an extra 12 nats worth of scratch space.// Using bit sizes that don't divide 8 are more complex to implement, but// are likely to be more efficient if necessary.:= [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)// newNat calls are unrolled so they are allocated on the stack.NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),}[0].set().montgomeryRepresentation()for := 1; < len(); ++ {[].montgomeryMul([-1], [0], )}.resetFor().limbs[0] = 1.montgomeryRepresentation():= NewNat().ExpandFor()for , := range {for , := range []int{4, 0} {// Square four times. Optimization note: this can be implemented// more efficiently than with generic Montgomery multiplication..montgomeryMul(, , ).montgomeryMul(, , ).montgomeryMul(, , ).montgomeryMul(, , )// Select x^k in constant time from the table.:= uint(( >> ) & 0b1111)for := range {.assign(ctEq(, uint(+1)), [])}// Multiply by x^k, discarding the result if k = 0..montgomeryMul(, , ).assign(not(ctEq(, 0)), )}}return .montgomeryReduction()}// ExpShort calculates out = x^e mod m.//// The output will be resized to the size of m and overwritten. x must already// be reduced modulo m. This leaks the exact bit size of the exponent.func ( *Nat) ( *Nat, uint, *Modulus) *Nat {:= NewNat().set().montgomeryRepresentation().resetFor().limbs[0] = 1.montgomeryRepresentation()// For short exponents, precomputing a table and using a window like in Exp// doesn't pay off. Instead, we do a simple constant-time conditional// square-and-multiply chain, skipping the initial run of zeroes.:= NewNat().ExpandFor()for := bits.UintSize - bitLen(); < bits.UintSize; ++ {.montgomeryMul(, , ):= ( >> (bits.UintSize - - 1)) & 1.montgomeryMul(, , ).assign(ctEq(, 1), )}return .montgomeryReduction()}
![]() |
The pages are generated with Golds v0.6.7. (GOOS=linux GOARCH=amd64) Golds is a Go 101 project developed by Tapir Liu. PR and bug reports are welcome and can be submitted to the issue list. Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds. |