electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

coinchooser.py (22774B)


      1 #!/usr/bin/env python
      2 #
      3 # Electrum - lightweight Bitcoin client
      4 # Copyright (C) 2015 kyuupichan@gmail
      5 #
      6 # Permission is hereby granted, free of charge, to any person
      7 # obtaining a copy of this software and associated documentation files
      8 # (the "Software"), to deal in the Software without restriction,
      9 # including without limitation the rights to use, copy, modify, merge,
     10 # publish, distribute, sublicense, and/or sell copies of the Software,
     11 # and to permit persons to whom the Software is furnished to do so,
     12 # subject to the following conditions:
     13 #
     14 # The above copyright notice and this permission notice shall be
     15 # included in all copies or substantial portions of the Software.
     16 #
     17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     24 # SOFTWARE.
     25 from collections import defaultdict
     26 from math import floor, log10
     27 from typing import NamedTuple, List, Callable, Sequence, Union, Dict, Tuple
     28 from decimal import Decimal
     29 
     30 from .bitcoin import sha256, COIN, is_address
     31 from .transaction import Transaction, TxOutput, PartialTransaction, PartialTxInput, PartialTxOutput
     32 from .util import NotEnoughFunds
     33 from .logging import Logger
     34 
     35 
     36 # A simple deterministic PRNG.  Used to deterministically shuffle a
     37 # set of coins - the same set of coins should produce the same output.
     38 # Although choosing UTXOs "randomly" we want it to be deterministic,
     39 # so if sending twice from the same UTXO set we choose the same UTXOs
     40 # to spend.  This prevents attacks on users by malicious or stale
     41 # servers.
     42 class PRNG:
     43     def __init__(self, seed):
     44         self.sha = sha256(seed)
     45         self.pool = bytearray()
     46 
     47     def get_bytes(self, n: int) -> bytes:
     48         while len(self.pool) < n:
     49             self.pool.extend(self.sha)
     50             self.sha = sha256(self.sha)
     51         result, self.pool = self.pool[:n], self.pool[n:]
     52         return bytes(result)
     53 
     54     def randint(self, start, end):
     55         # Returns random integer in [start, end)
     56         n = end - start
     57         r = 0
     58         p = 1
     59         while p < n:
     60             r = self.get_bytes(1)[0] + (r << 8)
     61             p = p << 8
     62         return start + (r % n)
     63 
     64     def choice(self, seq):
     65         return seq[self.randint(0, len(seq))]
     66 
     67     def shuffle(self, x):
     68         for i in reversed(range(1, len(x))):
     69             # pick an element in x[:i+1] with which to exchange x[i]
     70             j = self.randint(0, i+1)
     71             x[i], x[j] = x[j], x[i]
     72 
     73 
     74 class Bucket(NamedTuple):
     75     desc: str
     76     weight: int                   # as in BIP-141
     77     value: int                    # in satoshis
     78     effective_value: int          # estimate of value left after subtracting fees. in satoshis
     79     coins: List[PartialTxInput]   # UTXOs
     80     min_height: int               # min block height where a coin was confirmed
     81     witness: bool                 # whether any coin uses segwit
     82 
     83 
     84 class ScoredCandidate(NamedTuple):
     85     penalty: float
     86     tx: PartialTransaction
     87     buckets: List[Bucket]
     88 
     89 
     90 def strip_unneeded(bkts: List[Bucket], sufficient_funds) -> List[Bucket]:
     91     '''Remove buckets that are unnecessary in achieving the spend amount'''
     92     if sufficient_funds([], bucket_value_sum=0):
     93         # none of the buckets are needed
     94         return []
     95     bkts = sorted(bkts, key=lambda bkt: bkt.value, reverse=True)
     96     bucket_value_sum = 0
     97     for i in range(len(bkts)):
     98         bucket_value_sum += (bkts[i]).value
     99         if sufficient_funds(bkts[:i+1], bucket_value_sum=bucket_value_sum):
    100             return bkts[:i+1]
    101     raise Exception("keeping all buckets is still not enough")
    102 
    103 
    104 class CoinChooserBase(Logger):
    105 
    106     def __init__(self, *, enable_output_value_rounding: bool):
    107         Logger.__init__(self)
    108         self.enable_output_value_rounding = enable_output_value_rounding
    109 
    110     def keys(self, coins: Sequence[PartialTxInput]) -> Sequence[str]:
    111         raise NotImplementedError
    112 
    113     def bucketize_coins(self, coins: Sequence[PartialTxInput], *, fee_estimator_vb):
    114         keys = self.keys(coins)
    115         buckets = defaultdict(list)  # type: Dict[str, List[PartialTxInput]]
    116         for key, coin in zip(keys, coins):
    117             buckets[key].append(coin)
    118         # fee_estimator returns fee to be paid, for given vbytes.
    119         # guess whether it is just returning a constant as follows.
    120         constant_fee = fee_estimator_vb(2000) == fee_estimator_vb(200)
    121 
    122         def make_Bucket(desc: str, coins: List[PartialTxInput]):
    123             witness = any(coin.is_segwit(guess_for_address=True) for coin in coins)
    124             # note that we're guessing whether the tx uses segwit based
    125             # on this single bucket
    126             weight = sum(Transaction.estimated_input_weight(coin, witness)
    127                          for coin in coins)
    128             value = sum(coin.value_sats() for coin in coins)
    129             min_height = min(coin.block_height for coin in coins)
    130             assert min_height is not None
    131             # the fee estimator is typically either a constant or a linear function,
    132             # so the "function:" effective_value(bucket) will be homomorphic for addition
    133             # i.e. effective_value(b1) + effective_value(b2) = effective_value(b1 + b2)
    134             if constant_fee:
    135                 effective_value = value
    136             else:
    137                 # when converting from weight to vBytes, instead of rounding up,
    138                 # keep fractional part, to avoid overestimating fee
    139                 fee = fee_estimator_vb(Decimal(weight) / 4)
    140                 effective_value = value - fee
    141             return Bucket(desc=desc,
    142                           weight=weight,
    143                           value=value,
    144                           effective_value=effective_value,
    145                           coins=coins,
    146                           min_height=min_height,
    147                           witness=witness)
    148 
    149         return list(map(make_Bucket, buckets.keys(), buckets.values()))
    150 
    151     def penalty_func(self, base_tx, *,
    152                      tx_from_buckets: Callable[[List[Bucket]], Tuple[PartialTransaction, List[PartialTxOutput]]]) \
    153             -> Callable[[List[Bucket]], ScoredCandidate]:
    154         raise NotImplementedError
    155 
    156     def _change_amounts(self, tx: PartialTransaction, count: int, fee_estimator_numchange) -> List[int]:
    157         # Break change up if bigger than max_change
    158         output_amounts = [o.value for o in tx.outputs()]
    159         # Don't split change of less than 0.02 BTC
    160         max_change = max(max(output_amounts) * 1.25, 0.02 * COIN)
    161 
    162         # Use N change outputs
    163         for n in range(1, count + 1):
    164             # How much is left if we add this many change outputs?
    165             change_amount = max(0, tx.get_fee() - fee_estimator_numchange(n))
    166             if change_amount // n <= max_change:
    167                 break
    168 
    169         # Get a handle on the precision of the output amounts; round our
    170         # change to look similar
    171         def trailing_zeroes(val):
    172             s = str(val)
    173             return len(s) - len(s.rstrip('0'))
    174 
    175         zeroes = [trailing_zeroes(i) for i in output_amounts]
    176         min_zeroes = min(zeroes)
    177         max_zeroes = max(zeroes)
    178 
    179         if n > 1:
    180             zeroes = range(max(0, min_zeroes - 1), (max_zeroes + 1) + 1)
    181         else:
    182             # if there is only one change output, this will ensure that we aim
    183             # to have one that is exactly as precise as the most precise output
    184             zeroes = [min_zeroes]
    185 
    186         # Calculate change; randomize it a bit if using more than 1 output
    187         remaining = change_amount
    188         amounts = []
    189         while n > 1:
    190             average = remaining / n
    191             amount = self.p.randint(int(average * 0.7), int(average * 1.3))
    192             precision = min(self.p.choice(zeroes), int(floor(log10(amount))))
    193             amount = int(round(amount, -precision))
    194             amounts.append(amount)
    195             remaining -= amount
    196             n -= 1
    197 
    198         # Last change output.  Round down to maximum precision but lose
    199         # no more than 10**max_dp_to_round_for_privacy
    200         # e.g. a max of 2 decimal places means losing 100 satoshis to fees
    201         max_dp_to_round_for_privacy = 2 if self.enable_output_value_rounding else 0
    202         N = int(pow(10, min(max_dp_to_round_for_privacy, zeroes[0])))
    203         amount = (remaining // N) * N
    204         amounts.append(amount)
    205 
    206         assert sum(amounts) <= change_amount
    207 
    208         return amounts
    209 
    210     def _change_outputs(self, tx: PartialTransaction, change_addrs, fee_estimator_numchange,
    211                         dust_threshold) -> List[PartialTxOutput]:
    212         amounts = self._change_amounts(tx, len(change_addrs), fee_estimator_numchange)
    213         assert min(amounts) >= 0
    214         assert len(change_addrs) >= len(amounts)
    215         assert all([isinstance(amt, int) for amt in amounts])
    216         # If change is above dust threshold after accounting for the
    217         # size of the change output, add it to the transaction.
    218         amounts = [amount for amount in amounts if amount >= dust_threshold]
    219         change = [PartialTxOutput.from_address_and_value(addr, amount)
    220                   for addr, amount in zip(change_addrs, amounts)]
    221         return change
    222 
    223     def _construct_tx_from_selected_buckets(self, *, buckets: Sequence[Bucket],
    224                                             base_tx: PartialTransaction, change_addrs,
    225                                             fee_estimator_w, dust_threshold,
    226                                             base_weight) -> Tuple[PartialTransaction, List[PartialTxOutput]]:
    227         # make a copy of base_tx so it won't get mutated
    228         tx = PartialTransaction.from_io(base_tx.inputs()[:], base_tx.outputs()[:])
    229 
    230         tx.add_inputs([coin for b in buckets for coin in b.coins])
    231         tx_weight = self._get_tx_weight(buckets, base_weight=base_weight)
    232 
    233         # change is sent back to sending address unless specified
    234         if not change_addrs:
    235             change_addrs = [tx.inputs()[0].address]
    236             # note: this is not necessarily the final "first input address"
    237             # because the inputs had not been sorted at this point
    238             assert is_address(change_addrs[0])
    239 
    240         # This takes a count of change outputs and returns a tx fee
    241         output_weight = 4 * Transaction.estimated_output_size_for_address(change_addrs[0])
    242         fee_estimator_numchange = lambda count: fee_estimator_w(tx_weight + count * output_weight)
    243         change = self._change_outputs(tx, change_addrs, fee_estimator_numchange, dust_threshold)
    244         tx.add_outputs(change)
    245 
    246         return tx, change
    247 
    248     def _get_tx_weight(self, buckets: Sequence[Bucket], *, base_weight: int) -> int:
    249         """Given a collection of buckets, return the total weight of the
    250         resulting transaction.
    251         base_weight is the weight of the tx that includes the fixed (non-change)
    252         outputs and potentially some fixed inputs. Note that the change outputs
    253         at this point are not yet known so they are NOT accounted for.
    254         """
    255         total_weight = base_weight + sum(bucket.weight for bucket in buckets)
    256         is_segwit_tx = any(bucket.witness for bucket in buckets)
    257         if is_segwit_tx:
    258             total_weight += 2  # marker and flag
    259             # non-segwit inputs were previously assumed to have
    260             # a witness of '' instead of '00' (hex)
    261             # note that mixed legacy/segwit buckets are already ok
    262             num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
    263                                     for bucket in buckets)
    264             total_weight += num_legacy_inputs
    265 
    266         return total_weight
    267 
    268     def make_tx(self, *, coins: Sequence[PartialTxInput], inputs: List[PartialTxInput],
    269                 outputs: List[PartialTxOutput], change_addrs: Sequence[str],
    270                 fee_estimator_vb: Callable, dust_threshold: int) -> PartialTransaction:
    271         """Select unspent coins to spend to pay outputs.  If the change is
    272         greater than dust_threshold (after adding the change output to
    273         the transaction) it is kept, otherwise none is sent and it is
    274         added to the transaction fee.
    275 
    276         `inputs` and `outputs` are guaranteed to be a subset of the
    277         inputs and outputs of the resulting transaction.
    278         `coins` are further UTXOs we can choose from.
    279 
    280         Note: fee_estimator_vb expects virtual bytes
    281         """
    282         assert outputs, 'tx outputs cannot be empty'
    283 
    284         # Deterministic randomness from coins
    285         utxos = [c.prevout.serialize_to_network() for c in coins]
    286         self.p = PRNG(b''.join(sorted(utxos)))
    287 
    288         # Copy the outputs so when adding change we don't modify "outputs"
    289         base_tx = PartialTransaction.from_io(inputs[:], outputs[:])
    290         input_value = base_tx.input_value()
    291 
    292         # Weight of the transaction with no inputs and no change
    293         # Note: this will use legacy tx serialization as the need for "segwit"
    294         # would be detected from inputs. The only side effect should be that the
    295         # marker and flag are excluded, which is compensated in get_tx_weight()
    296         # FIXME calculation will be off by this (2 wu) in case of RBF batching
    297         base_weight = base_tx.estimated_weight()
    298         spent_amount = base_tx.output_value()
    299 
    300         def fee_estimator_w(weight):
    301             return fee_estimator_vb(Transaction.virtual_size_from_weight(weight))
    302 
    303         def sufficient_funds(buckets, *, bucket_value_sum):
    304             '''Given a list of buckets, return True if it has enough
    305             value to pay for the transaction'''
    306             # assert bucket_value_sum == sum(bucket.value for bucket in buckets)  # expensive!
    307             total_input = input_value + bucket_value_sum
    308             if total_input < spent_amount:  # shortcut for performance
    309                 return False
    310             # note re performance: so far this was constant time
    311             # what follows is linear in len(buckets)
    312             total_weight = self._get_tx_weight(buckets, base_weight=base_weight)
    313             return total_input >= spent_amount + fee_estimator_w(total_weight)
    314 
    315         def tx_from_buckets(buckets):
    316             return self._construct_tx_from_selected_buckets(buckets=buckets,
    317                                                             base_tx=base_tx,
    318                                                             change_addrs=change_addrs,
    319                                                             fee_estimator_w=fee_estimator_w,
    320                                                             dust_threshold=dust_threshold,
    321                                                             base_weight=base_weight)
    322 
    323         # Collect the coins into buckets
    324         all_buckets = self.bucketize_coins(coins, fee_estimator_vb=fee_estimator_vb)
    325         # Filter some buckets out. Only keep those that have positive effective value.
    326         # Note that this filtering is intentionally done on the bucket level
    327         # instead of per-coin, as each bucket should be either fully spent or not at all.
    328         # (e.g. CoinChooserPrivacy ensures that same-address coins go into one bucket)
    329         all_buckets = list(filter(lambda b: b.effective_value > 0, all_buckets))
    330         # Choose a subset of the buckets
    331         scored_candidate = self.choose_buckets(all_buckets, sufficient_funds,
    332                                                self.penalty_func(base_tx, tx_from_buckets=tx_from_buckets))
    333         tx = scored_candidate.tx
    334 
    335         self.logger.info(f"using {len(tx.inputs())} inputs")
    336         self.logger.info(f"using buckets: {[bucket.desc for bucket in scored_candidate.buckets]}")
    337 
    338         return tx
    339 
    340     def choose_buckets(self, buckets: List[Bucket],
    341                        sufficient_funds: Callable,
    342                        penalty_func: Callable[[List[Bucket]], ScoredCandidate]) -> ScoredCandidate:
    343         raise NotImplemented('To be subclassed')
    344 
    345 
    346 class CoinChooserRandom(CoinChooserBase):
    347 
    348     def bucket_candidates_any(self, buckets: List[Bucket], sufficient_funds) -> List[List[Bucket]]:
    349         '''Returns a list of bucket sets.'''
    350         if not buckets:
    351             if sufficient_funds([], bucket_value_sum=0):
    352                 return [[]]
    353             else:
    354                 raise NotEnoughFunds()
    355 
    356         candidates = set()
    357 
    358         # Add all singletons
    359         for n, bucket in enumerate(buckets):
    360             if sufficient_funds([bucket], bucket_value_sum=bucket.value):
    361                 candidates.add((n, ))
    362 
    363         # And now some random ones
    364         attempts = min(100, (len(buckets) - 1) * 10 + 1)
    365         permutation = list(range(len(buckets)))
    366         for i in range(attempts):
    367             # Get a random permutation of the buckets, and
    368             # incrementally combine buckets until sufficient
    369             self.p.shuffle(permutation)
    370             bkts = []
    371             bucket_value_sum = 0
    372             for count, index in enumerate(permutation):
    373                 bucket = buckets[index]
    374                 bkts.append(bucket)
    375                 bucket_value_sum += bucket.value
    376                 if sufficient_funds(bkts, bucket_value_sum=bucket_value_sum):
    377                     candidates.add(tuple(sorted(permutation[:count + 1])))
    378                     break
    379             else:
    380                 # note: this assumes that the effective value of any bkt is >= 0
    381                 raise NotEnoughFunds()
    382 
    383         candidates = [[buckets[n] for n in c] for c in candidates]
    384         return [strip_unneeded(c, sufficient_funds) for c in candidates]
    385 
    386     def bucket_candidates_prefer_confirmed(self, buckets: List[Bucket],
    387                                            sufficient_funds) -> List[List[Bucket]]:
    388         """Returns a list of bucket sets preferring confirmed coins.
    389 
    390         Any bucket can be:
    391         1. "confirmed" if it only contains confirmed coins; else
    392         2. "unconfirmed" if it does not contain coins with unconfirmed parents
    393         3. other: e.g. "unconfirmed parent" or "local"
    394 
    395         This method tries to only use buckets of type 1, and if the coins there
    396         are not enough, tries to use the next type but while also selecting
    397         all buckets of all previous types.
    398         """
    399         conf_buckets = [bkt for bkt in buckets if bkt.min_height > 0]
    400         unconf_buckets = [bkt for bkt in buckets if bkt.min_height == 0]
    401         other_buckets = [bkt for bkt in buckets if bkt.min_height < 0]
    402 
    403         bucket_sets = [conf_buckets, unconf_buckets, other_buckets]
    404         already_selected_buckets = []
    405         already_selected_buckets_value_sum = 0
    406 
    407         for bkts_choose_from in bucket_sets:
    408             try:
    409                 def sfunds(bkts, *, bucket_value_sum):
    410                     bucket_value_sum += already_selected_buckets_value_sum
    411                     return sufficient_funds(already_selected_buckets + bkts,
    412                                             bucket_value_sum=bucket_value_sum)
    413 
    414                 candidates = self.bucket_candidates_any(bkts_choose_from, sfunds)
    415                 break
    416             except NotEnoughFunds:
    417                 already_selected_buckets += bkts_choose_from
    418                 already_selected_buckets_value_sum += sum(bucket.value for bucket in bkts_choose_from)
    419         else:
    420             raise NotEnoughFunds()
    421 
    422         candidates = [(already_selected_buckets + c) for c in candidates]
    423         return [strip_unneeded(c, sufficient_funds) for c in candidates]
    424 
    425     def choose_buckets(self, buckets, sufficient_funds, penalty_func):
    426         candidates = self.bucket_candidates_prefer_confirmed(buckets, sufficient_funds)
    427         scored_candidates = [penalty_func(cand) for cand in candidates]
    428         winner = min(scored_candidates, key=lambda x: x.penalty)
    429         self.logger.info(f"Total number of buckets: {len(buckets)}")
    430         self.logger.info(f"Num candidates considered: {len(candidates)}. "
    431                          f"Winning penalty: {winner.penalty}")
    432         return winner
    433 
    434 
    435 class CoinChooserPrivacy(CoinChooserRandom):
    436     """Attempts to better preserve user privacy.
    437     First, if any coin is spent from a user address, all coins are.
    438     Compared to spending from other addresses to make up an amount, this reduces
    439     information leakage about sender holdings.  It also helps to
    440     reduce blockchain UTXO bloat, and reduce future privacy loss that
    441     would come from reusing that address' remaining UTXOs.
    442     Second, it penalizes change that is quite different to the sent amount.
    443     Third, it penalizes change that is too big.
    444     """
    445 
    446     def keys(self, coins):
    447         return [coin.scriptpubkey.hex() for coin in coins]
    448 
    449     def penalty_func(self, base_tx, *, tx_from_buckets):
    450         min_change = min(o.value for o in base_tx.outputs()) * 0.75
    451         max_change = max(o.value for o in base_tx.outputs()) * 1.33
    452 
    453         def penalty(buckets: List[Bucket]) -> ScoredCandidate:
    454             # Penalize using many buckets (~inputs)
    455             badness = len(buckets) - 1
    456             tx, change_outputs = tx_from_buckets(buckets)
    457             change = sum(o.value for o in change_outputs)
    458             # Penalize change not roughly in output range
    459             if change == 0:
    460                 pass  # no change is great!
    461             elif change < min_change:
    462                 badness += (min_change - change) / (min_change + 10000)
    463                 # Penalize really small change; under 1 mBTC ~= using 1 more input
    464                 if change < COIN / 1000:
    465                     badness += 1
    466             elif change > max_change:
    467                 badness += (change - max_change) / (max_change + 10000)
    468                 # Penalize large change; 5 BTC excess ~= using 1 more input
    469                 badness += change / (COIN * 5)
    470             return ScoredCandidate(badness, tx, buckets)
    471 
    472         return penalty
    473 
    474 
    475 COIN_CHOOSERS = {
    476     'Privacy': CoinChooserPrivacy,
    477 }
    478 
    479 def get_name(config):
    480     kind = config.get('coin_chooser')
    481     if not kind in COIN_CHOOSERS:
    482         kind = 'Privacy'
    483     return kind
    484 
    485 def get_coin_chooser(config):
    486     klass = COIN_CHOOSERS[get_name(config)]
    487     # note: we enable enable_output_value_rounding by default as
    488     #       - for sacrificing a few satoshis
    489     #       + it gives better privacy for the user re change output
    490     #       + it also helps the network as a whole as fees will become noisier
    491     #         (trying to counter the heuristic that "whole integer sat/byte feerates" are common)
    492     coinchooser = klass(
    493         enable_output_value_rounding=config.get('coin_chooser_output_rounding', True),
    494     )
    495     return coinchooser