rsakey.py (16814B)
1 #!/usr/bin/env python 2 # 3 # Electrum - lightweight Bitcoin client 4 # Copyright (C) 2015 Thomas Voegtlin 5 # 6 # Permission is hereby granted, free of charge, to any person 7 # obtaining a copy of this software and associated documentation files 8 # (the "Software"), to deal in the Software without restriction, 9 # including without limitation the rights to use, copy, modify, merge, 10 # publish, distribute, sublicense, and/or sell copies of the Software, 11 # and to permit persons to whom the Software is furnished to do so, 12 # subject to the following conditions: 13 # 14 # The above copyright notice and this permission notice shall be 15 # included in all copies or substantial portions of the Software. 16 # 17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 # SOFTWARE. 25 26 # This module uses functions from TLSLite (public domain) 27 # 28 # TLSLite Authors: 29 # Trevor Perrin 30 # Martin von Loewis - python 3 port 31 # Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2 32 # 33 34 """Pure-Python RSA implementation.""" 35 36 import os 37 import math 38 import hashlib 39 40 41 def SHA1(x): 42 return hashlib.sha1(x).digest() 43 44 45 # ************************************************************************** 46 # PRNG Functions 47 # ************************************************************************** 48 49 # Check that os.urandom works 50 import zlib 51 length = len(zlib.compress(os.urandom(1000))) 52 assert(length > 900) 53 54 def getRandomBytes(howMany): 55 b = bytearray(os.urandom(howMany)) 56 assert(len(b) == howMany) 57 return b 58 59 prngName = "os.urandom" 60 61 62 # ************************************************************************** 63 # Converter Functions 64 # ************************************************************************** 65 66 def bytesToNumber(b): 67 total = 0 68 multiplier = 1 69 for count in range(len(b)-1, -1, -1): 70 byte = b[count] 71 total += multiplier * byte 72 multiplier *= 256 73 return total 74 75 def numberToByteArray(n, howManyBytes=None): 76 """Convert an integer into a bytearray, zero-pad to howManyBytes. 77 78 The returned bytearray may be smaller than howManyBytes, but will 79 not be larger. The returned bytearray will contain a big-endian 80 encoding of the input integer (n). 81 """ 82 if howManyBytes == None: 83 howManyBytes = numBytes(n) 84 b = bytearray(howManyBytes) 85 for count in range(howManyBytes-1, -1, -1): 86 b[count] = int(n % 256) 87 n >>= 8 88 return b 89 90 def mpiToNumber(mpi): #mpi is an openssl-format bignum string 91 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number 92 raise AssertionError() 93 b = bytearray(mpi[4:]) 94 return bytesToNumber(b) 95 96 def numberToMPI(n): 97 b = numberToByteArray(n) 98 ext = 0 99 #If the high-order bit is going to be set, 100 #add an extra byte of zeros 101 if (numBits(n) & 0x7)==0: 102 ext = 1 103 length = numBytes(n) + ext 104 b = bytearray(4+ext) + b 105 b[0] = (length >> 24) & 0xFF 106 b[1] = (length >> 16) & 0xFF 107 b[2] = (length >> 8) & 0xFF 108 b[3] = length & 0xFF 109 return bytes(b) 110 111 112 # ************************************************************************** 113 # Misc. Utility Functions 114 # ************************************************************************** 115 116 def numBits(n): 117 if n==0: 118 return 0 119 s = "%x" % n 120 return ((len(s)-1)*4) + \ 121 {'0':0, '1':1, '2':2, '3':2, 122 '4':3, '5':3, '6':3, '7':3, 123 '8':4, '9':4, 'a':4, 'b':4, 124 'c':4, 'd':4, 'e':4, 'f':4, 125 }[s[0]] 126 127 def numBytes(n): 128 if n==0: 129 return 0 130 bits = numBits(n) 131 return int(math.ceil(bits / 8.0)) 132 133 # ************************************************************************** 134 # Big Number Math 135 # ************************************************************************** 136 137 def getRandomNumber(low, high): 138 if low >= high: 139 raise AssertionError() 140 howManyBits = numBits(high) 141 howManyBytes = numBytes(high) 142 lastBits = howManyBits % 8 143 while 1: 144 bytes = getRandomBytes(howManyBytes) 145 if lastBits: 146 bytes[0] = bytes[0] % (1 << lastBits) 147 n = bytesToNumber(bytes) 148 if n >= low and n < high: 149 return n 150 151 def gcd(a,b): 152 a, b = max(a,b), min(a,b) 153 while b: 154 a, b = b, a % b 155 return a 156 157 def lcm(a, b): 158 return (a * b) // gcd(a, b) 159 160 #Returns inverse of a mod b, zero if none 161 #Uses Extended Euclidean Algorithm 162 def invMod(a, b): 163 c, d = a, b 164 uc, ud = 1, 0 165 while c != 0: 166 q = d // c 167 c, d = d-(q*c), c 168 uc, ud = ud - (q * uc), uc 169 if d == 1: 170 return ud % b 171 return 0 172 173 174 def powMod(base, power, modulus): 175 if power < 0: 176 result = pow(base, power*-1, modulus) 177 result = invMod(result, modulus) 178 return result 179 else: 180 return pow(base, power, modulus) 181 182 #Pre-calculate a sieve of the ~100 primes < 1000: 183 def makeSieve(n): 184 sieve = list(range(n)) 185 for count in range(2, int(math.sqrt(n))+1): 186 if sieve[count] == 0: 187 continue 188 x = sieve[count] * 2 189 while x < len(sieve): 190 sieve[x] = 0 191 x += sieve[count] 192 sieve = [x for x in sieve[2:] if x] 193 return sieve 194 195 sieve = makeSieve(1000) 196 197 def isPrime(n, iterations=5, display=False): 198 #Trial division with sieve 199 for x in sieve: 200 if x >= n: return True 201 if n % x == 0: return False 202 #Passed trial division, proceed to Rabin-Miller 203 #Rabin-Miller implemented per Ferguson & Schneier 204 #Compute s, t for Rabin-Miller 205 if display: print("*", end=' ') 206 s, t = n-1, 0 207 while s % 2 == 0: 208 s, t = s//2, t+1 209 #Repeat Rabin-Miller x times 210 a = 2 #Use 2 as a base for first iteration speedup, per HAC 211 for count in range(iterations): 212 v = powMod(a, s, n) 213 if v==1: 214 continue 215 i = 0 216 while v != n-1: 217 if i == t-1: 218 return False 219 else: 220 v, i = powMod(v, 2, n), i+1 221 a = getRandomNumber(2, n) 222 return True 223 224 def getRandomPrime(bits, display=False): 225 if bits < 10: 226 raise AssertionError() 227 #The 1.5 ensures the 2 MSBs are set 228 #Thus, when used for p,q in RSA, n will have its MSB set 229 # 230 #Since 30 is lcm(2,3,5), we'll set our test numbers to 231 #29 % 30 and keep them there 232 low = ((2 ** (bits-1)) * 3) // 2 233 high = 2 ** bits - 30 234 p = getRandomNumber(low, high) 235 p += 29 - (p % 30) 236 while 1: 237 if display: print(".", end=' ') 238 p += 30 239 if p >= high: 240 p = getRandomNumber(low, high) 241 p += 29 - (p % 30) 242 if isPrime(p, display=display): 243 return p 244 245 #Unused at the moment... 246 def getRandomSafePrime(bits, display=False): 247 if bits < 10: 248 raise AssertionError() 249 #The 1.5 ensures the 2 MSBs are set 250 #Thus, when used for p,q in RSA, n will have its MSB set 251 # 252 #Since 30 is lcm(2,3,5), we'll set our test numbers to 253 #29 % 30 and keep them there 254 low = (2 ** (bits-2)) * 3//2 255 high = (2 ** (bits-1)) - 30 256 q = getRandomNumber(low, high) 257 q += 29 - (q % 30) 258 while 1: 259 if display: print(".", end=' ') 260 q += 30 261 if (q >= high): 262 q = getRandomNumber(low, high) 263 q += 29 - (q % 30) 264 #Ideas from Tom Wu's SRP code 265 #Do trial division on p and q before Rabin-Miller 266 if isPrime(q, 0, display=display): 267 p = (2 * q) + 1 268 if isPrime(p, display=display): 269 if isPrime(q, display=display): 270 return p 271 272 273 class RSAKey(object): 274 275 def __init__(self, n=0, e=0, d=0, p=0, q=0, dP=0, dQ=0, qInv=0): 276 if (n and not e) or (e and not n): 277 raise AssertionError() 278 self.n = n 279 self.e = e 280 self.d = d 281 self.p = p 282 self.q = q 283 self.dP = dP 284 self.dQ = dQ 285 self.qInv = qInv 286 self.blinder = 0 287 self.unblinder = 0 288 289 def __len__(self): 290 """Return the length of this key in bits. 291 292 @rtype: int 293 """ 294 return numBits(self.n) 295 296 def hasPrivateKey(self): 297 return self.d != 0 298 299 def hashAndSign(self, bytes): 300 """Hash and sign the passed-in bytes. 301 302 This requires the key to have a private component. It performs 303 a PKCS1-SHA1 signature on the passed-in data. 304 305 @type bytes: str or L{bytearray} of unsigned bytes 306 @param bytes: The value which will be hashed and signed. 307 308 @rtype: L{bytearray} of unsigned bytes. 309 @return: A PKCS1-SHA1 signature on the passed-in data. 310 """ 311 hashBytes = SHA1(bytearray(bytes)) 312 prefixedHashBytes = self._addPKCS1SHA1Prefix(hashBytes) 313 sigBytes = self.sign(prefixedHashBytes) 314 return sigBytes 315 316 def hashAndVerify(self, sigBytes, bytes): 317 """Hash and verify the passed-in bytes with the signature. 318 319 This verifies a PKCS1-SHA1 signature on the passed-in data. 320 321 @type sigBytes: L{bytearray} of unsigned bytes 322 @param sigBytes: A PKCS1-SHA1 signature. 323 324 @type bytes: str or L{bytearray} of unsigned bytes 325 @param bytes: The value which will be hashed and verified. 326 327 @rtype: bool 328 @return: Whether the signature matches the passed-in data. 329 """ 330 hashBytes = SHA1(bytearray(bytes)) 331 332 # Try it with/without the embedded NULL 333 prefixedHashBytes1 = self._addPKCS1SHA1Prefix(hashBytes, False) 334 prefixedHashBytes2 = self._addPKCS1SHA1Prefix(hashBytes, True) 335 result1 = self.verify(sigBytes, prefixedHashBytes1) 336 result2 = self.verify(sigBytes, prefixedHashBytes2) 337 return (result1 or result2) 338 339 def sign(self, bytes): 340 """Sign the passed-in bytes. 341 342 This requires the key to have a private component. It performs 343 a PKCS1 signature on the passed-in data. 344 345 @type bytes: L{bytearray} of unsigned bytes 346 @param bytes: The value which will be signed. 347 348 @rtype: L{bytearray} of unsigned bytes. 349 @return: A PKCS1 signature on the passed-in data. 350 """ 351 if not self.hasPrivateKey(): 352 raise AssertionError() 353 paddedBytes = self._addPKCS1Padding(bytes, 1) 354 m = bytesToNumber(paddedBytes) 355 if m >= self.n: 356 raise ValueError() 357 c = self._rawPrivateKeyOp(m) 358 sigBytes = numberToByteArray(c, numBytes(self.n)) 359 return sigBytes 360 361 def verify(self, sigBytes, bytes): 362 """Verify the passed-in bytes with the signature. 363 364 This verifies a PKCS1 signature on the passed-in data. 365 366 @type sigBytes: L{bytearray} of unsigned bytes 367 @param sigBytes: A PKCS1 signature. 368 369 @type bytes: L{bytearray} of unsigned bytes 370 @param bytes: The value which will be verified. 371 372 @rtype: bool 373 @return: Whether the signature matches the passed-in data. 374 """ 375 if len(sigBytes) != numBytes(self.n): 376 return False 377 paddedBytes = self._addPKCS1Padding(bytes, 1) 378 c = bytesToNumber(sigBytes) 379 if c >= self.n: 380 return False 381 m = self._rawPublicKeyOp(c) 382 checkBytes = numberToByteArray(m, numBytes(self.n)) 383 return checkBytes == paddedBytes 384 385 def encrypt(self, bytes): 386 """Encrypt the passed-in bytes. 387 388 This performs PKCS1 encryption of the passed-in data. 389 390 @type bytes: L{bytearray} of unsigned bytes 391 @param bytes: The value which will be encrypted. 392 393 @rtype: L{bytearray} of unsigned bytes. 394 @return: A PKCS1 encryption of the passed-in data. 395 """ 396 paddedBytes = self._addPKCS1Padding(bytes, 2) 397 m = bytesToNumber(paddedBytes) 398 if m >= self.n: 399 raise ValueError() 400 c = self._rawPublicKeyOp(m) 401 encBytes = numberToByteArray(c, numBytes(self.n)) 402 return encBytes 403 404 def decrypt(self, encBytes): 405 """Decrypt the passed-in bytes. 406 407 This requires the key to have a private component. It performs 408 PKCS1 decryption of the passed-in data. 409 410 @type encBytes: L{bytearray} of unsigned bytes 411 @param encBytes: The value which will be decrypted. 412 413 @rtype: L{bytearray} of unsigned bytes or None. 414 @return: A PKCS1 decryption of the passed-in data or None if 415 the data is not properly formatted. 416 """ 417 if not self.hasPrivateKey(): 418 raise AssertionError() 419 if len(encBytes) != numBytes(self.n): 420 return None 421 c = bytesToNumber(encBytes) 422 if c >= self.n: 423 return None 424 m = self._rawPrivateKeyOp(c) 425 decBytes = numberToByteArray(m, numBytes(self.n)) 426 #Check first two bytes 427 if decBytes[0] != 0 or decBytes[1] != 2: 428 return None 429 #Scan through for zero separator 430 for x in range(1, len(decBytes)-1): 431 if decBytes[x]== 0: 432 break 433 else: 434 return None 435 return decBytes[x+1:] #Return everything after the separator 436 437 438 439 440 # ************************************************************************** 441 # Helper Functions for RSA Keys 442 # ************************************************************************** 443 444 def _addPKCS1SHA1Prefix(self, bytes, withNULL=True): 445 # There is a long history of confusion over whether the SHA1 446 # algorithmIdentifier should be encoded with a NULL parameter or 447 # with the parameter omitted. While the original intention was 448 # apparently to omit it, many toolkits went the other way. TLS 1.2 449 # specifies the NULL should be included, and this behavior is also 450 # mandated in recent versions of PKCS #1, and is what tlslite has 451 # always implemented. Anyways, verification code should probably 452 # accept both. However, nothing uses this code yet, so this is 453 # all fairly moot. 454 if not withNULL: 455 prefixBytes = bytearray(\ 456 [0x30,0x1f,0x30,0x07,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x04,0x14]) 457 else: 458 prefixBytes = bytearray(\ 459 [0x30,0x21,0x30,0x09,0x06,0x05,0x2b,0x0e,0x03,0x02,0x1a,0x05,0x00,0x04,0x14]) 460 prefixedBytes = prefixBytes + bytes 461 return prefixedBytes 462 463 def _addPKCS1Padding(self, bytes, blockType): 464 padLength = (numBytes(self.n) - (len(bytes)+3)) 465 if blockType == 1: #Signature padding 466 pad = [0xFF] * padLength 467 elif blockType == 2: #Encryption padding 468 pad = bytearray(0) 469 while len(pad) < padLength: 470 padBytes = getRandomBytes(padLength * 2) 471 pad = [b for b in padBytes if b != 0] 472 pad = pad[:padLength] 473 else: 474 raise AssertionError() 475 476 padding = bytearray([0,blockType] + pad + [0]) 477 paddedBytes = padding + bytes 478 return paddedBytes 479 480 481 482 483 def _rawPrivateKeyOp(self, m): 484 #Create blinding values, on the first pass: 485 if not self.blinder: 486 self.unblinder = getRandomNumber(2, self.n) 487 self.blinder = powMod(invMod(self.unblinder, self.n), self.e, 488 self.n) 489 490 #Blind the input 491 m = (m * self.blinder) % self.n 492 493 #Perform the RSA operation 494 c = self._rawPrivateKeyOpHelper(m) 495 496 #Unblind the output 497 c = (c * self.unblinder) % self.n 498 499 #Update blinding values 500 self.blinder = (self.blinder * self.blinder) % self.n 501 self.unblinder = (self.unblinder * self.unblinder) % self.n 502 503 #Return the output 504 return c 505 506 507 def _rawPrivateKeyOpHelper(self, m): 508 #Non-CRT version 509 #c = powMod(m, self.d, self.n) 510 511 #CRT version (~3x faster) 512 s1 = powMod(m, self.dP, self.p) 513 s2 = powMod(m, self.dQ, self.q) 514 h = ((s1 - s2) * self.qInv) % self.p 515 c = s2 + self.q * h 516 return c 517 518 def _rawPublicKeyOp(self, c): 519 m = powMod(c, self.e, self.n) 520 return m 521 522 def acceptsPassword(self): 523 return False 524 525 def generate(bits): 526 key = RSAKey() 527 p = getRandomPrime(bits//2, False) 528 q = getRandomPrime(bits//2, False) 529 t = lcm(p-1, q-1) 530 key.n = p * q 531 key.e = 65537 532 key.d = invMod(key.e, t) 533 key.p = p 534 key.q = q 535 key.dP = key.d % (p-1) 536 key.dQ = key.d % (q-1) 537 key.qInv = invMod(q, p) 538 return key 539 generate = staticmethod(generate)