electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

rsakey.py (16814B)


      1 #!/usr/bin/env python
      2 #
      3 # Electrum - lightweight Bitcoin client
      4 # Copyright (C) 2015 Thomas Voegtlin
      5 #
      6 # Permission is hereby granted, free of charge, to any person
      7 # obtaining a copy of this software and associated documentation files
      8 # (the "Software"), to deal in the Software without restriction,
      9 # including without limitation the rights to use, copy, modify, merge,
     10 # publish, distribute, sublicense, and/or sell copies of the Software,
     11 # and to permit persons to whom the Software is furnished to do so,
     12 # subject to the following conditions:
     13 #
     14 # The above copyright notice and this permission notice shall be
     15 # included in all copies or substantial portions of the Software.
     16 #
     17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     24 # SOFTWARE.
     25 
     26 # This module uses functions from TLSLite (public domain)
     27 #
     28 # TLSLite Authors:
     29 #   Trevor Perrin
     30 #   Martin von Loewis - python 3 port
     31 #   Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
     32 #
     33 
     34 """Pure-Python RSA implementation."""
     35 
     36 import os
     37 import math
     38 import hashlib
     39 
     40 
     41 def SHA1(x):
     42     return hashlib.sha1(x).digest()
     43 
     44 
     45 # **************************************************************************
     46 # PRNG Functions
     47 # **************************************************************************
     48 
     49 # Check that os.urandom works
     50 import zlib
     51 length = len(zlib.compress(os.urandom(1000)))
     52 assert(length > 900)
     53 
     54 def getRandomBytes(howMany):
     55     b = bytearray(os.urandom(howMany))
     56     assert(len(b) == howMany)
     57     return b
     58 
     59 prngName = "os.urandom"
     60 
     61 
     62 # **************************************************************************
     63 # Converter Functions
     64 # **************************************************************************
     65 
     66 def bytesToNumber(b):
     67     total = 0
     68     multiplier = 1
     69     for count in range(len(b)-1, -1, -1):
     70         byte = b[count]
     71         total += multiplier * byte
     72         multiplier *= 256
     73     return total
     74 
     75 def numberToByteArray(n, howManyBytes=None):
     76     """Convert an integer into a bytearray, zero-pad to howManyBytes.
     77 
     78     The returned bytearray may be smaller than howManyBytes, but will
     79     not be larger.  The returned bytearray will contain a big-endian
     80     encoding of the input integer (n).
     81     """
     82     if howManyBytes == None:
     83         howManyBytes = numBytes(n)
     84     b = bytearray(howManyBytes)
     85     for count in range(howManyBytes-1, -1, -1):
     86         b[count] = int(n % 256)
     87         n >>= 8
     88     return b
     89 
     90 def mpiToNumber(mpi): #mpi is an openssl-format bignum string
     91     if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
     92         raise AssertionError()
     93     b = bytearray(mpi[4:])
     94     return bytesToNumber(b)
     95 
     96 def numberToMPI(n):
     97     b = numberToByteArray(n)
     98     ext = 0
     99     #If the high-order bit is going to be set,
    100     #add an extra byte of zeros
    101     if (numBits(n) & 0x7)==0:
    102         ext = 1
    103     length = numBytes(n) + ext
    104     b = bytearray(4+ext) + b
    105     b[0] = (length >> 24) & 0xFF
    106     b[1] = (length >> 16) & 0xFF
    107     b[2] = (length >> 8) & 0xFF
    108     b[3] = length & 0xFF
    109     return bytes(b)
    110 
    111 
    112 # **************************************************************************
    113 # Misc. Utility Functions
    114 # **************************************************************************
    115 
    116 def numBits(n):
    117     if n==0:
    118         return 0
    119     s = "%x" % n
    120     return ((len(s)-1)*4) + \
    121     {'0':0, '1':1, '2':2, '3':2,
    122      '4':3, '5':3, '6':3, '7':3,
    123      '8':4, '9':4, 'a':4, 'b':4,
    124      'c':4, 'd':4, 'e':4, 'f':4,
    125      }[s[0]]
    126 
    127 def numBytes(n):
    128     if n==0:
    129         return 0
    130     bits = numBits(n)
    131     return int(math.ceil(bits / 8.0))
    132 
    133 # **************************************************************************
    134 # Big Number Math
    135 # **************************************************************************
    136 
    137 def getRandomNumber(low, high):
    138     if low >= high:
    139         raise AssertionError()
    140     howManyBits = numBits(high)
    141     howManyBytes = numBytes(high)
    142     lastBits = howManyBits % 8
    143     while 1:
    144         bytes = getRandomBytes(howManyBytes)
    145         if lastBits:
    146             bytes[0] = bytes[0] % (1 << lastBits)
    147         n = bytesToNumber(bytes)
    148         if n >= low and n < high:
    149             return n
    150 
    151 def gcd(a,b):
    152     a, b = max(a,b), min(a,b)
    153     while b:
    154         a, b = b, a % b
    155     return a
    156 
    157 def lcm(a, b):
    158     return (a * b) // gcd(a, b)
    159 
    160 #Returns inverse of a mod b, zero if none
    161 #Uses Extended Euclidean Algorithm
    162 def invMod(a, b):
    163     c, d = a, b
    164     uc, ud = 1, 0
    165     while c != 0:
    166         q = d // c
    167         c, d = d-(q*c), c
    168         uc, ud = ud - (q * uc), uc
    169     if d == 1:
    170         return ud % b
    171     return 0
    172 
    173 
    174 def powMod(base, power, modulus):
    175     if power < 0:
    176         result = pow(base, power*-1, modulus)
    177         result = invMod(result, modulus)
    178         return result
    179     else:
    180         return pow(base, power, modulus)
    181 
    182 #Pre-calculate a sieve of the ~100 primes < 1000:
    183 def makeSieve(n):
    184     sieve = list(range(n))
    185     for count in range(2, int(math.sqrt(n))+1):
    186         if sieve[count] == 0:
    187             continue
    188         x = sieve[count] * 2
    189         while x < len(sieve):
    190             sieve[x] = 0
    191             x += sieve[count]
    192     sieve = [x for x in sieve[2:] if x]
    193     return sieve
    194 
    195 sieve = makeSieve(1000)
    196 
    197 def isPrime(n, iterations=5, display=False):
    198     #Trial division with sieve
    199     for x in sieve:
    200         if x >= n: return True
    201         if n % x == 0: return False
    202     #Passed trial division, proceed to Rabin-Miller
    203     #Rabin-Miller implemented per Ferguson & Schneier
    204     #Compute s, t for Rabin-Miller
    205     if display: print("*", end=' ')
    206     s, t = n-1, 0
    207     while s % 2 == 0:
    208         s, t = s//2, t+1
    209     #Repeat Rabin-Miller x times
    210     a = 2 #Use 2 as a base for first iteration speedup, per HAC
    211     for count in range(iterations):
    212         v = powMod(a, s, n)
    213         if v==1:
    214             continue
    215         i = 0
    216         while v != n-1:
    217             if i == t-1:
    218                 return False
    219             else:
    220                 v, i = powMod(v, 2, n), i+1
    221         a = getRandomNumber(2, n)
    222     return True
    223 
    224 def getRandomPrime(bits, display=False):
    225     if bits < 10:
    226         raise AssertionError()
    227     #The 1.5 ensures the 2 MSBs are set
    228     #Thus, when used for p,q in RSA, n will have its MSB set
    229     #
    230     #Since 30 is lcm(2,3,5), we'll set our test numbers to
    231     #29 % 30 and keep them there
    232     low = ((2 ** (bits-1)) * 3) // 2
    233     high = 2 ** bits - 30
    234     p = getRandomNumber(low, high)
    235     p += 29 - (p % 30)
    236     while 1:
    237         if display: print(".", end=' ')
    238         p += 30
    239         if p >= high:
    240             p = getRandomNumber(low, high)
    241             p += 29 - (p % 30)
    242         if isPrime(p, display=display):
    243             return p
    244 
    245 #Unused at the moment...
    246 def getRandomSafePrime(bits, display=False):
    247     if bits < 10:
    248         raise AssertionError()
    249     #The 1.5 ensures the 2 MSBs are set
    250     #Thus, when used for p,q in RSA, n will have its MSB set
    251     #
    252     #Since 30 is lcm(2,3,5), we'll set our test numbers to
    253     #29 % 30 and keep them there
    254     low = (2 ** (bits-2)) * 3//2
    255     high = (2 ** (bits-1)) - 30
    256     q = getRandomNumber(low, high)
    257     q += 29 - (q % 30)
    258     while 1:
    259         if display: print(".", end=' ')
    260         q += 30
    261         if (q >= high):
    262             q = getRandomNumber(low, high)
    263             q += 29 - (q % 30)
    264         #Ideas from Tom Wu's SRP code
    265         #Do trial division on p and q before Rabin-Miller
    266         if isPrime(q, 0, display=display):
    267             p = (2 * q) + 1
    268             if isPrime(p, display=display):
    269                 if isPrime(q, display=display):
    270                     return p
    271 
    272 
    273 class RSAKey(object):
    274 
    275     def __init__(self, n=0, e=0, d=0, p=0, q=0, dP=0, dQ=0, qInv=0):
    276         if (n and not e) or (e and not n):
    277             raise AssertionError()
    278         self.n = n
    279         self.e = e
    280         self.d = d
    281         self.p = p
    282         self.q = q
    283         self.dP = dP
    284         self.dQ = dQ
    285         self.qInv = qInv
    286         self.blinder = 0
    287         self.unblinder = 0
    288 
    289     def __len__(self):
    290         """Return the length of this key in bits.
    291 
    292         @rtype: int
    293         """
    294         return numBits(self.n)
    295 
    296     def hasPrivateKey(self):
    297         return self.d != 0
    298 
    299     def hashAndSign(self, bytes):
    300         """Hash and sign the passed-in bytes.
    301 
    302         This requires the key to have a private component.  It performs
    303         a PKCS1-SHA1 signature on the passed-in data.
    304 
    305         @type bytes: str or L{bytearray} of unsigned bytes
    306         @param bytes: The value which will be hashed and signed.
    307 
    308         @rtype: L{bytearray} of unsigned bytes.
    309         @return: A PKCS1-SHA1 signature on the passed-in data.
    310         """
    311         hashBytes = SHA1(bytearray(bytes))
    312         prefixedHashBytes = self._addPKCS1SHA1Prefix(hashBytes)
    313         sigBytes = self.sign(prefixedHashBytes)
    314         return sigBytes
    315 
    316     def hashAndVerify(self, sigBytes, bytes):
    317         """Hash and verify the passed-in bytes with the signature.
    318 
    319         This verifies a PKCS1-SHA1 signature on the passed-in data.
    320 
    321         @type sigBytes: L{bytearray} of unsigned bytes
    322         @param sigBytes: A PKCS1-SHA1 signature.
    323 
    324         @type bytes: str or L{bytearray} of unsigned bytes
    325         @param bytes: The value which will be hashed and verified.
    326 
    327         @rtype: bool
    328         @return: Whether the signature matches the passed-in data.
    329         """
    330         hashBytes = SHA1(bytearray(bytes))
    331 
    332         # Try it with/without the embedded NULL
    333         prefixedHashBytes1 = self._addPKCS1SHA1Prefix(hashBytes, False)
    334         prefixedHashBytes2 = self._addPKCS1SHA1Prefix(hashBytes, True)
    335         result1 = self.verify(sigBytes, prefixedHashBytes1)
    336         result2 = self.verify(sigBytes, prefixedHashBytes2)
    337         return (result1 or result2)
    338 
    339     def sign(self, bytes):
    340         """Sign the passed-in bytes.
    341 
    342         This requires the key to have a private component.  It performs
    343         a PKCS1 signature on the passed-in data.
    344 
    345         @type bytes: L{bytearray} of unsigned bytes
    346         @param bytes: The value which will be signed.
    347 
    348         @rtype: L{bytearray} of unsigned bytes.
    349         @return: A PKCS1 signature on the passed-in data.
    350         """
    351         if not self.hasPrivateKey():
    352             raise AssertionError()
    353         paddedBytes = self._addPKCS1Padding(bytes, 1)
    354         m = bytesToNumber(paddedBytes)
    355         if m >= self.n:
    356             raise ValueError()
    357         c = self._rawPrivateKeyOp(m)
    358         sigBytes = numberToByteArray(c, numBytes(self.n))
    359         return sigBytes
    360 
    361     def verify(self, sigBytes, bytes):
    362         """Verify the passed-in bytes with the signature.
    363 
    364         This verifies a PKCS1 signature on the passed-in data.
    365 
    366         @type sigBytes: L{bytearray} of unsigned bytes
    367         @param sigBytes: A PKCS1 signature.
    368 
    369         @type bytes: L{bytearray} of unsigned bytes
    370         @param bytes: The value which will be verified.
    371 
    372         @rtype: bool
    373         @return: Whether the signature matches the passed-in data.
    374         """
    375         if len(sigBytes) != numBytes(self.n):
    376             return False
    377         paddedBytes = self._addPKCS1Padding(bytes, 1)
    378         c = bytesToNumber(sigBytes)
    379         if c >= self.n:
    380             return False
    381         m = self._rawPublicKeyOp(c)
    382         checkBytes = numberToByteArray(m, numBytes(self.n))
    383         return checkBytes == paddedBytes
    384 
    385     def encrypt(self, bytes):
    386         """Encrypt the passed-in bytes.
    387 
    388         This performs PKCS1 encryption of the passed-in data.
    389 
    390         @type bytes: L{bytearray} of unsigned bytes
    391         @param bytes: The value which will be encrypted.
    392 
    393         @rtype: L{bytearray} of unsigned bytes.
    394         @return: A PKCS1 encryption of the passed-in data.
    395         """
    396         paddedBytes = self._addPKCS1Padding(bytes, 2)
    397         m = bytesToNumber(paddedBytes)
    398         if m >= self.n:
    399             raise ValueError()
    400         c = self._rawPublicKeyOp(m)
    401         encBytes = numberToByteArray(c, numBytes(self.n))
    402         return encBytes
    403 
    404     def decrypt(self, encBytes):
    405         """Decrypt the passed-in bytes.
    406 
    407         This requires the key to have a private component.  It performs
    408         PKCS1 decryption of the passed-in data.
    409 
    410         @type encBytes: L{bytearray} of unsigned bytes
    411         @param encBytes: The value which will be decrypted.
    412 
    413         @rtype: L{bytearray} of unsigned bytes or None.
    414         @return: A PKCS1 decryption of the passed-in data or None if
    415         the data is not properly formatted.
    416         """
    417         if not self.hasPrivateKey():
    418             raise AssertionError()
    419         if len(encBytes) != numBytes(self.n):
    420             return None
    421         c = bytesToNumber(encBytes)
    422         if c >= self.n:
    423             return None
    424         m = self._rawPrivateKeyOp(c)
    425         decBytes = numberToByteArray(m, numBytes(self.n))
    426         #Check first two bytes
    427         if decBytes[0] != 0 or decBytes[1] != 2:
    428             return None
    429         #Scan through for zero separator
    430         for x in range(1, len(decBytes)-1):
    431             if decBytes[x]== 0:
    432                 break
    433         else:
    434             return None
    435         return decBytes[x+1:] #Return everything after the separator
    436 
    437 
    438 
    439 
    440     # **************************************************************************
    441     # Helper Functions for RSA Keys
    442     # **************************************************************************
    443 
    444     def _addPKCS1SHA1Prefix(self, bytes, withNULL=True):
    445         # There is a long history of confusion over whether the SHA1
    446         # algorithmIdentifier should be encoded with a NULL parameter or
    447         # with the parameter omitted.  While the original intention was
    448         # apparently to omit it, many toolkits went the other way.  TLS 1.2
    449         # specifies the NULL should be included, and this behavior is also
    450         # mandated in recent versions of PKCS #1, and is what tlslite has
    451         # always implemented.  Anyways, verification code should probably
    452         # accept both.  However, nothing uses this code yet, so this is
    453         # all fairly moot.
    454         if not withNULL:
    455             prefixBytes = bytearray(\
    456             [0x30,0x1f,0x30,0x07,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x04,0x14])
    457         else:
    458             prefixBytes = bytearray(\
    459             [0x30,0x21,0x30,0x09,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x05,0x00,0x04,0x14])
    460         prefixedBytes = prefixBytes + bytes
    461         return prefixedBytes
    462 
    463     def _addPKCS1Padding(self, bytes, blockType):
    464         padLength = (numBytes(self.n) - (len(bytes)+3))
    465         if blockType == 1: #Signature padding
    466             pad = [0xFF] * padLength
    467         elif blockType == 2: #Encryption padding
    468             pad = bytearray(0)
    469             while len(pad) < padLength:
    470                 padBytes = getRandomBytes(padLength * 2)
    471                 pad = [b for b in padBytes if b != 0]
    472                 pad = pad[:padLength]
    473         else:
    474             raise AssertionError()
    475 
    476         padding = bytearray([0,blockType] + pad + [0])
    477         paddedBytes = padding + bytes
    478         return paddedBytes
    479 
    480 
    481 
    482 
    483     def _rawPrivateKeyOp(self, m):
    484         #Create blinding values, on the first pass:
    485         if not self.blinder:
    486             self.unblinder = getRandomNumber(2, self.n)
    487             self.blinder = powMod(invMod(self.unblinder, self.n), self.e,
    488                                   self.n)
    489 
    490         #Blind the input
    491         m = (m * self.blinder) % self.n
    492 
    493         #Perform the RSA operation
    494         c = self._rawPrivateKeyOpHelper(m)
    495 
    496         #Unblind the output
    497         c = (c * self.unblinder) % self.n
    498 
    499         #Update blinding values
    500         self.blinder = (self.blinder * self.blinder) % self.n
    501         self.unblinder = (self.unblinder * self.unblinder) % self.n
    502 
    503         #Return the output
    504         return c
    505 
    506 
    507     def _rawPrivateKeyOpHelper(self, m):
    508         #Non-CRT version
    509         #c = powMod(m, self.d, self.n)
    510 
    511         #CRT version  (~3x faster)
    512         s1 = powMod(m, self.dP, self.p)
    513         s2 = powMod(m, self.dQ, self.q)
    514         h = ((s1 - s2) * self.qInv) % self.p
    515         c = s2 + self.q * h
    516         return c
    517 
    518     def _rawPublicKeyOp(self, c):
    519         m = powMod(c, self.e, self.n)
    520         return m
    521 
    522     def acceptsPassword(self):
    523         return False
    524 
    525     def generate(bits):
    526         key = RSAKey()
    527         p = getRandomPrime(bits//2, False)
    528         q = getRandomPrime(bits//2, False)
    529         t = lcm(p-1, q-1)
    530         key.n = p * q
    531         key.e = 65537
    532         key.d = invMod(key.e, t)
    533         key.p = p
    534         key.q = q
    535         key.dP = key.d % (p-1)
    536         key.dQ = key.d % (q-1)
    537         key.qInv = invMod(q, p)
    538         return key
    539     generate = staticmethod(generate)