Source File
nat.go
Belonging Package
crypto/internal/bigmod
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bigmod
import (
)
const (
// _W is the size in bits of our limbs.
_W = bits.UintSize
// _S is the size in bytes of our limbs.
_S = _W / 8
)
// choice represents a constant-time boolean. The value of choice is always
// either 1 or 0. We use an int instead of bool in order to make decisions in
// constant time by turning it into a mask.
type choice uint
func ( choice) choice { return 1 ^ }
const yes = choice(1)
const no = choice(0)
// ctMask is all 1s if on is yes, and all 0s otherwise.
func ( choice) uint { return -uint() }
// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
func (, uint) choice {
// If x != y, then either x - y or y - x will generate a carry.
, := bits.Sub(, , 0)
, := bits.Sub(, , 0)
return not(choice( | ))
}
// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
func (, uint) choice {
// If x < y, then x - y generates a carry.
, := bits.Sub(, , 0)
return not(choice())
}
// Nat represents an arbitrary natural number
//
// Each Nat has an announced length, which is the number of limbs it has stored.
// Operations on this number are allowed to leak this length, but will not leak
// any information about the values contained in those limbs.
type Nat struct {
// limbs is little-endian in base 2^W with W = bits.UintSize.
limbs []uint
}
// preallocTarget is the size in bits of the numbers used to implement the most
// common and most performant RSA key size. It's also enough to cover some of
// the operations of key sizes up to 4096.
const preallocTarget = 2048
const preallocLimbs = (preallocTarget + _W - 1) / _W
// NewNat returns a new nat with a size of zero, just like new(Nat), but with
// the preallocated capacity to hold a number of up to preallocTarget bits.
// NewNat inlines, so the allocation can live on the stack.
func () *Nat {
:= make([]uint, 0, preallocLimbs)
return &Nat{}
}
// expand expands x to n limbs, leaving its value unchanged.
func ( *Nat) ( int) *Nat {
if len(.limbs) > {
panic("bigmod: internal error: shrinking nat")
}
if cap(.limbs) < {
:= make([]uint, )
copy(, .limbs)
.limbs =
return
}
:= .limbs[len(.limbs):]
for := range {
[] = 0
}
.limbs = .limbs[:]
return
}
// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
func ( *Nat) ( int) *Nat {
if cap(.limbs) < {
.limbs = make([]uint, )
return
}
for := range .limbs {
.limbs[] = 0
}
.limbs = .limbs[:]
return
}
// set assigns x = y, optionally resizing x to the appropriate size.
func ( *Nat) ( *Nat) *Nat {
.reset(len(.limbs))
copy(.limbs, .limbs)
return
}
// setBig assigns x = n, optionally resizing n to the appropriate size.
//
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
func ( *Nat) ( *big.Int) *Nat {
:= .Bits()
.reset(len())
for := range {
.limbs[] = uint([])
}
return
}
// Bytes returns x as a zero-extended big-endian byte slice. The size of the
// slice will match the size of m.
//
// x must have the same size as m and it must be reduced modulo m.
func ( *Nat) ( *Modulus) []byte {
:= .Size()
:= make([]byte, )
for , := range .limbs {
for := 0; < _S; ++ {
--
if < 0 {
if == 0 {
break
}
panic("bigmod: modulus is smaller than nat")
}
[] = byte()
>>= 8
}
}
return
}
// SetBytes assigns x = b, where b is a slice of big-endian bytes.
// SetBytes returns an error if b >= m.
//
// The output will be resized to the size of m and overwritten.
func ( *Nat) ( []byte, *Modulus) (*Nat, error) {
if := .setBytes(, ); != nil {
return nil,
}
if .cmpGeq(.nat) == yes {
return nil, errors.New("input overflows the modulus")
}
return , nil
}
// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
// SetOverflowingBytes returns an error if b has a longer bit length than m, but
// reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
//
// The output will be resized to the size of m and overwritten.
func ( *Nat) ( []byte, *Modulus) (*Nat, error) {
if := .setBytes(, ); != nil {
return nil,
}
:= _W - bitLen(.limbs[len(.limbs)-1])
if < .leading {
return nil, errors.New("input overflows the modulus size")
}
.maybeSubtractModulus(no, )
return , nil
}
// bigEndianUint returns the contents of buf interpreted as a
// big-endian encoded uint value.
func ( []byte) uint {
if _W == 64 {
return uint(binary.BigEndian.Uint64())
}
return uint(binary.BigEndian.Uint32())
}
func ( *Nat) ( []byte, *Modulus) error {
.resetFor()
, := len(), 0
for < len(.limbs) && >= _S {
.limbs[] = bigEndianUint([-_S : ])
-= _S
++
}
for := 0; < _W && < len(.limbs) && > 0; += 8 {
.limbs[] |= uint([-1]) <<
--
}
if > 0 {
return errors.New("input overflows the modulus size")
}
return nil
}
// Equal returns 1 if x == y, and 0 otherwise.
//
// Both operands must have the same announced length.
func ( *Nat) ( *Nat) choice {
// Eliminate bounds checks in the loop.
:= len(.limbs)
:= .limbs[:]
:= .limbs[:]
:= yes
for := 0; < ; ++ {
&= ctEq([], [])
}
return
}
// IsZero returns 1 if x == 0, and 0 otherwise.
func ( *Nat) () choice {
// Eliminate bounds checks in the loop.
:= len(.limbs)
:= .limbs[:]
:= yes
for := 0; < ; ++ {
&= ctEq([], 0)
}
return
}
// cmpGeq returns 1 if x >= y, and 0 otherwise.
//
// Both operands must have the same announced length.
func ( *Nat) ( *Nat) choice {
// Eliminate bounds checks in the loop.
:= len(.limbs)
:= .limbs[:]
:= .limbs[:]
var uint
for := 0; < ; ++ {
_, = bits.Sub([], [], )
}
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
return not(choice())
}
// assign sets x <- y if on == 1, and does nothing otherwise.
//
// Both operands must have the same announced length.
func ( *Nat) ( choice, *Nat) *Nat {
// Eliminate bounds checks in the loop.
:= len(.limbs)
:= .limbs[:]
:= .limbs[:]
:= ctMask()
for := 0; < ; ++ {
[] ^= & ([] ^ [])
}
return
}
// add computes x += y and returns the carry.
//
// Both operands must have the same announced length.
func ( *Nat) ( *Nat) ( uint) {
// Eliminate bounds checks in the loop.
:= len(.limbs)
:= .limbs[:]
:= .limbs[:]
for := 0; < ; ++ {
[], = bits.Add([], [], )
}
return
}
// sub computes x -= y. It returns the borrow of the subtraction.
//
// Both operands must have the same announced length.
func ( *Nat) ( *Nat) ( uint) {
// Eliminate bounds checks in the loop.
:= len(.limbs)
:= .limbs[:]
:= .limbs[:]
for := 0; < ; ++ {
[], = bits.Sub([], [], )
}
return
}
// Modulus is used for modular arithmetic, precomputing relevant constants.
//
// Moduli are assumed to be odd numbers. Moduli can also leak the exact
// number of bits needed to store their value, and are stored without padding.
//
// Their actual value is still kept secret.
type Modulus struct {
// The underlying natural number for this modulus.
//
// This will be stored without any padding, and shouldn't alias with any
// other natural number being used.
nat *Nat
leading int // number of leading zeros in the modulus
m0inv uint // -nat.limbs[0]⁻¹ mod _W
rr *Nat // R*R for montgomeryRepresentation
}
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
func ( *Modulus) *Nat {
:= NewNat().ExpandFor()
// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
// most significant limb to 1. We then get to R*R by shifting left by _W
// n + 1 times.
:= len(.limbs)
.limbs[-1] = 1
for := - 1; < 2*; ++ {
.shiftIn(0, ) // x = x * 2^_W mod m
}
return
}
// minusInverseModW computes -x⁻¹ mod _W with x odd.
//
// This operation is used to precompute a constant involved in Montgomery
// multiplication.
func ( uint) uint {
// Every iteration of this loop doubles the least-significant bits of
// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
// for 64 bits (and wastes only one iteration for 32 bits).
//
// See https://crypto.stackexchange.com/a/47496.
:=
for := 0; < 5; ++ {
= * (2 - *)
}
return -
}
// NewModulusFromBig creates a new Modulus from a [big.Int].
//
// The Int must be odd. The number of significant bits (and nothing else) is
// leaked through timing side-channels.
func ( *big.Int) (*Modulus, error) {
if := .Bits(); len() == 0 {
return nil, errors.New("modulus must be >= 0")
} else if [0]&1 != 1 {
return nil, errors.New("modulus must be odd")
}
:= &Modulus{}
.nat = NewNat().setBig()
.leading = _W - bitLen(.nat.limbs[len(.nat.limbs)-1])
.m0inv = minusInverseModW(.nat.limbs[0])
.rr = rr()
return , nil
}
// bitLen is a version of bits.Len that only leaks the bit length of n, but not
// its value. bits.Len and bits.LeadingZeros use a lookup table for the
// low-order bits on some architectures.
func ( uint) int {
var int
// We assume, here and elsewhere, that comparison to zero is constant time
// with respect to different non-zero values.
for != 0 {
++
>>= 1
}
return
}
// Size returns the size of m in bytes.
func ( *Modulus) () int {
return (.BitLen() + 7) / 8
}
// BitLen returns the size of m in bits.
func ( *Modulus) () int {
return len(.nat.limbs)*_W - int(.leading)
}
// Nat returns m as a Nat. The return value must not be written to.
func ( *Modulus) () *Nat {
return .nat
}
// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m.
func ( *Nat) ( uint, *Modulus) *Nat {
:= NewNat().resetFor()
// Eliminate bounds checks in the loop.
:= len(.nat.limbs)
:= .limbs[:]
:= .limbs[:]
:= .nat.limbs[:]
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
// from y. Effectively, it left-shifts x and adds y one bit at a time,
// reducing it every time.
//
// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
// The next iteration (and finally the return line) will use either result
// based on whether 2x + b overflows m.
:= no
for := _W - 1; >= 0; -- {
:= ( >> ) & 1
var uint
:= ctMask()
for := 0; < ; ++ {
:= [] ^ ( & ([] ^ []))
[], = bits.Add(, , )
[], = bits.Sub([], [], )
}
// Like in maybeSubtractModulus, we need the subtraction if either it
// didn't underflow (meaning 2x + b > m) or if computing 2x + b
// overflowed (meaning 2x + b > 2^_W*n > m).
= not(choice()) | choice()
}
return .assign(, )
}
// Mod calculates out = x mod m.
//
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
func ( *Nat) ( *Nat, *Modulus) *Nat {
.resetFor()
// Working our way from the most significant to the least significant limb,
// we can insert each limb at the least significant position, shifting all
// previous limbs left by _W. This way each limb will get shifted by the
// correct number of bits. We can insert at least N - 1 limbs without
// overflowing m. After that, we need to reduce every time we shift.
:= len(.limbs) - 1
// For the first N - 1 limbs we can skip the actual shifting and position
// them at the shifted position, which starts at min(N - 2, i).
:= len(.nat.limbs) - 2
if < {
=
}
for := ; >= 0; -- {
.limbs[] = .limbs[]
--
}
// We shift in the remaining limbs, reducing modulo m each time.
for >= 0 {
.shiftIn(.limbs[], )
--
}
return
}
// ExpandFor ensures x has the right size to work with operations modulo m.
//
// The announced size of x must be smaller than or equal to that of m.
func ( *Nat) ( *Modulus) *Nat {
return .expand(len(.nat.limbs))
}
// resetFor ensures out has the right size to work with operations modulo m.
//
// out is zeroed and may start at any size.
func ( *Nat) ( *Modulus) *Nat {
return .reset(len(.nat.limbs))
}
// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
//
// It can be used to reduce modulo m a value up to 2m - 1, which is a common
// range for results computed by higher level operations.
//
// always is usually a carry that indicates that the operation that produced x
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
//
// x and m operands must have the same announced length.
func ( *Nat) ( choice, *Modulus) {
:= NewNat().set()
:= .sub(.nat)
// We keep the result if x - m didn't underflow (meaning x >= m)
// or if always was set.
:= not(choice()) | choice()
.assign(, )
}
// Sub computes x = x - y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func ( *Nat) ( *Nat, *Modulus) *Nat {
:= .sub()
// If the subtraction underflowed, add m.
:= NewNat().set()
.add(.nat)
.assign(choice(), )
return
}
// Add computes x = x + y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func ( *Nat) ( *Nat, *Modulus) *Nat {
:= .add()
.maybeSubtractModulus(choice(), )
return
}
// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs).
//
// Faster Montgomery multiplication replaces standard modular multiplication for
// numbers in this representation.
//
// This assumes that x is already reduced mod m.
func ( *Nat) ( *Modulus) *Nat {
// A Montgomery multiplication (which computes a * b / R) by R * R works out
// to a multiplication by R, which takes the value out of the Montgomery domain.
return .montgomeryMul(, .rr, )
}
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs).
//
// This assumes that x is already reduced mod m.
func ( *Nat) ( *Modulus) *Nat {
// By Montgomery multiplying with 1 not in Montgomery representation, we
// convert out back from Montgomery representation, because it works out to
// dividing by R.
:= NewNat().ExpandFor()
.limbs[0] = 1
return .montgomeryMul(, , )
}
// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs), also known as a Montgomery multiplication.
//
// All inputs should be the same length and already reduced modulo m.
// x will be resized to the size of m and overwritten.
func ( *Nat) ( *Nat, *Nat, *Modulus) *Nat {
:= len(.nat.limbs)
:= .nat.limbs[:]
:= .limbs[:]
:= .limbs[:]
switch {
default:
// Attempt to use a stack-allocated backing array.
:= make([]uint, 0, preallocLimbs*2)
if cap() < *2 {
= make([]uint, 0, *2)
}
= [:*2]
// This loop implements Word-by-Word Montgomery Multiplication, as
// described in Algorithm 4 (Fig. 3) of "Efficient Software
// Implementations of Modular Exponentiation" by Shay Gueron
// [https://eprint.iacr.org/2011/239.pdf].
var uint
for := 0; < ; ++ {
_ = [+] // bounds check elimination hint
// Step 1 (T = a × b) is computed as a large pen-and-paper column
// multiplication of two numbers with n base-2^_W digits. If we just
// wanted to produce 2n-wide T, we would do
//
// for i := 0; i < n; i++ {
// d := bLimbs[i]
// T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
// }
//
// where d is a digit of the multiplier, T[i:n+i] is the shifted
// position of the product of that digit, and T[n+i] is the final carry.
// Note that T[i] isn't modified after processing the i-th digit.
//
// Instead of running two loops, one for Step 1 and one for Steps 2–6,
// the result of Step 1 is computed during the next loop. This is
// possible because each iteration only uses T[i] in Step 2 and then
// discards it in Step 6.
:= []
:= addMulVVW([:+], , )
// Step 6 is replaced by shifting the virtual window we operate
// over: T of the algorithm is T[i:] for us. That means that T1 in
// Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
:= [] * .m0inv
// Step 4 and 5 add Y × m to T, which as mentioned above is stored
// at T[i:]. The two carries (from a × d and Y × m) are added up in
// the next word T[n+i], and the carry bit from that addition is
// brought forward to the next iteration.
:= addMulVVW([:+], , )
[+], = bits.Add(, , )
}
// Finally for Step 7 we copy the final T window into x, and subtract m
// if necessary (which as explained in maybeSubtractModulus can be the
// case both if x >= m, or if x overflowed).
//
// The paper suggests in Section 4 that we can do an "Almost Montgomery
// Multiplication" by subtracting only in the overflow case, but the
// cost is very similar since the constant time subtraction tells us if
// x >= m as a side effect, and taking care of the broken invariant is
// highly undesirable (see https://go.dev/issue/13907).
copy(.reset().limbs, [:])
.maybeSubtractModulus(choice(), )
// The following specialized cases follow the exact same algorithm, but
// optimized for the sizes most used in RSA. addMulVVW is implemented in
// assembly with loop unrolling depending on the architecture and bounds
// checks are removed by the compiler thanks to the constant size.
case 1024 / _W:
const = 1024 / _W // compiler hint
:= make([]uint, *2)
var uint
for := 0; < ; ++ {
:= []
:= addMulVVW1024(&[], &[0], )
:= [] * .m0inv
:= addMulVVW1024(&[], &[0], )
[+], = bits.Add(, , )
}
copy(.reset().limbs, [:])
.maybeSubtractModulus(choice(), )
case 1536 / _W:
const = 1536 / _W // compiler hint
:= make([]uint, *2)
var uint
for := 0; < ; ++ {
:= []
:= addMulVVW1536(&[], &[0], )
:= [] * .m0inv
:= addMulVVW1536(&[], &[0], )
[+], = bits.Add(, , )
}
copy(.reset().limbs, [:])
.maybeSubtractModulus(choice(), )
case 2048 / _W:
const = 2048 / _W // compiler hint
:= make([]uint, *2)
var uint
for := 0; < ; ++ {
:= []
:= addMulVVW2048(&[], &[0], )
:= [] * .m0inv
:= addMulVVW2048(&[], &[0], )
[+], = bits.Add(, , )
}
copy(.reset().limbs, [:])
.maybeSubtractModulus(choice(), )
}
return
}
// addMulVVW multiplies the multi-word value x by the single-word value y,
// adding the result to the multi-word value z and returning the final carry.
// It can be thought of as one row of a pen-and-paper column multiplication.
func (, []uint, uint) ( uint) {
_ = [len()-1] // bounds check elimination hint
for := range {
, := bits.Mul([], )
, := bits.Add(, [], 0)
// We use bits.Add with zero to get an add-with-carry instruction that
// absorbs the carry from the previous bits.Add.
, _ = bits.Add(, 0, )
, = bits.Add(, , 0)
, _ = bits.Add(, 0, )
=
[] =
}
return
}
// Mul calculates x = x * y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func ( *Nat) ( *Nat, *Modulus) *Nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
:= NewNat().set().montgomeryRepresentation() // xR = x * R mod m
return .montgomeryMul(, , ) // x = xR * y / R mod m
}
// Exp calculates out = x^e mod m.
//
// The exponent e is represented in big-endian order. The output will be resized
// to the size of m and overwritten. x must already be reduced modulo m.
func ( *Nat) ( *Nat, []byte, *Modulus) *Nat {
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
// Using bit sizes that don't divide 8 are more complex to implement, but
// are likely to be more efficient if necessary.
:= [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
// newNat calls are unrolled so they are allocated on the stack.
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
}
[0].set().montgomeryRepresentation()
for := 1; < len(); ++ {
[].montgomeryMul([-1], [0], )
}
.resetFor()
.limbs[0] = 1
.montgomeryRepresentation()
:= NewNat().ExpandFor()
for , := range {
for , := range []int{4, 0} {
// Square four times. Optimization note: this can be implemented
// more efficiently than with generic Montgomery multiplication.
.montgomeryMul(, , )
.montgomeryMul(, , )
.montgomeryMul(, , )
.montgomeryMul(, , )
// Select x^k in constant time from the table.
:= uint(( >> ) & 0b1111)
for := range {
.assign(ctEq(, uint(+1)), [])
}
// Multiply by x^k, discarding the result if k = 0.
.montgomeryMul(, , )
.assign(not(ctEq(, 0)), )
}
}
return .montgomeryReduction()
}
// ExpShort calculates out = x^e mod m.
//
// The output will be resized to the size of m and overwritten. x must already
// be reduced modulo m. This leaks the exact bit size of the exponent.
func ( *Nat) ( *Nat, uint, *Modulus) *Nat {
:= NewNat().set().montgomeryRepresentation()
.resetFor()
.limbs[0] = 1
.montgomeryRepresentation()
// For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple constant-time conditional
// square-and-multiply chain, skipping the initial run of zeroes.
:= NewNat().ExpandFor()
for := bits.UintSize - bitLen(); < bits.UintSize; ++ {
.montgomeryMul(, , )
:= ( >> (bits.UintSize - - 1)) & 1
.montgomeryMul(, , )
.assign(ctEq(, 1), )
}
return .montgomeryReduction()
}
The pages are generated with Golds v0.6.7. (GOOS=linux GOARCH=amd64) Golds is a Go 101 project developed by Tapir Liu. PR and bug reports are welcome and can be submitted to the issue list. Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds. |