LatticeHacks

Polynomials

    Zx.<x> = ZZ[]

This creates a Zx class. A Zx object is a polynomial in x with integer coefficients. For example:

    sage: f = Zx([3,1,4])
    sage: f
    4*x^2 + x + 3

This polynomial f is a sum of three terms: 4*x^2, x, and 3. Each term is an integer coefficient (4, 1, and 3 respectively) times a power of x (x^2, x, and 1 respectively).

Easier version of the same example to copy and paste into Sage:

    f = Zx([3,1,4])
    f
    # output: 4*x^2 + x + 3

Polynomial multiplication

The Zx class has a built-in multiplication method that multiplies each term and adds the results. For example, multiplying f by x, or any power of x, shifts the list of coefficients of f:

    f
    # output: 4*x^2 + x + 3
    f*x
    # output: 4*x^3 + x^2 + 3*x

Another example:

    g = Zx([2,7,1])
    g
    # output: x^2 + 7*x + 2
    f*g
    # output: 4*x^4 + 29*x^3 + 18*x^2 + 23*x + 6

Cyclic convolution

    def convolution(f,g):
      return (f * g) % (x^n-1)

This is the multiplication operation used in NTRU. It is the same as polynomial multiplication but reduces the output "modulo x^n-1": this means replacing x^n with 1, replacing x^(n+1) with x, replacing x^(n+2) with x^2, etc.

The inputs are two n-coefficient polynomials f and g, meaning polynomials with coefficients of 1, x, and so on through x^(n-1). The output is also an n-coefficient polynomial, since x^n etc. have been eliminated. It's possible for some or all of the n output coefficients to be 0 (and, similarly, input coefficients can be 0). Saying n-coefficient doesn't mean that all of 1, x, ..., x^(n-1) appear; it means that x^n, x^(n+1), etc. don't appear.

Example:

    n = 3
    f*g
    # output: 4*x^4 + 29*x^3 + 18*x^2 + 23*x + 6
    convolution(f,g)
    # output: 18*x^2 + 27*x + 35

As another example, convolution of f by powers of x rotates the list of n coefficients of f:

    n = 3
    f
    # output: 4*x^2 + x + 3
    convolution(f,x)
    # output: x^2 + 3*x + 4
    convolution(f,x^2)
    # output: 3*x^2 + 4*x + 1

Beware that n is a global variable. Python exercise: Write a function ring(n) whose output is a class R that is similar to Zx but automatically reduces the results of addition, subtraction, and multiplication modulo x^n-1. Sage already knows how to do this (R = Zx.quotient(x^n-1)) but the exercise is to make your own class starting from Zx, defining __mul__, etc.

Math exercise: Can you figure out two input polynomials where the convolution is 0 (all output coefficients 0)? This is automatic if one input is 0, or if the other input is 0, or both, but can you figure out any other examples?

The name "cyclic convolution" (or "circular convolution") comes from signal processing. Polynomial multiplication is called "acyclic convolution".

Modular reduction

    def balancedmod(f,q):
      g = list(((f[i] + q//2) % q) - q//2 for i in range(n))
      return Zx(g)

There are two inputs to balancedmod: an n-coefficient integer polynomial f; and a positive integer q. The output is the same polynomial except that each coefficient is reduced modulo q. Mathematicians normally define reduction to produce outputs between 0 and q-1, but balancedmod instead produces outputs between -q/2 and q/2: more precisely, between -q/2 and q/2-1 if q is even, or between -(q-1)/2 and (q-1)/2 if q is odd.

For example:

    u = Zx([3,1,4,1,5,9])
    u
    # output: 9*x^5 + 5*x^4 + x^3 + 4*x^2 + x + 3
    n = 7
    balancedmod(u,10)
    # output: -x^5 - 5*x^4 + x^3 + 4*x^2 + x + 3
    balancedmod(u,3)
    # output: -x^4 + x^3 + x^2 + x

Internally, balancedmod uses Sage's % q operation for integers, which always produces outputs between 0 and q-1. balancedmod adjusts the input and output by q//2, which means q/2 rounded down.

Beware that negative inputs to % q typically produce negative results in lower-level languages, so the output of % q leaks the sign of the input (unless the input happens to be divisible by q, in which case the output is 0). This leak can be a serious security problem. Sage also supports a % q operation for polynomials, and a similar leak is sometimes visible in the output:

    u = 314-159*x
    u % 200
    # output: -159*x + 114
    (u - 400) % 200
    # output: -159*x - 86
    (u - 600) % 200
    # output: -159*x + 114
    balancedmod(u,200)
    # output: 41*x - 86

Random polynomials with d nonzero coefficients

    def randomdpoly():
      assert d <= n
      result = n*[0]
      for j in range(d):
        while True:
          r = randrange(n)
          if not result[r]: break
        result[r] = 1-2*randrange(2)
      return Zx(result)

randomdpoly returns an n-coefficient polynomial where exactly d coefficients are nonzero (d are nonzero, the other n-d are all zero). Each nonzero coefficient is either 1 or -1. Beware that d and n are both global variables.

For example:

    n = 7
    d = 5
    f = randomdpoly()
    f
    # output: x^6 + x^5 - x^3 + x^2 - 1
    f = randomdpoly()
    f
    # output: -x^4 + x^3 + x^2 - x + 1

Division modulo primes

    def invertmodprime(f,p):
      T = Zx.change_ring(Integers(p)).quotient(x^n-1)
      return Zx(lift(1 / T(f)))

invertmodprime computes a reciprocal of a polynomial modulo x^n-1 modulo p. There are two inputs: an n-coefficient polynomial f; and a prime number p (for example, 3). The output is an n-coefficient polynomial g so that convolution(f,g) is 1+p*u for some polynomial u. invertmodprime raises an exception if no such g exists.

For example:

    n = 7
    f
    # output: -x^4 + x^3 + x^2 - x + 1
    f3 = invertmodprime(f,3)
    f3
    # output: x^6 + 2*x^4 + x
    convolution(f,f3)
    # output: 3*x^6 - 3*x^5 + 3*x^4 + 1

This convolution is 1+3*u where u is the polynomial x^6 - x^5 + x^4.

The concept of invertmodprime also makes sense for non-primes (see invertmodpowerof2 below), but invertmodprime internally uses some Sage subroutines that aren't smart enough to handle non-primes.

Division modulo powers of 2

    def invertmodpowerof2(f,q):
      assert q.is_power_of(2)
      g = invertmodprime(f,2)
      while True:
        r = balancedmod(convolution(g,f),q)
        if r == 1: return g
        g = balancedmod(convolution(g,2 - r),q)

Just like invertmodprime above, except that the second input q is 2 or 4 or 8 or 16 or ... Example:

    n = 7
    q = 256
    f
    # output: -x^4 + x^3 + x^2 - x + 1
    fq = invertmodpowerof2(f,q)
    convolution(f,fq)
    # output: -256*x^6 + 256*x^4 - 256*x^2 + 257

This convolution is 1+256*u where u is the polynomial -x^6 + x^4 - x^2 + 1.

Math exercise: Watch what's happening internally in invertmodpowerof2, and explain why it works. Hint: How does r relate to the r in the previous loop?

NTRU key generation

    def keypair():
      while True:
        try:
          f = randomdpoly()
          f3 = invertmodprime(f,3)
          fq = invertmodpowerof2(f,q)
          break
        except:
          pass
      g = randomdpoly()
      publickey = balancedmod(3 * convolution(fq,g),q)
      secretkey = f,f3
      return publickey,secretkey

This returns an NTRU public key h and a corresponding secret key f,f3. The public key looks like a random n-coefficient polynomial modulo q. For example, if n = 7 (much too small for security!) and q = 256, then the public key looks like 7 random bytes:

    n = 7
    d = 5
    q = 256
    publickey,secretkey = keypair()
    publickey
    # output: 54*x^6 - 40*x^5 + 90*x^4 + 101*x^3 - 108*x^2 + 80*x + 76

The f part of the secret key is a polynomial with small coefficients. Convolution of f with the public key modulo q produces another polynomial with small coefficients, namely 3 times the g that appeared in key generation:

    f,f3 = secretkey
    f
    # output: -x^6 + x^5 - x^4 + x^2 + 1
    convolution(f,publickey)
    # output: 256*x^6 + 3*x^5 - 3*x^3 - 3*x^2 + 253*x - 253
    balancedmod(_,q)
    # output: 3*x^5 - 3*x^3 - 3*x^2 - 3*x + 3

Messages for encryption

    def randommessage():
      result = list(randrange(3) - 1 for j in range(n))
      return Zx(result)

randommessage returns an n-coefficient polynomial where each coefficient is either 1 or 0 or -1. For example:

    n = 7
    randommessage()
    # output: -x^6 - x^5 + x^4
    randommessage()
    # output: x^6 + x^5 - x^4 - 1
    randommessage()
    # output: -x^4 - x^3 - x + 1
    randommessage()
    # output: -x^6 + x^4 - x^2 + 1

Encryption

    def encrypt(message,publickey):
      r = randomdpoly()
      return balancedmod(convolution(publickey,r) + message,q)

encrypt returns an NTRU ciphertext given a message and a public key h. The ciphertext is h*r+m modulo x^n-1 modulo q, where m is the message and r is random.

Example:

    n = 7
    d = 5
    q = 256
    h,secretkey = keypair()
    h
    # output: -82*x^6 + 118*x^5 - 94*x^4 + 108*x^3 + 70*x^2 - 122*x + 5
    m = randommessage()
    m
    # output: -x^6 - x^4 + x^2 + 1
    c = encrypt(m,h)
    c
    # output: -66*x^6 + 37*x^5 + 115*x^4 - 15*x^3 - 6*x^2 - 89*x + 27

Decryption

    def decrypt(ciphertext,secretkey):
      f,f3 = secretkey
      a = balancedmod(convolution(ciphertext,f),q)
      return balancedmod(convolution(a,f3),3)

decrypt returns a message given an NTRU ciphertext and a secret key.

Example of testing decryption for 10 ciphertexts using 10 different keys with reasonable NTRU parameters:

    n = 743
    d = 495
    q = 2048
    for tests in range(10):
      publickey,secretkey = keypair()
      m = randommessage()
      c = encrypt(m,publickey)
      print m == decrypt(c,secretkey)

    # output: True
    # output: True
    # output: True
    # output: True
    # output: True
    # output: True
    # output: True
    # output: True
    # output: True
    # output: True

Here is an example of the internal calculations in encryption and decryption to illustrate why decryption works:

    n = 7
    d = 5
    q = 256
    h,secretkey = keypair()
    h
    # output: -82*x^6 + 118*x^5 - 94*x^4 + 108*x^3 + 70*x^2 - 122*x + 5
    m = randommessage()
    m
    # output: x^6 + x^5 - x^4 - x^3 + x - 1
    r = randomdpoly()
    r
    # output: -x^6 + x^5 + x^4 + x^3 - x^2
    f = secretkey[0]
    f
    # output: -x^6 - x^5 - x^4 - x^3 - x
    g3 = balancedmod(convolution(f,h),q)
    g3
    # output: -3*x^6 - 3*x^3 + 3*x^2 - 3*x - 3
    c = balancedmod(convolution(h,r) + m,q)
    c
    # output: -93*x^6 - 105*x^5 - 110*x^4 - 95*x^3 - 106*x^2 - 111*x - 95
    a = balancedmod(convolution(f,c),q)
    a
    # output: 3*x^5 - 13*x^4 - 3*x^3 + 2*x^2 - x + 3
    convolution(g3,r) + convolution(f,m)
    # output: 3*x^5 - 13*x^4 - 3*x^3 + 2*x^2 - x + 3
    balancedmod(a,3)
    # output: -x^4 - x^2 - x
    balancedmod(convolution(f,m),3)
    # output: -x^4 - x^2 - x

The idea is that c = h*r+m modulo q, and f*h = 3*g modulo q, so a = f*c = f*h*r+f*m = 3*g*r+f*m modulo q. The polynomials g,r,f,m have small coefficients, so there is a limit on how big the coefficients of 3*g*r+f*m can be. If this limit is small enough compared to q, then a is exactly 3*g*r+f*m, and reducing a modulo 3 is the same as reducing f*m modulo 3. Multiplying by f3 produces m modulo 3.

Beware that decryption can fail. One can standardize n, d, and q that avoid all decryption failures for valid ciphertexts, but an attacker can still deliberately choose invalid ciphertexts to see whether decryption works. It is important for security to take extra steps to protect against these chosen-ciphertext attacks.

An attack example with very small NTRU parameters

The following example uses n = 7, d = 5, and q = 256.

The attacker starts from the public key h, which is 3*g/f. It's simpler to work with g/f, which is computed here as h3:

    h
    # output: -82*x^6 + 118*x^5 - 94*x^4 + 108*x^3 + 70*x^2 - 122*x + 5
    Integers(q)(1/3)
    # output: 171
    h3 = (171*h)%q

Think of the secret g as the secret f times h3. Remember that the secret f is obtained from 1, x, x^2, x^3, x^4, x^5, x^6. by some additions and subtractions. The secret g is correspondingly obtained from the polynomials h3, h3*x, h3*x^2, h3*x^3, h3*x^4, h3*x^5, h3*x^6 by some additions and subtractions. Here are these polynomials:

    h3
    # output: 58*x^6 + 210*x^5 + 54*x^4 + 36*x^3 + 194*x^2 + 130*x + 87
    convolution(h3,x)
    # output: 210*x^6 + 54*x^5 + 36*x^4 + 194*x^3 + 130*x^2 + 87*x + 58
    convolution(h3,x^2)
    # output: 54*x^6 + 36*x^5 + 194*x^4 + 130*x^3 + 87*x^2 + 58*x + 210
    convolution(h3,x^3)
    # output: 36*x^6 + 194*x^5 + 130*x^4 + 87*x^3 + 58*x^2 + 210*x + 54
    convolution(h3,x^4)
    # output: 194*x^6 + 130*x^5 + 87*x^4 + 58*x^3 + 210*x^2 + 54*x + 36
    convolution(h3,x^5)
    # output: 130*x^6 + 87*x^5 + 58*x^4 + 210*x^3 + 54*x^2 + 36*x + 194
    convolution(h3,x^6)
    # output: 87*x^6 + 58*x^5 + 210*x^4 + 54*x^3 + 36*x^2 + 194*x + 130

Actually, g is f*h3 modulo q: q can be added to or subtracted from any coefficient. This means that g is obtained as a combination of the polynomials q, q*x, q*x^2, q*x^3, q*x^4, q*x^5, q*x^6, h3, h3*x, h3*x^2, h3*x^3, h3*x^4, h3*x^5, h3*x^6.

Finally, concatenating the coefficients of g and f produces a combination of the rows of the following matrix:

    M = matrix(2*n)
    for i in range(n): M[i,i] = q
    for i in range(n,2*n): M[i,i] = 1
    for i in range(n):
      for j in range(n):
        M[i+n,j] = convolution(h3,x^i)[j]

    M
    # output: [256   0   0   0   0   0   0   0   0   0   0   0   0   0]
    # output: [  0 256   0   0   0   0   0   0   0   0   0   0   0   0]
    # output: [  0   0 256   0   0   0   0   0   0   0   0   0   0   0]
    # output: [  0   0   0 256   0   0   0   0   0   0   0   0   0   0]
    # output: [  0   0   0   0 256   0   0   0   0   0   0   0   0   0]
    # output: [  0   0   0   0   0 256   0   0   0   0   0   0   0   0]
    # output: [  0   0   0   0   0   0 256   0   0   0   0   0   0   0]
    # output: [ 87 130 194  36  54 210  58   1   0   0   0   0   0   0]
    # output: [ 58  87 130 194  36  54 210   0   1   0   0   0   0   0]
    # output: [210  58  87 130 194  36  54   0   0   1   0   0   0   0]
    # output: [ 54 210  58  87 130 194  36   0   0   0   1   0   0   0]
    # output: [ 36  54 210  58  87 130 194   0   0   0   0   1   0   0]
    # output: [194  36  54 210  58  87 130   0   0   0   0   0   1   0]
    # output: [130 194  36  54 210  58  87   0   0   0   0   0   0   1]

The LLL algorithm quickly finds short combinations of these rows:

    M.LLL()
    # output: [ -1  -1   1  -1   1   0   0  -1   1  -1  -1   1   0   0]
    # output: [  0  -1  -1   1  -1   1   0   0  -1   1  -1  -1   1   0]
    # output: [  1  -1   1  -1   0   0   1  -1   1   1  -1   0   0   1]
    # output: [ -1   1  -1   0   0   1   1   1   1  -1   0   0   1  -1]
    # output: [  1  -1   0   0   1   1  -1   1  -1   0   0   1  -1   1]
    # output: [  1   1   1   1   1   1   1   1   1   1   1   1   1   1]
    # output: [  0   0   1   1  -1   1  -1   0   0   1  -1   1   1  -1]
    # output: [ 39 -28  19  12  11 -48  -4  47   6 -31 -20 -19  36 -18]
    # output: [ -5 -34 -14  -3   9 -39 -43  47  54  22   1 -17  19   1]
    # output: [  4 -39  28 -19 -12 -11  48  18 -47  -6  31  20  19 -36]
    # output: [  9 -40 -43  -5 -32 -13  -1 -17  20   1  47  54  23   3]
    # output: [ -1   9 -40 -43  -5 -32 -13   3 -17  20   1  47  54  23]
    # output: [ 14   3  -9  40  43   4  32 -22  -3  17 -18  -1 -48 -54]
    # output: [ 28 -19 -12 -11  48   4 -39  -6  31  20  19 -36  18 -47]

The first row is the secret g followed by the secret f. Actually, it turns out to be the negative of g followed by the negative of f:

    M.LLL()[0][n:]
    # output: (-1, 1, -1, -1, 1, 0, 0)
    Zx(list(_))
    # output: x^4 - x^3 - x^2 + x - 1
    f
    # output: -x^4 + x^3 + x^2 - x + 1

But the attacker can simply use the decryption algorithm without worrying about this negation. Similarly, LLL could have produced (for example) x*g and x*f, but this would again have worked for decryption. NTRU needs much larger n for security.

Automating the attack

    def attack(publickey):
      recip3 = lift(1/Integers(q)(3))
      publickeyover3 = balancedmod(recip3 * publickey,q)
      M = matrix(2 * n)
      for i in range(n):
        M[i,i] = q
      for i in range(n):
        M[i+n,i+n] = 1
        c = convolution(x^i,publickeyover3)
        for j in range(n):
          M[i+n,j] = c[j]
      M = M.LLL()
      for j in range(2 * n):
        try:
          f = Zx(list(M[j][n:]))
          f3 = invertmodprime(f,3)
          return (f,f3)
        except:
          pass
      return (f,f)

    n = 120
    q = 2^32
    d = 81

    publickey,secretkey = keypair()
    donald = attack(publickey)
    print donald[0]
    try:
      m = randommessage()
      c = encrypt(m,publickey)
      assert decrypt(c,donald) == m
      print 'attack successfully decrypts'
    except:
      print 'attack was unsuccessful'

Version: This is version 2017.12.28 of the "NTRU" web page.