405 lines
11 KiB
Python
405 lines
11 KiB
Python
|
"""cryptomath module
|
||
|
|
||
|
This module has basic math/crypto code."""
|
||
|
|
||
|
import os
|
||
|
import sys
|
||
|
import math
|
||
|
import base64
|
||
|
import binascii
|
||
|
if sys.version_info[:2] <= (2, 4):
|
||
|
from sha import sha as sha1
|
||
|
else:
|
||
|
from hashlib import sha1
|
||
|
|
||
|
from compat import *
|
||
|
|
||
|
|
||
|
# **************************************************************************
|
||
|
# Load Optional Modules
|
||
|
# **************************************************************************
|
||
|
|
||
|
# Try to load M2Crypto/OpenSSL
|
||
|
try:
|
||
|
from M2Crypto import m2
|
||
|
m2cryptoLoaded = True
|
||
|
|
||
|
except ImportError:
|
||
|
m2cryptoLoaded = False
|
||
|
|
||
|
|
||
|
# Try to load cryptlib
|
||
|
try:
|
||
|
import cryptlib_py
|
||
|
try:
|
||
|
cryptlib_py.cryptInit()
|
||
|
except cryptlib_py.CryptException, e:
|
||
|
#If tlslite and cryptoIDlib are both present,
|
||
|
#they might each try to re-initialize this,
|
||
|
#so we're tolerant of that.
|
||
|
if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
|
||
|
raise
|
||
|
cryptlibpyLoaded = True
|
||
|
|
||
|
except ImportError:
|
||
|
cryptlibpyLoaded = False
|
||
|
|
||
|
#Try to load GMPY
|
||
|
try:
|
||
|
import gmpy
|
||
|
gmpyLoaded = True
|
||
|
except ImportError:
|
||
|
gmpyLoaded = False
|
||
|
|
||
|
#Try to load pycrypto
|
||
|
try:
|
||
|
import Crypto.Cipher.AES
|
||
|
pycryptoLoaded = True
|
||
|
except ImportError:
|
||
|
pycryptoLoaded = False
|
||
|
|
||
|
|
||
|
# **************************************************************************
|
||
|
# PRNG Functions
|
||
|
# **************************************************************************
|
||
|
|
||
|
# Get os.urandom PRNG
|
||
|
try:
|
||
|
os.urandom(1)
|
||
|
def getRandomBytes(howMany):
|
||
|
return stringToBytes(os.urandom(howMany))
|
||
|
prngName = "os.urandom"
|
||
|
|
||
|
except:
|
||
|
# Else get cryptlib PRNG
|
||
|
if cryptlibpyLoaded:
|
||
|
def getRandomBytes(howMany):
|
||
|
randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
|
||
|
cryptlib_py.CRYPT_ALGO_AES)
|
||
|
cryptlib_py.cryptSetAttribute(randomKey,
|
||
|
cryptlib_py.CRYPT_CTXINFO_MODE,
|
||
|
cryptlib_py.CRYPT_MODE_OFB)
|
||
|
cryptlib_py.cryptGenerateKey(randomKey)
|
||
|
bytes = createByteArrayZeros(howMany)
|
||
|
cryptlib_py.cryptEncrypt(randomKey, bytes)
|
||
|
return bytes
|
||
|
prngName = "cryptlib"
|
||
|
|
||
|
else:
|
||
|
#Else get UNIX /dev/urandom PRNG
|
||
|
try:
|
||
|
devRandomFile = open("/dev/urandom", "rb")
|
||
|
def getRandomBytes(howMany):
|
||
|
return stringToBytes(devRandomFile.read(howMany))
|
||
|
prngName = "/dev/urandom"
|
||
|
except IOError:
|
||
|
#Else get Win32 CryptoAPI PRNG
|
||
|
try:
|
||
|
import win32prng
|
||
|
def getRandomBytes(howMany):
|
||
|
s = win32prng.getRandomBytes(howMany)
|
||
|
if len(s) != howMany:
|
||
|
raise AssertionError()
|
||
|
return stringToBytes(s)
|
||
|
prngName ="CryptoAPI"
|
||
|
except ImportError:
|
||
|
#Else no PRNG :-(
|
||
|
def getRandomBytes(howMany):
|
||
|
raise NotImplementedError("No Random Number Generator "\
|
||
|
"available.")
|
||
|
prngName = "None"
|
||
|
|
||
|
# **************************************************************************
|
||
|
# Converter Functions
|
||
|
# **************************************************************************
|
||
|
|
||
|
def bytesToNumber(bytes):
|
||
|
total = 0L
|
||
|
multiplier = 1L
|
||
|
for count in range(len(bytes)-1, -1, -1):
|
||
|
byte = bytes[count]
|
||
|
total += multiplier * byte
|
||
|
multiplier *= 256
|
||
|
return total
|
||
|
|
||
|
def numberToBytes(n):
|
||
|
howManyBytes = numBytes(n)
|
||
|
bytes = createByteArrayZeros(howManyBytes)
|
||
|
for count in range(howManyBytes-1, -1, -1):
|
||
|
bytes[count] = int(n % 256)
|
||
|
n >>= 8
|
||
|
return bytes
|
||
|
|
||
|
def bytesToBase64(bytes):
|
||
|
s = bytesToString(bytes)
|
||
|
return stringToBase64(s)
|
||
|
|
||
|
def base64ToBytes(s):
|
||
|
s = base64ToString(s)
|
||
|
return stringToBytes(s)
|
||
|
|
||
|
def numberToBase64(n):
|
||
|
bytes = numberToBytes(n)
|
||
|
return bytesToBase64(bytes)
|
||
|
|
||
|
def base64ToNumber(s):
|
||
|
bytes = base64ToBytes(s)
|
||
|
return bytesToNumber(bytes)
|
||
|
|
||
|
def stringToNumber(s):
|
||
|
bytes = stringToBytes(s)
|
||
|
return bytesToNumber(bytes)
|
||
|
|
||
|
def numberToString(s):
|
||
|
bytes = numberToBytes(s)
|
||
|
return bytesToString(bytes)
|
||
|
|
||
|
def base64ToString(s):
|
||
|
try:
|
||
|
return base64.decodestring(s)
|
||
|
except binascii.Error, e:
|
||
|
raise SyntaxError(e)
|
||
|
except binascii.Incomplete, e:
|
||
|
raise SyntaxError(e)
|
||
|
|
||
|
def stringToBase64(s):
|
||
|
return base64.encodestring(s).replace("\n", "")
|
||
|
|
||
|
def mpiToNumber(mpi): #mpi is an openssl-format bignum string
|
||
|
if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
|
||
|
raise AssertionError()
|
||
|
bytes = stringToBytes(mpi[4:])
|
||
|
return bytesToNumber(bytes)
|
||
|
|
||
|
def numberToMPI(n):
|
||
|
bytes = numberToBytes(n)
|
||
|
ext = 0
|
||
|
#If the high-order bit is going to be set,
|
||
|
#add an extra byte of zeros
|
||
|
if (numBits(n) & 0x7)==0:
|
||
|
ext = 1
|
||
|
length = numBytes(n) + ext
|
||
|
bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
|
||
|
bytes[0] = (length >> 24) & 0xFF
|
||
|
bytes[1] = (length >> 16) & 0xFF
|
||
|
bytes[2] = (length >> 8) & 0xFF
|
||
|
bytes[3] = length & 0xFF
|
||
|
return bytesToString(bytes)
|
||
|
|
||
|
|
||
|
|
||
|
# **************************************************************************
|
||
|
# Misc. Utility Functions
|
||
|
# **************************************************************************
|
||
|
|
||
|
def numBytes(n):
|
||
|
if n==0:
|
||
|
return 0
|
||
|
bits = numBits(n)
|
||
|
return int(math.ceil(bits / 8.0))
|
||
|
|
||
|
def hashAndBase64(s):
|
||
|
return stringToBase64(sha1(s).digest())
|
||
|
|
||
|
def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
|
||
|
bytes = getRandomBytes(numChars)
|
||
|
bytesStr = "".join([chr(b) for b in bytes])
|
||
|
return stringToBase64(bytesStr)[:numChars]
|
||
|
|
||
|
|
||
|
# **************************************************************************
|
||
|
# Big Number Math
|
||
|
# **************************************************************************
|
||
|
|
||
|
def getRandomNumber(low, high):
|
||
|
if low >= high:
|
||
|
raise AssertionError()
|
||
|
howManyBits = numBits(high)
|
||
|
howManyBytes = numBytes(high)
|
||
|
lastBits = howManyBits % 8
|
||
|
while 1:
|
||
|
bytes = getRandomBytes(howManyBytes)
|
||
|
if lastBits:
|
||
|
bytes[0] = bytes[0] % (1 << lastBits)
|
||
|
n = bytesToNumber(bytes)
|
||
|
if n >= low and n < high:
|
||
|
return n
|
||
|
|
||
|
def gcd(a,b):
|
||
|
a, b = max(a,b), min(a,b)
|
||
|
while b:
|
||
|
a, b = b, a % b
|
||
|
return a
|
||
|
|
||
|
def lcm(a, b):
|
||
|
#This will break when python division changes, but we can't use // cause
|
||
|
#of Jython
|
||
|
return (a * b) / gcd(a, b)
|
||
|
|
||
|
#Returns inverse of a mod b, zero if none
|
||
|
#Uses Extended Euclidean Algorithm
|
||
|
def invMod(a, b):
|
||
|
c, d = a, b
|
||
|
uc, ud = 1, 0
|
||
|
while c != 0:
|
||
|
#This will break when python division changes, but we can't use //
|
||
|
#cause of Jython
|
||
|
q = d / c
|
||
|
c, d = d-(q*c), c
|
||
|
uc, ud = ud - (q * uc), uc
|
||
|
if d == 1:
|
||
|
return ud % b
|
||
|
return 0
|
||
|
|
||
|
|
||
|
if gmpyLoaded:
|
||
|
def powMod(base, power, modulus):
|
||
|
base = gmpy.mpz(base)
|
||
|
power = gmpy.mpz(power)
|
||
|
modulus = gmpy.mpz(modulus)
|
||
|
result = pow(base, power, modulus)
|
||
|
return long(result)
|
||
|
|
||
|
else:
|
||
|
#Copied from Bryan G. Olson's post to comp.lang.python
|
||
|
#Does left-to-right instead of pow()'s right-to-left,
|
||
|
#thus about 30% faster than the python built-in with small bases
|
||
|
def powMod(base, power, modulus):
|
||
|
nBitScan = 5
|
||
|
|
||
|
""" Return base**power mod modulus, using multi bit scanning
|
||
|
with nBitScan bits at a time."""
|
||
|
|
||
|
#TREV - Added support for negative exponents
|
||
|
negativeResult = False
|
||
|
if (power < 0):
|
||
|
power *= -1
|
||
|
negativeResult = True
|
||
|
|
||
|
exp2 = 2**nBitScan
|
||
|
mask = exp2 - 1
|
||
|
|
||
|
# Break power into a list of digits of nBitScan bits.
|
||
|
# The list is recursive so easy to read in reverse direction.
|
||
|
nibbles = None
|
||
|
while power:
|
||
|
nibbles = int(power & mask), nibbles
|
||
|
power = power >> nBitScan
|
||
|
|
||
|
# Make a table of powers of base up to 2**nBitScan - 1
|
||
|
lowPowers = [1]
|
||
|
for i in xrange(1, exp2):
|
||
|
lowPowers.append((lowPowers[i-1] * base) % modulus)
|
||
|
|
||
|
# To exponentiate by the first nibble, look it up in the table
|
||
|
nib, nibbles = nibbles
|
||
|
prod = lowPowers[nib]
|
||
|
|
||
|
# For the rest, square nBitScan times, then multiply by
|
||
|
# base^nibble
|
||
|
while nibbles:
|
||
|
nib, nibbles = nibbles
|
||
|
for i in xrange(nBitScan):
|
||
|
prod = (prod * prod) % modulus
|
||
|
if nib: prod = (prod * lowPowers[nib]) % modulus
|
||
|
|
||
|
#TREV - Added support for negative exponents
|
||
|
if negativeResult:
|
||
|
prodInv = invMod(prod, modulus)
|
||
|
#Check to make sure the inverse is correct
|
||
|
if (prod * prodInv) % modulus != 1:
|
||
|
raise AssertionError()
|
||
|
return prodInv
|
||
|
return prod
|
||
|
|
||
|
|
||
|
#Pre-calculate a sieve of the ~100 primes < 1000:
|
||
|
def makeSieve(n):
|
||
|
sieve = range(n)
|
||
|
for count in range(2, int(math.sqrt(n))):
|
||
|
if sieve[count] == 0:
|
||
|
continue
|
||
|
x = sieve[count] * 2
|
||
|
while x < len(sieve):
|
||
|
sieve[x] = 0
|
||
|
x += sieve[count]
|
||
|
sieve = [x for x in sieve[2:] if x]
|
||
|
return sieve
|
||
|
|
||
|
sieve = makeSieve(1000)
|
||
|
|
||
|
def isPrime(n, iterations=5, display=False):
|
||
|
#Trial division with sieve
|
||
|
for x in sieve:
|
||
|
if x >= n: return True
|
||
|
if n % x == 0: return False
|
||
|
#Passed trial division, proceed to Rabin-Miller
|
||
|
#Rabin-Miller implemented per Ferguson & Schneier
|
||
|
#Compute s, t for Rabin-Miller
|
||
|
if display: print "*",
|
||
|
s, t = n-1, 0
|
||
|
while s % 2 == 0:
|
||
|
s, t = s/2, t+1
|
||
|
#Repeat Rabin-Miller x times
|
||
|
a = 2 #Use 2 as a base for first iteration speedup, per HAC
|
||
|
for count in range(iterations):
|
||
|
v = powMod(a, s, n)
|
||
|
if v==1:
|
||
|
continue
|
||
|
i = 0
|
||
|
while v != n-1:
|
||
|
if i == t-1:
|
||
|
return False
|
||
|
else:
|
||
|
v, i = powMod(v, 2, n), i+1
|
||
|
a = getRandomNumber(2, n)
|
||
|
return True
|
||
|
|
||
|
def getRandomPrime(bits, display=False):
|
||
|
if bits < 10:
|
||
|
raise AssertionError()
|
||
|
#The 1.5 ensures the 2 MSBs are set
|
||
|
#Thus, when used for p,q in RSA, n will have its MSB set
|
||
|
#
|
||
|
#Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||
|
#29 % 30 and keep them there
|
||
|
low = (2L ** (bits-1)) * 3/2
|
||
|
high = 2L ** bits - 30
|
||
|
p = getRandomNumber(low, high)
|
||
|
p += 29 - (p % 30)
|
||
|
while 1:
|
||
|
if display: print ".",
|
||
|
p += 30
|
||
|
if p >= high:
|
||
|
p = getRandomNumber(low, high)
|
||
|
p += 29 - (p % 30)
|
||
|
if isPrime(p, display=display):
|
||
|
return p
|
||
|
|
||
|
#Unused at the moment...
|
||
|
def getRandomSafePrime(bits, display=False):
|
||
|
if bits < 10:
|
||
|
raise AssertionError()
|
||
|
#The 1.5 ensures the 2 MSBs are set
|
||
|
#Thus, when used for p,q in RSA, n will have its MSB set
|
||
|
#
|
||
|
#Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||
|
#29 % 30 and keep them there
|
||
|
low = (2 ** (bits-2)) * 3/2
|
||
|
high = (2 ** (bits-1)) - 30
|
||
|
q = getRandomNumber(low, high)
|
||
|
q += 29 - (q % 30)
|
||
|
while 1:
|
||
|
if display: print ".",
|
||
|
q += 30
|
||
|
if (q >= high):
|
||
|
q = getRandomNumber(low, high)
|
||
|
q += 29 - (q % 30)
|
||
|
#Ideas from Tom Wu's SRP code
|
||
|
#Do trial division on p and q before Rabin-Miller
|
||
|
if isPrime(q, 0, display=display):
|
||
|
p = (2 * q) + 1
|
||
|
if isPrime(p, display=display):
|
||
|
if isPrime(q, display=display):
|
||
|
return p
|