diff --git a/Channel.cpp b/Channel.cpp index bf16cf5..eb808fd 100644 --- a/Channel.cpp +++ b/Channel.cpp @@ -21,6 +21,7 @@ along with this program. If not, see . #include "u2f.hpp" #include #include +#include using namespace std; @@ -33,7 +34,7 @@ uint32_t Channel::getCID() const { return cid; } -void Channel::handle(const U2FMessage& uMsg) { +bool Channel::handle(const U2FMessage& uMsg, AuthorisationLevel auth) { if (uMsg.cmd == U2FHID_INIT) this->initState = ChannelInitState::Initialised; else if (uMsg.cid != this->cid) @@ -46,8 +47,21 @@ void Channel::handle(const U2FMessage& uMsg) { auto cmd = U2F_CMD::get(uMsg); - if (cmd) - return cmd->respond(this->cid); + if (!cmd) + return true; + + if (cmd->requiresAuthorisation()) { + if (auth == AuthorisationLevel::Unspecified) { + __android_log_print(ANDROID_LOG_DEBUG, "U2FDevice", + "Action requires authorisation; seeking authorisation"); + return false; + } + else if (auth == AuthorisationLevel::Unauthorised) + __android_log_print(ANDROID_LOG_DEBUG, "U2FDevice", "Action requires authorisation; authorisation not granted"); + } + + cmd->respond(this->cid, auth == AuthorisationLevel::Authorised); + return true; } void Channel::init(const ChannelInitState newInitState) { diff --git a/Channel.hpp b/Channel.hpp index 10e16e3..f96ed13 100644 --- a/Channel.hpp +++ b/Channel.hpp @@ -25,6 +25,8 @@ enum class ChannelInitState { Unitialised, Initialised }; enum class ChannelLockedState { Locked, Unlocked }; +enum class AuthorisationLevel { Unspecified, Unauthorised, Authorised }; + class Channel { protected: uint32_t cid; @@ -33,7 +35,10 @@ protected: public: Channel(const uint32_t channelID); - void handle(const U2FMessage& uMsg); + + // Returns false if requires authorisation check + // True otherwise + bool handle(const U2FMessage& uMsg, AuthorisationLevel auth); uint32_t getCID() const; void init(const ChannelInitState newInitState); diff --git a/Controller.cpp b/Controller.cpp index 91560ff..3202d04 100644 --- a/Controller.cpp +++ b/Controller.cpp @@ -32,10 +32,10 @@ void Controller::handleTransaction() { if (!msg) return; - handleTransaction(*msg); + handleTransaction(*msg, AuthorisationLevel::Unspecified); } -void Controller::handleTransaction(const U2FMessage& msg) { +bool Controller::handleTransaction(const U2FMessage& msg, AuthorisationLevel auth) { try { if (channels.size() != 0 && chrono::duration_cast(chrono::system_clock::now() - lastMessage) < @@ -60,7 +60,7 @@ void Controller::handleTransaction(const U2FMessage& msg) { } } else if (channels.find(opChannelID) == channels.end()) { U2FMessage::error(opChannelID, ERR_CHANNEL_BUSY); - return; + return true; } #ifdef DEBUG_MSGS @@ -68,7 +68,7 @@ void Controller::handleTransaction(const U2FMessage& msg) { clog << "cid: " << msg.cid << ", cmd: " << static_cast(msg.cmd) << endl; #endif - channels.at(opChannelID).handle(msg); + return channels.at(opChannelID).handle(msg, auth); } uint32_t Controller::nextChannel() { diff --git a/Controller.hpp b/Controller.hpp index 650b611..e5c3947 100644 --- a/Controller.hpp +++ b/Controller.hpp @@ -31,6 +31,9 @@ public: Controller(const uint32_t startChannel = 1); void handleTransaction(); - void handleTransaction(const U2FMessage& msg); + + // Returns false if required authentication + // Returns true otherwise + bool handleTransaction(const U2FMessage& msg, AuthorisationLevel auth); uint32_t nextChannel(); }; diff --git a/U2F_Authenticate_APDU.cpp b/U2F_Authenticate_APDU.cpp index a07fc86..a7be072 100644 --- a/U2F_Authenticate_APDU.cpp +++ b/U2F_Authenticate_APDU.cpp @@ -47,7 +47,7 @@ U2F_Authenticate_APDU::U2F_Authenticate_APDU(const U2F_Msg_CMD& msg, const vecto copy(data.begin() + 65, data.begin() + 65 + keyHLen, back_inserter(keyH)); } -void U2F_Authenticate_APDU::respond(const uint32_t channelID) const { +void U2F_Authenticate_APDU::respond(const uint32_t channelID, bool hasAuthorisation) const { if (keyH.size() != sizeof(Storage::KeyHandle)) { // Respond with error code - key handle is of wrong size cerr << "Invalid key handle length" << endl; @@ -71,6 +71,8 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const { return; } + uint8_t presence; + switch (p1) { case ControlCode::CheckOnly: this->error(channelID, APDU_STATUS::SW_CONDITIONS_NOT_SATISFIED); @@ -98,7 +100,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const { auto& keyCount = Storage::keyCounts[keyHB]; keyCount++; - response.push_back(0x01); + response.push_back(hasAuthorisation ? 1u : 0u); response.insert(response.end(), FIELD_BE(keyCount)); Digest digest; @@ -110,7 +112,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const { mbedtls_sha256_update(&shaContext, reinterpret_cast(appParam.data()), sizeof(appParam)); - uint8_t userPresence{ 1u }; + uint8_t userPresence = hasAuthorisation ? 1u : 0u; mbedtls_sha256_update(&shaContext, &userPresence, 1); const auto beCounter = beEncode(keyCount); mbedtls_sha256_update(&shaContext, beCounter.data(), beCounter.size()); @@ -128,3 +130,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const { msg.write(); } + +bool U2F_Authenticate_APDU::requiresAuthorisation() const { + return p1 == ControlCode::EnforcePresenceSign; +} diff --git a/U2F_Authenticate_APDU.hpp b/U2F_Authenticate_APDU.hpp index e540682..37014d2 100644 --- a/U2F_Authenticate_APDU.hpp +++ b/U2F_Authenticate_APDU.hpp @@ -29,7 +29,8 @@ struct U2F_Authenticate_APDU : U2F_Msg_CMD { public: U2F_Authenticate_APDU(const U2F_Msg_CMD& msg, const std::vector& data); - virtual void respond(const uint32_t channelID) const override; + bool requiresAuthorisation() const override; + virtual void respond(const uint32_t channelID, bool hasAuthorisation) const override; enum ControlCode { CheckOnly = 0x07, diff --git a/U2F_CMD.cpp b/U2F_CMD.cpp index 6966e14..a1389aa 100644 --- a/U2F_CMD.cpp +++ b/U2F_CMD.cpp @@ -42,3 +42,7 @@ shared_ptr U2F_CMD::get(const U2FMessage& uMsg) { return {}; } } + +bool U2F_CMD::requiresAuthorisation() const { + return false; +} diff --git a/U2F_CMD.hpp b/U2F_CMD.hpp index 5d5973a..4e54dd8 100644 --- a/U2F_CMD.hpp +++ b/U2F_CMD.hpp @@ -26,6 +26,7 @@ protected: public: virtual ~U2F_CMD() = default; + virtual bool requiresAuthorisation() const; static std::shared_ptr get(const U2FMessage& uMsg); - virtual void respond(const uint32_t channelID) const = 0; + virtual void respond(const uint32_t channelID, bool hasAuthorisation) const = 0; }; // For polymorphic type casting diff --git a/U2F_Init_CMD.cpp b/U2F_Init_CMD.cpp index b09d54f..a217313 100644 --- a/U2F_Init_CMD.cpp +++ b/U2F_Init_CMD.cpp @@ -37,7 +37,7 @@ U2F_Init_CMD::U2F_Init_CMD(const U2FMessage& uMsg) { this->nonce = *reinterpret_cast(uMsg.data.data()); } -void U2F_Init_CMD::respond(const uint32_t channelID) const { +void U2F_Init_CMD::respond(const uint32_t channelID, bool) const { U2FMessage msg{}; msg.cid = CID_BROADCAST; msg.cmd = U2FHID_INIT; diff --git a/U2F_Init_CMD.hpp b/U2F_Init_CMD.hpp index 255df00..10383c8 100644 --- a/U2F_Init_CMD.hpp +++ b/U2F_Init_CMD.hpp @@ -27,5 +27,5 @@ struct U2F_Init_CMD : U2F_CMD { public: U2F_Init_CMD(const U2FMessage& uMsg); - virtual void respond(const uint32_t channelID) const override; + virtual void respond(const uint32_t channelID, bool) const override; }; diff --git a/U2F_Msg_CMD.cpp b/U2F_Msg_CMD.cpp index 66533d1..fadd52b 100644 --- a/U2F_Msg_CMD.cpp +++ b/U2F_Msg_CMD.cpp @@ -194,6 +194,6 @@ const map U2F_Msg_CMD::usesData = { { U2F_REG, true }, { U2F_AUTH, true }, { U2F_VER, false } }; -void U2F_Msg_CMD::respond(const uint32_t channelID) const { +void U2F_Msg_CMD::respond(const uint32_t channelID, bool) const { U2F_Msg_CMD::error(channelID, static_cast(APDU_STATUS::SW_INS_NOT_SUPPORTED)); } diff --git a/U2F_Msg_CMD.hpp b/U2F_Msg_CMD.hpp index 7ea7782..32b126a 100644 --- a/U2F_Msg_CMD.hpp +++ b/U2F_Msg_CMD.hpp @@ -23,7 +23,7 @@ along with this program. If not, see . #include #include -struct U2F_Msg_CMD : U2F_CMD { +struct U2F_Msg_CMD : public U2F_CMD { uint8_t cla; uint8_t ins; uint8_t p1; @@ -36,9 +36,10 @@ struct U2F_Msg_CMD : U2F_CMD { protected: static uint32_t getLe(const uint32_t byteCount, std::vector bytes); U2F_Msg_CMD() = default; + virtual ~U2F_Msg_CMD() = default; public: static std::shared_ptr generate(const U2FMessage& uMsg); static void error(const uint32_t channelID, const uint16_t errCode); - void respond(const uint32_t channelID) const; + void respond(const uint32_t channelID, bool hasAuthorisation) const; }; diff --git a/U2F_Ping_CMD.cpp b/U2F_Ping_CMD.cpp index cbcdcc1..d51ef1b 100644 --- a/U2F_Ping_CMD.cpp +++ b/U2F_Ping_CMD.cpp @@ -26,7 +26,7 @@ U2F_Ping_CMD::U2F_Ping_CMD(const U2FMessage& uMsg) : nonce{ uMsg.data } { throw runtime_error{ "Failed to get U2F ping message" }; } -void U2F_Ping_CMD::respond(const uint32_t channelID) const { +void U2F_Ping_CMD::respond(const uint32_t channelID, bool) const { U2FMessage msg{}; msg.cid = channelID; msg.cmd = U2FHID_PING; diff --git a/U2F_Ping_CMD.hpp b/U2F_Ping_CMD.hpp index ecfa031..f7ef01e 100644 --- a/U2F_Ping_CMD.hpp +++ b/U2F_Ping_CMD.hpp @@ -27,5 +27,5 @@ struct U2F_Ping_CMD : U2F_CMD { public: U2F_Ping_CMD(const U2FMessage& uMsg); - virtual void respond(const uint32_t channelID) const override; + virtual void respond(const uint32_t channelID, bool) const override; }; diff --git a/U2F_Register_APDU.cpp b/U2F_Register_APDU.cpp index e5cc62b..011ec27 100644 --- a/U2F_Register_APDU.cpp +++ b/U2F_Register_APDU.cpp @@ -61,7 +61,13 @@ U2F_Register_APDU::U2F_Register_APDU(const U2F_Msg_CMD& msg, const vectorkeyH] = 0; } -void U2F_Register_APDU::respond(const uint32_t channelID) const { +void U2F_Register_APDU::respond(const uint32_t channelID, bool hasAuthorisation) const { + if (!hasAuthorisation) { + error(channelID, APDU_STATUS::SW_CONDITIONS_NOT_SATISFIED); + return; + } + + U2FMessage m{}; m.cid = channelID; m.cmd = U2FHID_MSG; @@ -118,3 +124,7 @@ void U2F_Register_APDU::respond(const uint32_t channelID) const { m.write(); } + +bool U2F_Register_APDU::requiresAuthorisation() const { + return true; +} diff --git a/U2F_Register_APDU.hpp b/U2F_Register_APDU.hpp index db6ca39..4354164 100644 --- a/U2F_Register_APDU.hpp +++ b/U2F_Register_APDU.hpp @@ -27,6 +27,8 @@ struct U2F_Register_APDU : U2F_Msg_CMD { public: U2F_Register_APDU(const U2F_Msg_CMD& msg, const std::vector& data); + virtual ~U2F_Register_APDU() = default; - void respond(const uint32_t channelID) const override; + bool requiresAuthorisation() const override; + void respond(const uint32_t channelID, bool hasAuthorisation) const override; }; diff --git a/U2F_Version_APDU.cpp b/U2F_Version_APDU.cpp index 0d087bc..a615f7d 100644 --- a/U2F_Version_APDU.cpp +++ b/U2F_Version_APDU.cpp @@ -33,7 +33,7 @@ U2F_Version_APDU::U2F_Version_APDU(const U2F_Msg_CMD& msg, const std::vector. struct U2F_Version_APDU : U2F_Msg_CMD { public: U2F_Version_APDU(const U2F_Msg_CMD& msg, const std::vector& data); - void respond(const uint32_t channelID) const override; + void respond(const uint32_t channelID, bool) const override; };