commit f409b5da40e607819f94093f4ca6a91f8a2b71f0
parent 6424163d4bad3de72733849db797d10f11b47479
Author: SomberNight <somber.night@protonmail.com>
Date: Thu, 20 Jun 2019 17:45:56 +0200
coinchooser: refactor so that penalty_func has access to change outputs
Diffstat:
1 file changed, 88 insertions(+), 76 deletions(-)
diff --git a/electrum/coinchooser.py b/electrum/coinchooser.py
@@ -24,7 +24,7 @@
# SOFTWARE.
from collections import defaultdict
from math import floor, log10
-from typing import NamedTuple, List
+from typing import NamedTuple, List, Callable
from .bitcoin import sha256, COIN, TYPE_ADDRESS, is_address
from .transaction import Transaction, TxOutput
@@ -79,6 +79,12 @@ class Bucket(NamedTuple):
witness: bool # whether any coin uses segwit
+class ScoredCandidate(NamedTuple):
+ penalty: float
+ tx: Transaction
+ buckets: List[Bucket]
+
+
def strip_unneeded(bkts, sufficient_funds):
'''Remove buckets that are unnecessary in achieving the spend amount'''
if sufficient_funds([], bucket_value_sum=0):
@@ -121,12 +127,10 @@ class CoinChooserBase(Logger):
return list(map(make_Bucket, buckets.keys(), buckets.values()))
- def penalty_func(self, tx, *, fee_for_buckets):
- def penalty(candidate):
- return 0
- return penalty
+ def penalty_func(self, base_tx, *, tx_from_buckets) -> Callable[[List[Bucket]], ScoredCandidate]:
+ raise NotImplementedError
- def change_amounts(self, tx, count, fee_estimator, dust_threshold):
+ def _change_amounts(self, tx, count, fee_estimator):
# Break change up if bigger than max_change
output_amounts = [o.value for o in tx.outputs()]
# Don't split change of less than 0.02 BTC
@@ -180,22 +184,60 @@ class CoinChooserBase(Logger):
return amounts
- def change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
- amounts = self.change_amounts(tx, len(change_addrs), fee_estimator,
- dust_threshold)
+ def _change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
+ amounts = self._change_amounts(tx, len(change_addrs), fee_estimator)
assert min(amounts) >= 0
assert len(change_addrs) >= len(amounts)
# If change is above dust threshold after accounting for the
# size of the change output, add it to the transaction.
- dust = sum(amount for amount in amounts if amount < dust_threshold)
amounts = [amount for amount in amounts if amount >= dust_threshold]
change = [TxOutput(TYPE_ADDRESS, addr, amount)
for addr, amount in zip(change_addrs, amounts)]
- self.logger.info(f'change: {change}')
- if dust:
- self.logger.info(f'not keeping dust {dust}')
return change
+ def _construct_tx_from_selected_buckets(self, *, buckets, base_tx, change_addrs,
+ fee_estimator_w, dust_threshold, base_weight):
+ # make a copy of base_tx so it won't get mutated
+ tx = Transaction.from_io(base_tx.inputs()[:], base_tx.outputs()[:])
+
+ tx.add_inputs([coin for b in buckets for coin in b.coins])
+ tx_weight = self._get_tx_weight(buckets, base_weight=base_weight)
+
+ # change is sent back to sending address unless specified
+ if not change_addrs:
+ change_addrs = [tx.inputs()[0]['address']]
+ # note: this is not necessarily the final "first input address"
+ # because the inputs had not been sorted at this point
+ assert is_address(change_addrs[0])
+
+ # This takes a count of change outputs and returns a tx fee
+ output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
+ fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
+ change = self._change_outputs(tx, change_addrs, fee, dust_threshold)
+ tx.add_outputs(change)
+
+ return tx, change
+
+ def _get_tx_weight(self, buckets, *, base_weight) -> int:
+ """Given a collection of buckets, return the total weight of the
+ resulting transaction.
+ base_weight is the weight of the tx that includes the fixed (non-change)
+ outputs and potentially some fixed inputs. Note that the change outputs
+ at this point are not yet known so they are NOT accounted for.
+ """
+ total_weight = base_weight + sum(bucket.weight for bucket in buckets)
+ is_segwit_tx = any(bucket.witness for bucket in buckets)
+ if is_segwit_tx:
+ total_weight += 2 # marker and flag
+ # non-segwit inputs were previously assumed to have
+ # a witness of '' instead of '00' (hex)
+ # note that mixed legacy/segwit buckets are already ok
+ num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
+ for bucket in buckets)
+ total_weight += num_legacy_inputs
+
+ return total_weight
+
def make_tx(self, coins, inputs, outputs, change_addrs, fee_estimator,
dust_threshold):
"""Select unspent coins to spend to pay outputs. If the change is
@@ -211,34 +253,20 @@ class CoinChooserBase(Logger):
self.p = PRNG(''.join(sorted(utxos)))
# Copy the outputs so when adding change we don't modify "outputs"
- tx = Transaction.from_io(inputs[:], outputs[:])
- input_value = tx.input_value()
+ base_tx = Transaction.from_io(inputs[:], outputs[:])
+ input_value = base_tx.input_value()
# Weight of the transaction with no inputs and no change
# Note: this will use legacy tx serialization as the need for "segwit"
# would be detected from inputs. The only side effect should be that the
# marker and flag are excluded, which is compensated in get_tx_weight()
# FIXME calculation will be off by this (2 wu) in case of RBF batching
- base_weight = tx.estimated_weight()
- spent_amount = tx.output_value()
+ base_weight = base_tx.estimated_weight()
+ spent_amount = base_tx.output_value()
def fee_estimator_w(weight):
return fee_estimator(Transaction.virtual_size_from_weight(weight))
- def get_tx_weight(buckets):
- total_weight = base_weight + sum(bucket.weight for bucket in buckets)
- is_segwit_tx = any(bucket.witness for bucket in buckets)
- if is_segwit_tx:
- total_weight += 2 # marker and flag
- # non-segwit inputs were previously assumed to have
- # a witness of '' instead of '00' (hex)
- # note that mixed legacy/segwit buckets are already ok
- num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
- for bucket in buckets)
- total_weight += num_legacy_inputs
-
- return total_weight
-
def sufficient_funds(buckets, *, bucket_value_sum):
'''Given a list of buckets, return True if it has enough
value to pay for the transaction'''
@@ -248,45 +276,30 @@ class CoinChooserBase(Logger):
return False
# note re performance: so far this was constant time
# what follows is linear in len(buckets)
- total_weight = get_tx_weight(buckets)
+ total_weight = self._get_tx_weight(buckets, base_weight=base_weight)
return total_input >= spent_amount + fee_estimator_w(total_weight)
- def fee_for_buckets(buckets) -> int:
- """Given a list of buckets, return the total fee paid by the
- transaction, in satoshis.
- Note that the change output(s) are not yet known here,
- so fees for those are excluded and hence this is a lower bound.
- """
- total_weight = get_tx_weight(buckets)
- return fee_estimator_w(total_weight)
+ def tx_from_buckets(buckets):
+ return self._construct_tx_from_selected_buckets(buckets=buckets,
+ base_tx=base_tx,
+ change_addrs=change_addrs,
+ fee_estimator_w=fee_estimator_w,
+ dust_threshold=dust_threshold,
+ base_weight=base_weight)
# Collect the coins into buckets, choose a subset of the buckets
- buckets = self.bucketize_coins(coins)
- buckets = self.choose_buckets(buckets, sufficient_funds,
- self.penalty_func(tx, fee_for_buckets=fee_for_buckets))
-
- tx.add_inputs([coin for b in buckets for coin in b.coins])
- tx_weight = get_tx_weight(buckets)
-
- # change is sent back to sending address unless specified
- if not change_addrs:
- change_addrs = [tx.inputs()[0]['address']]
- # note: this is not necessarily the final "first input address"
- # because the inputs had not been sorted at this point
- assert is_address(change_addrs[0])
-
- # This takes a count of change outputs and returns a tx fee
- output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
- fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
- change = self.change_outputs(tx, change_addrs, fee, dust_threshold)
- tx.add_outputs(change)
+ all_buckets = self.bucketize_coins(coins)
+ scored_candidate = self.choose_buckets(all_buckets, sufficient_funds,
+ self.penalty_func(base_tx, tx_from_buckets=tx_from_buckets))
+ tx = scored_candidate.tx
self.logger.info(f"using {len(tx.inputs())} inputs")
- self.logger.info(f"using buckets: {[bucket.desc for bucket in buckets]}")
+ self.logger.info(f"using buckets: {[bucket.desc for bucket in scored_candidate.buckets]}")
return tx
- def choose_buckets(self, buckets, sufficient_funds, penalty_func):
+ def choose_buckets(self, buckets, sufficient_funds,
+ penalty_func: Callable[[List[Bucket]], ScoredCandidate]) -> ScoredCandidate:
raise NotImplemented('To be subclassed')
@@ -368,12 +381,14 @@ class CoinChooserRandom(CoinChooserBase):
def choose_buckets(self, buckets, sufficient_funds, penalty_func):
candidates = self.bucket_candidates_prefer_confirmed(buckets, sufficient_funds)
- penalties = [penalty_func(cand) for cand in candidates]
- winner = candidates[penalties.index(min(penalties))]
- self.logger.info(f"Bucket sets: {len(buckets)}")
- self.logger.info(f"Winning penalty: {min(penalties)}")
+ scored_candidates = [penalty_func(cand) for cand in candidates]
+ winner = min(scored_candidates, key=lambda x: x.penalty)
+ self.logger.info(f"Total number of buckets: {len(buckets)}")
+ self.logger.info(f"Num candidates considered: {len(candidates)}. "
+ f"Winning penalty: {winner.penalty}")
return winner
+
class CoinChooserPrivacy(CoinChooserRandom):
"""Attempts to better preserve user privacy.
First, if any coin is spent from a user address, all coins are.
@@ -388,18 +403,15 @@ class CoinChooserPrivacy(CoinChooserRandom):
def keys(self, coins):
return [coin['address'] for coin in coins]
- def penalty_func(self, tx, *, fee_for_buckets):
- min_change = min(o.value for o in tx.outputs()) * 0.75
- max_change = max(o.value for o in tx.outputs()) * 1.33
- spent_amount = sum(o.value for o in tx.outputs())
+ def penalty_func(self, base_tx, *, tx_from_buckets):
+ min_change = min(o.value for o in base_tx.outputs()) * 0.75
+ max_change = max(o.value for o in base_tx.outputs()) * 1.33
- def penalty(buckets):
+ def penalty(buckets) -> ScoredCandidate:
+ # Penalize using many buckets (~inputs)
badness = len(buckets) - 1
- total_input = sum(bucket.value for bucket in buckets)
- # FIXME fee_for_buckets does not include fees needed to cover the change output(s)
- # so fee here is a lower bound
- fee = fee_for_buckets(buckets)
- change = float(total_input - spent_amount - fee)
+ tx, change_outputs = tx_from_buckets(buckets)
+ change = sum(o.value for o in change_outputs)
# Penalize change not roughly in output range
if change < min_change:
badness += (min_change - change) / (min_change + 10000)
@@ -407,7 +419,7 @@ class CoinChooserPrivacy(CoinChooserRandom):
badness += (change - max_change) / (max_change + 10000)
# Penalize large change; 5 BTC excess ~= using 1 more input
badness += change / (COIN * 5)
- return badness
+ return ScoredCandidate(badness, tx, buckets)
return penalty