diff --git a/src/key.cpp b/src/key.cpp index 803c1b552407..84a068d71454 100644 --- a/src/key.cpp +++ b/src/key.cpp @@ -156,21 +156,21 @@ bool CKey::Check(const unsigned char *vch) { } void CKey::MakeNewKey(bool fCompressedIn) { + MakeKeyData(); do { - GetStrongRandBytes(keydata); - } while (!Check(keydata.data())); - fValid = true; + GetStrongRandBytes(*keydata); + } while (!Check(keydata->data())); fCompressed = fCompressedIn; } bool CKey::Negate() { - assert(fValid); - return secp256k1_ec_seckey_negate(secp256k1_context_sign, keydata.data()); + assert(keydata); + return secp256k1_ec_seckey_negate(secp256k1_context_sign, keydata->data()); } CPrivKey CKey::GetPrivKey() const { - assert(fValid); + assert(keydata); CPrivKey seckey; int ret; size_t seckeylen; @@ -183,7 +183,7 @@ CPrivKey CKey::GetPrivKey() const { } CPubKey CKey::GetPubKey() const { - assert(fValid); + assert(keydata); secp256k1_pubkey pubkey; size_t clen = CPubKey::SIZE; CPubKey result; @@ -209,7 +209,7 @@ bool SigHasLowR(const secp256k1_ecdsa_signature* sig) } bool CKey::Sign(const uint256 &hash, std::vector& vchSig, bool grind, uint32_t test_case) const { - if (!fValid) + if (!keydata) return false; vchSig.resize(CPubKey::SIGNATURE_SIZE); size_t nSigLen = CPubKey::SIGNATURE_SIZE; @@ -250,7 +250,7 @@ bool CKey::VerifyPubKey(const CPubKey& pubkey) const { } bool CKey::SignCompact(const uint256 &hash, std::vector& vchSig) const { - if (!fValid) + if (!keydata) return false; vchSig.resize(CPubKey::COMPACT_SIGNATURE_SIZE); int rec = -1; @@ -273,10 +273,12 @@ bool CKey::SignCompact(const uint256 &hash, std::vector& vchSig) } bool CKey::Load(const CPrivKey &seckey, const CPubKey &vchPubKey, bool fSkipCheck=false) { - if (!ec_seckey_import_der(secp256k1_context_sign, (unsigned char*)begin(), seckey.data(), seckey.size())) + MakeKeyData(); + if (!ec_seckey_import_der(secp256k1_context_sign, (unsigned char*)begin(), seckey.data(), seckey.size())) { + ClearKeyData(); return false; + } fCompressed = vchPubKey.IsCompressed(); - fValid = true; if (fSkipCheck) return true; @@ -297,22 +299,21 @@ bool CKey::Derive(CKey& keyChild, ChainCode &ccChild, unsigned int nChild, const BIP32Hash(cc, nChild, 0, begin(), vout.data()); } memcpy(ccChild.begin(), vout.data()+32, 32); - memcpy((unsigned char*)keyChild.begin(), begin(), 32); + keyChild.Set(begin(), begin() + 32, true); bool ret = secp256k1_ec_seckey_tweak_add(secp256k1_context_sign, (unsigned char*)keyChild.begin(), vout.data()); - keyChild.fCompressed = true; - keyChild.fValid = ret; + if (!ret) keyChild.ClearKeyData(); return ret; } EllSwiftPubKey CKey::EllSwiftCreate(Span ent32) const { - assert(fValid); + assert(keydata); assert(ent32.size() == 32); std::array encoded_pubkey; auto success = secp256k1_ellswift_create(secp256k1_context_sign, UCharCast(encoded_pubkey.data()), - keydata.data(), + keydata->data(), UCharCast(ent32.data())); // Should always succeed for valid keys (asserted above). @@ -322,7 +323,7 @@ EllSwiftPubKey CKey::EllSwiftCreate(Span ent32) const ECDHSecret CKey::ComputeBIP324ECDHSecret(const EllSwiftPubKey& their_ellswift, const EllSwiftPubKey& our_ellswift, bool initiating) const { - assert(fValid); + assert(keydata); ECDHSecret output; // BIP324 uses the initiator as party A, and the responder as party B. Remap the inputs @@ -331,7 +332,7 @@ ECDHSecret CKey::ComputeBIP324ECDHSecret(const EllSwiftPubKey& their_ellswift, c UCharCast(output.data()), UCharCast(initiating ? our_ellswift.data() : their_ellswift.data()), UCharCast(initiating ? their_ellswift.data() : our_ellswift.data()), - keydata.data(), + keydata->data(), initiating ? 0 : 1, secp256k1_ellswift_xdh_hash_function_bip324, nullptr); diff --git a/src/key.h b/src/key.h index 0a7b7569709b..c4c713c02f24 100644 --- a/src/key.h +++ b/src/key.h @@ -46,57 +46,77 @@ class CKey "COMPRESSED_SIZE is larger than SIZE"); private: - //! Whether this private key is valid. We check for correctness when modifying the key - //! data, so fValid should always correspond to the actual state. - bool fValid{false}; + /** Internal data container for private key material. */ + using KeyType = std::array; //! Whether the public key corresponding to this private key is (to be) compressed. bool fCompressed{false}; - //! The actual byte data - std::vector > keydata; + //! The actual byte data. nullptr for invalid keys. + secure_unique_ptr keydata; //! Check whether the 32-byte array pointed to by vch is valid keydata. bool static Check(const unsigned char* vch); + void MakeKeyData() + { + if (!keydata) keydata = make_secure_unique(); + } + + void ClearKeyData() + { + keydata.reset(); + } + public: - //! Construct an invalid private key. - CKey() + CKey() noexcept = default; + CKey(CKey&&) noexcept = default; + CKey& operator=(CKey&&) noexcept = default; + + CKey& operator=(const CKey& other) { - // Important: vch must be 32 bytes in length to not break serialization - keydata.resize(32); + if (other.keydata) { + MakeKeyData(); + *keydata = *other.keydata; + } else { + ClearKeyData(); + } + fCompressed = other.fCompressed; + return *this; } + CKey(const CKey& other) { *this = other; } + friend bool operator==(const CKey& a, const CKey& b) { return a.fCompressed == b.fCompressed && a.size() == b.size() && - memcmp(a.keydata.data(), b.keydata.data(), a.size()) == 0; + memcmp(a.data(), b.data(), a.size()) == 0; } //! Initialize using begin and end iterators to byte data. template void Set(const T pbegin, const T pend, bool fCompressedIn) { - if (size_t(pend - pbegin) != keydata.size()) { - fValid = false; + if (size_t(pend - pbegin) != std::tuple_size_v) { + ClearKeyData(); } else if (Check(&pbegin[0])) { - memcpy(keydata.data(), (unsigned char*)&pbegin[0], keydata.size()); - fValid = true; + MakeKeyData(); + memcpy(keydata->data(), (unsigned char*)&pbegin[0], keydata->size()); fCompressed = fCompressedIn; } else { - fValid = false; + ClearKeyData(); } } //! Simple read-only vector-like interface. - unsigned int size() const { return (fValid ? keydata.size() : 0); } - const std::byte* data() const { return reinterpret_cast(keydata.data()); } - const unsigned char* begin() const { return keydata.data(); } - const unsigned char* end() const { return keydata.data() + size(); } + unsigned int size() const { return keydata ? keydata->size() : 0; } + const std::byte* data() const { return keydata ? reinterpret_cast(keydata->data()) : nullptr; } + const unsigned char* begin() const { return keydata ? keydata->data() : nullptr; } + const unsigned char* end() const { return begin() + size(); } //! Check whether this private key is valid. - bool IsValid() const { return fValid; } + bool IsValid() const { return !!keydata; } //! Check whether the public key corresponding to this private key is (to be) compressed. bool IsCompressed() const { return fCompressed; } diff --git a/src/rpc/mining.cpp b/src/rpc/mining.cpp index ea13095ed37f..fb1ef58baa60 100644 --- a/src/rpc/mining.cpp +++ b/src/rpc/mining.cpp @@ -118,7 +118,7 @@ static RPCHelpMan getnetworkhashps() ChainstateManager& chainman = EnsureAnyChainman(request.context); LOCK(cs_main); - return GetNetworkHashPS(!request.params[0].isNull() ? request.params[0].getInt() : 120, !request.params[1].isNull() ? request.params[1].getInt() : -1, chainman.ActiveChain()); + return GetNetworkHashPS(self.Arg(0), self.Arg(1), chainman.ActiveChain()); }, }; } @@ -230,12 +230,12 @@ static RPCHelpMan generatetodescriptor() "\nGenerate 11 blocks to mydesc\n" + HelpExampleCli("generatetodescriptor", "11 \"mydesc\"")}, [&](const RPCHelpMan& self, const JSONRPCRequest& request) -> UniValue { - const int num_blocks{request.params[0].getInt()}; - const uint64_t max_tries{request.params[2].isNull() ? DEFAULT_MAX_TRIES : request.params[2].getInt()}; + const auto num_blocks{self.Arg(0)}; + const auto max_tries{self.Arg(2)}; CScript coinbase_script; std::string error; - if (!getScriptFromDescriptor(request.params[1].get_str(), coinbase_script, error)) { + if (!getScriptFromDescriptor(self.Arg(1), coinbase_script, error)) { throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, error); } diff --git a/src/rpc/util.cpp b/src/rpc/util.cpp index 738a5dea162e..02be3a9b80a4 100644 --- a/src/rpc/util.cpp +++ b/src/rpc/util.cpp @@ -525,13 +525,59 @@ UniValue RPCHelpMan::HandleRequest(const JSONRPCRequest& request) const if (request.mode == JSONRPCRequest::GET_HELP || !IsValidNumArgs(request.params.size())) { throw std::runtime_error(ToString()); } + CHECK_NONFATAL(m_req == nullptr); + m_req = &request; UniValue ret = m_fun(*this, request); + m_req = nullptr; if (gArgs.GetBoolArg("-rpcdoccheck", DEFAULT_RPC_DOC_CHECK)) { CHECK_NONFATAL(std::any_of(m_results.m_results.begin(), m_results.m_results.end(), [&ret](const RPCResult& res) { return res.MatchesType(ret); })); } return ret; } +using CheckFn = void(const RPCArg&); +static const UniValue* DetailMaybeArg(CheckFn* check, const std::vector& params, const JSONRPCRequest* req, size_t i) +{ + CHECK_NONFATAL(i < params.size()); + const UniValue& arg{CHECK_NONFATAL(req)->params[i]}; + const RPCArg& param{params.at(i)}; + if (check) check(param); + + if (!arg.isNull()) return &arg; + if (!std::holds_alternative(param.m_fallback)) return nullptr; + return &std::get(param.m_fallback); +} + +static void CheckRequiredOrDefault(const RPCArg& param) +{ + // Must use `Arg(i)` to get the argument or its default value. + const bool required{ + std::holds_alternative(param.m_fallback) && RPCArg::Optional::NO == std::get(param.m_fallback), + }; + CHECK_NONFATAL(required || std::holds_alternative(param.m_fallback)); +} + +#define TMPL_INST(check_param, ret_type, return_code) \ + template <> \ + ret_type RPCHelpMan::ArgValue(size_t i) const \ + { \ + const UniValue* maybe_arg{ \ + DetailMaybeArg(check_param, m_args, m_req, i), \ + }; \ + return return_code \ + } \ + void force_semicolon(ret_type) + +// Optional arg (without default). Can also be called on required args, if needed. +TMPL_INST(nullptr, std::optional, maybe_arg ? std::optional{maybe_arg->get_real()} : std::nullopt;); +TMPL_INST(nullptr, std::optional, maybe_arg ? std::optional{maybe_arg->get_bool()} : std::nullopt;); +TMPL_INST(nullptr, const std::string*, maybe_arg ? &maybe_arg->get_str() : nullptr;); + +// Required arg or optional arg with default value. +TMPL_INST(CheckRequiredOrDefault, int, CHECK_NONFATAL(maybe_arg)->getInt();); +TMPL_INST(CheckRequiredOrDefault, uint64_t, CHECK_NONFATAL(maybe_arg)->getInt();); +TMPL_INST(CheckRequiredOrDefault, const std::string&, CHECK_NONFATAL(maybe_arg)->get_str();); + bool RPCHelpMan::IsValidNumArgs(size_t num_args) const { size_t num_required_args = 0; diff --git a/src/rpc/util.h b/src/rpc/util.h index cbbd8831b283..5e113e992c13 100644 --- a/src/rpc/util.h +++ b/src/rpc/util.h @@ -5,6 +5,7 @@ #ifndef BITCOIN_RPC_UTIL_H #define BITCOIN_RPC_UTIL_H +#include #include #include #include @@ -13,14 +14,30 @@ #include