561 lines
18 KiB
Python
561 lines
18 KiB
Python
"""Classes representing TLS messages."""
|
|
|
|
from utils.compat import *
|
|
from utils.cryptomath import *
|
|
from errors import *
|
|
from utils.codec import *
|
|
from constants import *
|
|
from X509 import X509
|
|
from X509CertChain import X509CertChain
|
|
|
|
import sha
|
|
import md5
|
|
|
|
class RecordHeader3:
|
|
def __init__(self):
|
|
self.type = 0
|
|
self.version = (0,0)
|
|
self.length = 0
|
|
self.ssl2 = False
|
|
|
|
def create(self, version, type, length):
|
|
self.type = type
|
|
self.version = version
|
|
self.length = length
|
|
return self
|
|
|
|
def write(self):
|
|
w = Writer(5)
|
|
w.add(self.type, 1)
|
|
w.add(self.version[0], 1)
|
|
w.add(self.version[1], 1)
|
|
w.add(self.length, 2)
|
|
return w.bytes
|
|
|
|
def parse(self, p):
|
|
self.type = p.get(1)
|
|
self.version = (p.get(1), p.get(1))
|
|
self.length = p.get(2)
|
|
self.ssl2 = False
|
|
return self
|
|
|
|
class RecordHeader2:
|
|
def __init__(self):
|
|
self.type = 0
|
|
self.version = (0,0)
|
|
self.length = 0
|
|
self.ssl2 = True
|
|
|
|
def parse(self, p):
|
|
if p.get(1)!=128:
|
|
raise SyntaxError()
|
|
self.type = ContentType.handshake
|
|
self.version = (2,0)
|
|
#We don't support 2-byte-length-headers; could be a problem
|
|
self.length = p.get(1)
|
|
return self
|
|
|
|
|
|
class Msg:
|
|
def preWrite(self, trial):
|
|
if trial:
|
|
w = Writer()
|
|
else:
|
|
length = self.write(True)
|
|
w = Writer(length)
|
|
return w
|
|
|
|
def postWrite(self, w, trial):
|
|
if trial:
|
|
return w.index
|
|
else:
|
|
return w.bytes
|
|
|
|
class Alert(Msg):
|
|
def __init__(self):
|
|
self.contentType = ContentType.alert
|
|
self.level = 0
|
|
self.description = 0
|
|
|
|
def create(self, description, level=AlertLevel.fatal):
|
|
self.level = level
|
|
self.description = description
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.setLengthCheck(2)
|
|
self.level = p.get(1)
|
|
self.description = p.get(1)
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self):
|
|
w = Writer(2)
|
|
w.add(self.level, 1)
|
|
w.add(self.description, 1)
|
|
return w.bytes
|
|
|
|
|
|
class HandshakeMsg(Msg):
|
|
def preWrite(self, handshakeType, trial):
|
|
if trial:
|
|
w = Writer()
|
|
w.add(handshakeType, 1)
|
|
w.add(0, 3)
|
|
else:
|
|
length = self.write(True)
|
|
w = Writer(length)
|
|
w.add(handshakeType, 1)
|
|
w.add(length-4, 3)
|
|
return w
|
|
|
|
|
|
class ClientHello(HandshakeMsg):
|
|
def __init__(self, ssl2=False):
|
|
self.contentType = ContentType.handshake
|
|
self.ssl2 = ssl2
|
|
self.client_version = (0,0)
|
|
self.random = createByteArrayZeros(32)
|
|
self.session_id = createByteArraySequence([])
|
|
self.cipher_suites = [] # a list of 16-bit values
|
|
self.certificate_types = [CertificateType.x509]
|
|
self.compression_methods = [] # a list of 8-bit values
|
|
self.srp_username = None # a string
|
|
|
|
def create(self, version, random, session_id, cipher_suites,
|
|
certificate_types=None, srp_username=None):
|
|
self.client_version = version
|
|
self.random = random
|
|
self.session_id = session_id
|
|
self.cipher_suites = cipher_suites
|
|
self.certificate_types = certificate_types
|
|
self.compression_methods = [0]
|
|
self.srp_username = srp_username
|
|
return self
|
|
|
|
def parse(self, p):
|
|
if self.ssl2:
|
|
self.client_version = (p.get(1), p.get(1))
|
|
cipherSpecsLength = p.get(2)
|
|
sessionIDLength = p.get(2)
|
|
randomLength = p.get(2)
|
|
self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
|
|
self.session_id = p.getFixBytes(sessionIDLength)
|
|
self.random = p.getFixBytes(randomLength)
|
|
if len(self.random) < 32:
|
|
zeroBytes = 32-len(self.random)
|
|
self.random = createByteArrayZeros(zeroBytes) + self.random
|
|
self.compression_methods = [0]#Fake this value
|
|
|
|
#We're not doing a stopLengthCheck() for SSLv2, oh well..
|
|
else:
|
|
p.startLengthCheck(3)
|
|
self.client_version = (p.get(1), p.get(1))
|
|
self.random = p.getFixBytes(32)
|
|
self.session_id = p.getVarBytes(1)
|
|
self.cipher_suites = p.getVarList(2, 2)
|
|
self.compression_methods = p.getVarList(1, 1)
|
|
if not p.atLengthCheck():
|
|
totalExtLength = p.get(2)
|
|
soFar = 0
|
|
while soFar != totalExtLength:
|
|
extType = p.get(2)
|
|
extLength = p.get(2)
|
|
if extType == 6:
|
|
self.srp_username = bytesToString(p.getVarBytes(1))
|
|
elif extType == 7:
|
|
self.certificate_types = p.getVarList(1, 1)
|
|
else:
|
|
p.getFixBytes(extLength)
|
|
soFar += 4 + extLength
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
|
|
w.add(self.client_version[0], 1)
|
|
w.add(self.client_version[1], 1)
|
|
w.addFixSeq(self.random, 1)
|
|
w.addVarSeq(self.session_id, 1, 1)
|
|
w.addVarSeq(self.cipher_suites, 2, 2)
|
|
w.addVarSeq(self.compression_methods, 1, 1)
|
|
|
|
extLength = 0
|
|
if self.certificate_types and self.certificate_types != \
|
|
[CertificateType.x509]:
|
|
extLength += 5 + len(self.certificate_types)
|
|
if self.srp_username:
|
|
extLength += 5 + len(self.srp_username)
|
|
if extLength > 0:
|
|
w.add(extLength, 2)
|
|
|
|
if self.certificate_types and self.certificate_types != \
|
|
[CertificateType.x509]:
|
|
w.add(7, 2)
|
|
w.add(len(self.certificate_types)+1, 2)
|
|
w.addVarSeq(self.certificate_types, 1, 1)
|
|
if self.srp_username:
|
|
w.add(6, 2)
|
|
w.add(len(self.srp_username)+1, 2)
|
|
w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
|
|
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
|
|
class ServerHello(HandshakeMsg):
|
|
def __init__(self):
|
|
self.contentType = ContentType.handshake
|
|
self.server_version = (0,0)
|
|
self.random = createByteArrayZeros(32)
|
|
self.session_id = createByteArraySequence([])
|
|
self.cipher_suite = 0
|
|
self.certificate_type = CertificateType.x509
|
|
self.compression_method = 0
|
|
|
|
def create(self, version, random, session_id, cipher_suite,
|
|
certificate_type):
|
|
self.server_version = version
|
|
self.random = random
|
|
self.session_id = session_id
|
|
self.cipher_suite = cipher_suite
|
|
self.certificate_type = certificate_type
|
|
self.compression_method = 0
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
self.server_version = (p.get(1), p.get(1))
|
|
self.random = p.getFixBytes(32)
|
|
self.session_id = p.getVarBytes(1)
|
|
self.cipher_suite = p.get(2)
|
|
self.compression_method = p.get(1)
|
|
if not p.atLengthCheck():
|
|
totalExtLength = p.get(2)
|
|
soFar = 0
|
|
while soFar != totalExtLength:
|
|
extType = p.get(2)
|
|
extLength = p.get(2)
|
|
if extType == 7:
|
|
self.certificate_type = p.get(1)
|
|
else:
|
|
p.getFixBytes(extLength)
|
|
soFar += 4 + extLength
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
|
|
w.add(self.server_version[0], 1)
|
|
w.add(self.server_version[1], 1)
|
|
w.addFixSeq(self.random, 1)
|
|
w.addVarSeq(self.session_id, 1, 1)
|
|
w.add(self.cipher_suite, 2)
|
|
w.add(self.compression_method, 1)
|
|
|
|
extLength = 0
|
|
if self.certificate_type and self.certificate_type != \
|
|
CertificateType.x509:
|
|
extLength += 5
|
|
|
|
if extLength != 0:
|
|
w.add(extLength, 2)
|
|
|
|
if self.certificate_type and self.certificate_type != \
|
|
CertificateType.x509:
|
|
w.add(7, 2)
|
|
w.add(1, 2)
|
|
w.add(self.certificate_type, 1)
|
|
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
class Certificate(HandshakeMsg):
|
|
def __init__(self, certificateType):
|
|
self.certificateType = certificateType
|
|
self.contentType = ContentType.handshake
|
|
self.certChain = None
|
|
|
|
def create(self, certChain):
|
|
self.certChain = certChain
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
if self.certificateType == CertificateType.x509:
|
|
chainLength = p.get(3)
|
|
index = 0
|
|
certificate_list = []
|
|
while index != chainLength:
|
|
certBytes = p.getVarBytes(3)
|
|
x509 = X509()
|
|
x509.parseBinary(certBytes)
|
|
certificate_list.append(x509)
|
|
index += len(certBytes)+3
|
|
if certificate_list:
|
|
self.certChain = X509CertChain(certificate_list)
|
|
elif self.certificateType == CertificateType.cryptoID:
|
|
s = bytesToString(p.getVarBytes(2))
|
|
if s:
|
|
try:
|
|
import cryptoIDlib.CertChain
|
|
except ImportError:
|
|
raise SyntaxError(\
|
|
"cryptoID cert chain received, cryptoIDlib not present")
|
|
self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
|
|
else:
|
|
raise AssertionError()
|
|
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
|
|
if self.certificateType == CertificateType.x509:
|
|
chainLength = 0
|
|
if self.certChain:
|
|
certificate_list = self.certChain.x509List
|
|
else:
|
|
certificate_list = []
|
|
#determine length
|
|
for cert in certificate_list:
|
|
bytes = cert.writeBytes()
|
|
chainLength += len(bytes)+3
|
|
#add bytes
|
|
w.add(chainLength, 3)
|
|
for cert in certificate_list:
|
|
bytes = cert.writeBytes()
|
|
w.addVarSeq(bytes, 1, 3)
|
|
elif self.certificateType == CertificateType.cryptoID:
|
|
if self.certChain:
|
|
bytes = stringToBytes(self.certChain.write())
|
|
else:
|
|
bytes = createByteArraySequence([])
|
|
w.addVarSeq(bytes, 1, 2)
|
|
else:
|
|
raise AssertionError()
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
class CertificateRequest(HandshakeMsg):
|
|
def __init__(self):
|
|
self.contentType = ContentType.handshake
|
|
self.certificate_types = []
|
|
#treat as opaque bytes for now
|
|
self.certificate_authorities = createByteArraySequence([])
|
|
|
|
def create(self, certificate_types, certificate_authorities):
|
|
self.certificate_types = certificate_types
|
|
self.certificate_authorities = certificate_authorities
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
self.certificate_types = p.getVarList(1, 1)
|
|
self.certificate_authorities = p.getVarBytes(2)
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
|
|
trial)
|
|
w.addVarSeq(self.certificate_types, 1, 1)
|
|
w.addVarSeq(self.certificate_authorities, 1, 2)
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
class ServerKeyExchange(HandshakeMsg):
|
|
def __init__(self, cipherSuite):
|
|
self.cipherSuite = cipherSuite
|
|
self.contentType = ContentType.handshake
|
|
self.srp_N = 0L
|
|
self.srp_g = 0L
|
|
self.srp_s = createByteArraySequence([])
|
|
self.srp_B = 0L
|
|
self.signature = createByteArraySequence([])
|
|
|
|
def createSRP(self, srp_N, srp_g, srp_s, srp_B):
|
|
self.srp_N = srp_N
|
|
self.srp_g = srp_g
|
|
self.srp_s = srp_s
|
|
self.srp_B = srp_B
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
self.srp_N = bytesToNumber(p.getVarBytes(2))
|
|
self.srp_g = bytesToNumber(p.getVarBytes(2))
|
|
self.srp_s = p.getVarBytes(1)
|
|
self.srp_B = bytesToNumber(p.getVarBytes(2))
|
|
if self.cipherSuite in CipherSuite.srpRsaSuites:
|
|
self.signature = p.getVarBytes(2)
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
|
|
trial)
|
|
w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
|
|
w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
|
|
w.addVarSeq(self.srp_s, 1, 1)
|
|
w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
|
|
if self.cipherSuite in CipherSuite.srpRsaSuites:
|
|
w.addVarSeq(self.signature, 1, 2)
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
def hash(self, clientRandom, serverRandom):
|
|
oldCipherSuite = self.cipherSuite
|
|
self.cipherSuite = None
|
|
try:
|
|
bytes = clientRandom + serverRandom + self.write()[4:]
|
|
s = bytesToString(bytes)
|
|
return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
|
|
finally:
|
|
self.cipherSuite = oldCipherSuite
|
|
|
|
class ServerHelloDone(HandshakeMsg):
|
|
def __init__(self):
|
|
self.contentType = ContentType.handshake
|
|
|
|
def create(self):
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
class ClientKeyExchange(HandshakeMsg):
|
|
def __init__(self, cipherSuite, version=None):
|
|
self.cipherSuite = cipherSuite
|
|
self.version = version
|
|
self.contentType = ContentType.handshake
|
|
self.srp_A = 0
|
|
self.encryptedPreMasterSecret = createByteArraySequence([])
|
|
|
|
def createSRP(self, srp_A):
|
|
self.srp_A = srp_A
|
|
return self
|
|
|
|
def createRSA(self, encryptedPreMasterSecret):
|
|
self.encryptedPreMasterSecret = encryptedPreMasterSecret
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
if self.cipherSuite in CipherSuite.srpSuites + \
|
|
CipherSuite.srpRsaSuites:
|
|
self.srp_A = bytesToNumber(p.getVarBytes(2))
|
|
elif self.cipherSuite in CipherSuite.rsaSuites:
|
|
if self.version in ((3,1), (3,2)):
|
|
self.encryptedPreMasterSecret = p.getVarBytes(2)
|
|
elif self.version == (3,0):
|
|
self.encryptedPreMasterSecret = \
|
|
p.getFixBytes(len(p.bytes)-p.index)
|
|
else:
|
|
raise AssertionError()
|
|
else:
|
|
raise AssertionError()
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
|
|
trial)
|
|
if self.cipherSuite in CipherSuite.srpSuites + \
|
|
CipherSuite.srpRsaSuites:
|
|
w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
|
|
elif self.cipherSuite in CipherSuite.rsaSuites:
|
|
if self.version in ((3,1), (3,2)):
|
|
w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
|
|
elif self.version == (3,0):
|
|
w.addFixSeq(self.encryptedPreMasterSecret, 1)
|
|
else:
|
|
raise AssertionError()
|
|
else:
|
|
raise AssertionError()
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
class CertificateVerify(HandshakeMsg):
|
|
def __init__(self):
|
|
self.contentType = ContentType.handshake
|
|
self.signature = createByteArraySequence([])
|
|
|
|
def create(self, signature):
|
|
self.signature = signature
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
self.signature = p.getVarBytes(2)
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
|
|
trial)
|
|
w.addVarSeq(self.signature, 1, 2)
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
class ChangeCipherSpec(Msg):
|
|
def __init__(self):
|
|
self.contentType = ContentType.change_cipher_spec
|
|
self.type = 1
|
|
|
|
def create(self):
|
|
self.type = 1
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.setLengthCheck(1)
|
|
self.type = p.get(1)
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = Msg.preWrite(self, trial)
|
|
w.add(self.type,1)
|
|
return Msg.postWrite(self, w, trial)
|
|
|
|
|
|
class Finished(HandshakeMsg):
|
|
def __init__(self, version):
|
|
self.contentType = ContentType.handshake
|
|
self.version = version
|
|
self.verify_data = createByteArraySequence([])
|
|
|
|
def create(self, verify_data):
|
|
self.verify_data = verify_data
|
|
return self
|
|
|
|
def parse(self, p):
|
|
p.startLengthCheck(3)
|
|
if self.version == (3,0):
|
|
self.verify_data = p.getFixBytes(36)
|
|
elif self.version in ((3,1), (3,2)):
|
|
self.verify_data = p.getFixBytes(12)
|
|
else:
|
|
raise AssertionError()
|
|
p.stopLengthCheck()
|
|
return self
|
|
|
|
def write(self, trial=False):
|
|
w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
|
|
w.addFixSeq(self.verify_data, 1)
|
|
return HandshakeMsg.postWrite(self, w, trial)
|
|
|
|
class ApplicationData(Msg):
|
|
def __init__(self):
|
|
self.contentType = ContentType.application_data
|
|
self.bytes = createByteArraySequence([])
|
|
|
|
def create(self, bytes):
|
|
self.bytes = bytes
|
|
return self
|
|
|
|
def parse(self, p):
|
|
self.bytes = p.bytes
|
|
return self
|
|
|
|
def write(self):
|
|
return self.bytes |