electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

x509.py (11489B)


      1 #!/usr/bin/env python
      2 #
      3 # Electrum - lightweight Bitcoin client
      4 # Copyright (C) 2014 Thomas Voegtlin
      5 #
      6 # Permission is hereby granted, free of charge, to any person
      7 # obtaining a copy of this software and associated documentation files
      8 # (the "Software"), to deal in the Software without restriction,
      9 # including without limitation the rights to use, copy, modify, merge,
     10 # publish, distribute, sublicense, and/or sell copies of the Software,
     11 # and to permit persons to whom the Software is furnished to do so,
     12 # subject to the following conditions:
     13 #
     14 # The above copyright notice and this permission notice shall be
     15 # included in all copies or substantial portions of the Software.
     16 #
     17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     24 # SOFTWARE.
     25 
     26 import hashlib
     27 import time
     28 from datetime import datetime
     29 
     30 from . import util
     31 from .util import profiler, bh2u
     32 from .logging import get_logger
     33 
     34 
     35 _logger = get_logger(__name__)
     36 
     37 
     38 # algo OIDs
     39 ALGO_RSA_SHA1 = '1.2.840.113549.1.1.5'
     40 ALGO_RSA_SHA256 = '1.2.840.113549.1.1.11'
     41 ALGO_RSA_SHA384 = '1.2.840.113549.1.1.12'
     42 ALGO_RSA_SHA512 = '1.2.840.113549.1.1.13'
     43 ALGO_ECDSA_SHA256 = '1.2.840.10045.4.3.2'
     44 
     45 # prefixes, see http://stackoverflow.com/questions/3713774/c-sharp-how-to-calculate-asn-1-der-encoding-of-a-particular-hash-algorithm
     46 PREFIX_RSA_SHA256 = bytearray(
     47     [0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20])
     48 PREFIX_RSA_SHA384 = bytearray(
     49     [0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30])
     50 PREFIX_RSA_SHA512 = bytearray(
     51     [0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40])
     52 
     53 # types used in ASN1 structured data
     54 ASN1_TYPES = {
     55     'BOOLEAN'          : 0x01,
     56     'INTEGER'          : 0x02,
     57     'BIT STRING'       : 0x03,
     58     'OCTET STRING'     : 0x04,
     59     'NULL'             : 0x05,
     60     'OBJECT IDENTIFIER': 0x06,
     61     'SEQUENCE'         : 0x70,
     62     'SET'              : 0x71,
     63     'PrintableString'  : 0x13,
     64     'IA5String'        : 0x16,
     65     'UTCTime'          : 0x17,
     66     'GeneralizedTime'  : 0x18,
     67     'ENUMERATED'       : 0x0A,
     68     'UTF8String'       : 0x0C,
     69 }
     70 
     71 
     72 class CertificateError(Exception):
     73     pass
     74 
     75 
     76 # helper functions
     77 def bitstr_to_bytestr(s):
     78     if s[0] != 0x00:
     79         raise TypeError('no padding')
     80     return s[1:]
     81 
     82 
     83 def bytestr_to_int(s):
     84     i = 0
     85     for char in s:
     86         i <<= 8
     87         i |= char
     88     return i
     89 
     90 
     91 def decode_OID(s):
     92     r = []
     93     r.append(s[0] // 40)
     94     r.append(s[0] % 40)
     95     k = 0
     96     for i in s[1:]:
     97         if i < 128:
     98             r.append(i + 128 * k)
     99             k = 0
    100         else:
    101             k = (i - 128) + 128 * k
    102     return '.'.join(map(str, r))
    103 
    104 
    105 def encode_OID(oid):
    106     x = [int(i) for i in oid.split('.')]
    107     s = chr(x[0] * 40 + x[1])
    108     for i in x[2:]:
    109         ss = chr(i % 128)
    110         while i > 128:
    111             i //= 128
    112             ss = chr(128 + i % 128) + ss
    113         s += ss
    114     return s
    115 
    116 
    117 class ASN1_Node(bytes):
    118     def get_node(self, ix):
    119         # return index of first byte, first content byte and last byte.
    120         first = self[ix + 1]
    121         if (first & 0x80) == 0:
    122             length = first
    123             ixf = ix + 2
    124             ixl = ixf + length - 1
    125         else:
    126             lengthbytes = first & 0x7F
    127             length = bytestr_to_int(self[ix + 2:ix + 2 + lengthbytes])
    128             ixf = ix + 2 + lengthbytes
    129             ixl = ixf + length - 1
    130         return ix, ixf, ixl
    131 
    132     def root(self):
    133         return self.get_node(0)
    134 
    135     def next_node(self, node):
    136         ixs, ixf, ixl = node
    137         return self.get_node(ixl + 1)
    138 
    139     def first_child(self, node):
    140         ixs, ixf, ixl = node
    141         if self[ixs] & 0x20 != 0x20:
    142             raise TypeError('Can only open constructed types.', hex(self[ixs]))
    143         return self.get_node(ixf)
    144 
    145     def is_child_of(node1, node2):
    146         ixs, ixf, ixl = node1
    147         jxs, jxf, jxl = node2
    148         return ((ixf <= jxs) and (jxl <= ixl)) or ((jxf <= ixs) and (ixl <= jxl))
    149 
    150     def get_all(self, node):
    151         # return type + length + value
    152         ixs, ixf, ixl = node
    153         return self[ixs:ixl + 1]
    154 
    155     def get_value_of_type(self, node, asn1_type):
    156         # verify type byte and return content
    157         ixs, ixf, ixl = node
    158         if ASN1_TYPES[asn1_type] != self[ixs]:
    159             raise TypeError('Wrong type:', hex(self[ixs]), hex(ASN1_TYPES[asn1_type]))
    160         return self[ixf:ixl + 1]
    161 
    162     def get_value(self, node):
    163         ixs, ixf, ixl = node
    164         return self[ixf:ixl + 1]
    165 
    166     def get_children(self, node):
    167         nodes = []
    168         ii = self.first_child(node)
    169         nodes.append(ii)
    170         while ii[2] < node[2]:
    171             ii = self.next_node(ii)
    172             nodes.append(ii)
    173         return nodes
    174 
    175     def get_sequence(self):
    176         return list(map(lambda j: self.get_value(j), self.get_children(self.root())))
    177 
    178     def get_dict(self, node):
    179         p = {}
    180         for ii in self.get_children(node):
    181             for iii in self.get_children(ii):
    182                 iiii = self.first_child(iii)
    183                 oid = decode_OID(self.get_value_of_type(iiii, 'OBJECT IDENTIFIER'))
    184                 iiii = self.next_node(iiii)
    185                 value = self.get_value(iiii)
    186                 p[oid] = value
    187         return p
    188 
    189     def decode_time(self, ii):
    190         GENERALIZED_TIMESTAMP_FMT = '%Y%m%d%H%M%SZ'
    191         UTCTIME_TIMESTAMP_FMT = '%y%m%d%H%M%SZ'
    192 
    193         try:
    194             return time.strptime(self.get_value_of_type(ii, 'UTCTime').decode('ascii'), UTCTIME_TIMESTAMP_FMT)
    195         except TypeError:
    196             return time.strptime(self.get_value_of_type(ii, 'GeneralizedTime').decode('ascii'), GENERALIZED_TIMESTAMP_FMT)
    197 
    198 class X509(object):
    199     def __init__(self, b):
    200 
    201         self.bytes = bytearray(b)
    202 
    203         der = ASN1_Node(b)
    204         root = der.root()
    205         cert = der.first_child(root)
    206         # data for signature
    207         self.data = der.get_all(cert)
    208 
    209         # optional version field
    210         if der.get_value(cert)[0] == 0xa0:
    211             version = der.first_child(cert)
    212             serial_number = der.next_node(version)
    213         else:
    214             serial_number = der.first_child(cert)
    215         self.serial_number = bytestr_to_int(der.get_value_of_type(serial_number, 'INTEGER'))
    216 
    217         # signature algorithm
    218         sig_algo = der.next_node(serial_number)
    219         ii = der.first_child(sig_algo)
    220         self.sig_algo = decode_OID(der.get_value_of_type(ii, 'OBJECT IDENTIFIER'))
    221 
    222         # issuer
    223         issuer = der.next_node(sig_algo)
    224         self.issuer = der.get_dict(issuer)
    225 
    226         # validity
    227         validity = der.next_node(issuer)
    228         ii = der.first_child(validity)
    229         self.notBefore = der.decode_time(ii)
    230         ii = der.next_node(ii)
    231         self.notAfter = der.decode_time(ii)
    232 
    233         # subject
    234         subject = der.next_node(validity)
    235         self.subject = der.get_dict(subject)
    236         subject_pki = der.next_node(subject)
    237         public_key_algo = der.first_child(subject_pki)
    238         ii = der.first_child(public_key_algo)
    239         self.public_key_algo = decode_OID(der.get_value_of_type(ii, 'OBJECT IDENTIFIER'))
    240 
    241         if self.public_key_algo != '1.2.840.10045.2.1':  # for non EC public key
    242             # pubkey modulus and exponent
    243             subject_public_key = der.next_node(public_key_algo)
    244             spk = der.get_value_of_type(subject_public_key, 'BIT STRING')
    245             spk = ASN1_Node(bitstr_to_bytestr(spk))
    246             r = spk.root()
    247             modulus = spk.first_child(r)
    248             exponent = spk.next_node(modulus)
    249             rsa_n = spk.get_value_of_type(modulus, 'INTEGER')
    250             rsa_e = spk.get_value_of_type(exponent, 'INTEGER')
    251             self.modulus = int.from_bytes(rsa_n, byteorder='big', signed=False)
    252             self.exponent = int.from_bytes(rsa_e, byteorder='big', signed=False)
    253         else:
    254             subject_public_key = der.next_node(public_key_algo)
    255             spk = der.get_value_of_type(subject_public_key, 'BIT STRING')
    256             self.ec_public_key = spk
    257 
    258         # extensions
    259         self.CA = False
    260         self.AKI = None
    261         self.SKI = None
    262         i = subject_pki
    263         while i[2] < cert[2]:
    264             i = der.next_node(i)
    265             d = der.get_dict(i)
    266             for oid, value in d.items():
    267                 value = ASN1_Node(value)
    268                 if oid == '2.5.29.19':
    269                     # Basic Constraints
    270                     self.CA = bool(value)
    271                 elif oid == '2.5.29.14':
    272                     # Subject Key Identifier
    273                     r = value.root()
    274                     value = value.get_value_of_type(r, 'OCTET STRING')
    275                     self.SKI = bh2u(value)
    276                 elif oid == '2.5.29.35':
    277                     # Authority Key Identifier
    278                     self.AKI = bh2u(value.get_sequence()[0])
    279                 else:
    280                     pass
    281 
    282         # cert signature
    283         cert_sig_algo = der.next_node(cert)
    284         ii = der.first_child(cert_sig_algo)
    285         self.cert_sig_algo = decode_OID(der.get_value_of_type(ii, 'OBJECT IDENTIFIER'))
    286         cert_sig = der.next_node(cert_sig_algo)
    287         self.signature = der.get_value(cert_sig)[1:]
    288 
    289     def get_keyID(self):
    290         # http://security.stackexchange.com/questions/72077/validating-an-ssl-certificate-chain-according-to-rfc-5280-am-i-understanding-th
    291         return self.SKI if self.SKI else repr(self.subject)
    292 
    293     def get_issuer_keyID(self):
    294         return self.AKI if self.AKI else repr(self.issuer)
    295 
    296     def get_common_name(self):
    297         return self.subject.get('2.5.4.3', b'unknown').decode()
    298 
    299     def get_signature(self):
    300         return self.cert_sig_algo, self.signature, self.data
    301 
    302     def check_ca(self):
    303         return self.CA
    304 
    305     def check_date(self):
    306         now = time.gmtime()
    307         if self.notBefore > now:
    308             raise CertificateError('Certificate has not entered its valid date range. (%s)' % self.get_common_name())
    309         if self.notAfter <= now:
    310             dt = datetime.utcfromtimestamp(time.mktime(self.notAfter))
    311             raise CertificateError(f'Certificate ({self.get_common_name()}) has expired (at {dt} UTC).')
    312 
    313     def getFingerprint(self):
    314         return hashlib.sha1(self.bytes).digest()
    315 
    316 
    317 @profiler
    318 def load_certificates(ca_path):
    319     from . import pem
    320     ca_list = {}
    321     ca_keyID = {}
    322     # ca_path = '/tmp/tmp.txt'
    323     with open(ca_path, 'r', encoding='utf-8') as f:
    324         s = f.read()
    325     bList = pem.dePemList(s, "CERTIFICATE")
    326     for b in bList:
    327         try:
    328             x = X509(b)
    329             x.check_date()
    330         except BaseException as e:
    331             # with open('/tmp/tmp.txt', 'w') as f:
    332             #     f.write(pem.pem(b, 'CERTIFICATE').decode('ascii'))
    333             _logger.info(f"cert error: {e}")
    334             continue
    335 
    336         fp = x.getFingerprint()
    337         ca_list[fp] = x
    338         ca_keyID[x.get_keyID()] = fp
    339 
    340     return ca_list, ca_keyID
    341 
    342 
    343 if __name__ == "__main__":
    344     import certifi
    345 
    346     ca_path = certifi.where()
    347     ca_list, ca_keyID = load_certificates(ca_path)