diff --git a/Channel.cpp b/Channel.cpp index 3d35749..f7a3bc4 100644 --- a/Channel.cpp +++ b/Channel.cpp @@ -33,11 +33,11 @@ uint32_t Channel::getCID() const return cid; } -void Channel::handle(const shared_ptr uMsg) +void Channel::handle(const U2FMessage& uMsg) { - if (uMsg->cmd == U2FHID_INIT) + if (uMsg.cmd == U2FHID_INIT) this->initState = ChannelInitState::Initialised; - else if (uMsg->cid != this->cid) + else if (uMsg.cid != this->cid) throw runtime_error{ "CID of request invalid for this channel" }; if (this->initState != ChannelInitState::Initialised) diff --git a/Channel.hpp b/Channel.hpp index 61c454f..ab4f1dc 100644 --- a/Channel.hpp +++ b/Channel.hpp @@ -42,8 +42,8 @@ class Channel public: Channel(const uint32_t channelID); - void handle(const std::shared_ptr uMsg); - + void handle(const U2FMessage& uMsg); + uint32_t getCID() const; void init(const ChannelInitState newInitState); void lock(const ChannelLockedState newLockedState); diff --git a/Controller.cpp b/Controller.cpp index 97cf7dd..76a7be9 100644 --- a/Controller.cpp +++ b/Controller.cpp @@ -29,6 +29,16 @@ Controller::Controller(const uint32_t startChannel) {} void Controller::handleTransaction() +{ + auto msg = U2FMessage::readNonBlock(); + + if (!msg) + return; + + handleTransaction(*msg); +} + +void Controller::handleTransaction(const U2FMessage& msg) { try { @@ -36,20 +46,15 @@ void Controller::handleTransaction() toggleACTLED(); else enableACTLED(false); - } + } catch (runtime_error& ignored) {} - auto msg = U2FMessage::readNonBlock(); - - if (!msg) - return; - lastMessage = chrono::system_clock::now(); - auto opChannel = msg->cid; + auto opChannel = msg.cid; - if (msg->cmd == U2FHID_INIT) + if (msg.cmd == U2FHID_INIT) { opChannel = nextChannel(); auto channel = Channel{ opChannel }; @@ -66,7 +71,7 @@ void Controller::handleTransaction() #ifdef DEBUG_MSGS clog << "Message:" << endl; - clog << "cid: " << msg->cid << ", cmd: " << static_cast(msg->cmd) << endl; + clog << "cid: " << msg.cid << ", cmd: " << static_cast(msg.cmd) << endl; #endif channels.at(opChannel).handle(msg); diff --git a/Controller.hpp b/Controller.hpp index f365ebd..00c7712 100644 --- a/Controller.hpp +++ b/Controller.hpp @@ -26,11 +26,12 @@ class Controller protected: std::map channels; uint32_t currChannel; - std::chrono::system_clock::time_point lastMessage; + std::chrono::system_clock::time_point lastMessage; public: Controller(const uint32_t startChannel = 1); void handleTransaction(); + void handleTransaction(const U2FMessage& msg); uint32_t nextChannel(); }; diff --git a/U2F_CMD.cpp b/U2F_CMD.cpp index 7cdaf43..c734551 100644 --- a/U2F_CMD.cpp +++ b/U2F_CMD.cpp @@ -24,11 +24,11 @@ along with this program. If not, see . using namespace std; -shared_ptr U2F_CMD::get(const shared_ptr uMsg) +shared_ptr U2F_CMD::get(const U2FMessage& uMsg) { try { - switch (uMsg->cmd) + switch (uMsg.cmd) { case U2FHID_PING: return make_shared(uMsg); @@ -37,13 +37,13 @@ shared_ptr U2F_CMD::get(const shared_ptr uMsg) case U2FHID_INIT: return make_shared(uMsg); default: - U2FMessage::error(uMsg->cid, ERR_INVALID_CMD); + U2FMessage::error(uMsg.cid, ERR_INVALID_CMD); return {}; } } catch (runtime_error& ignored) { - U2FMessage::error(uMsg->cid, ERR_OTHER); + U2FMessage::error(uMsg.cid, ERR_OTHER); return {}; } } diff --git a/U2F_CMD.hpp b/U2F_CMD.hpp index 2533992..33c2bb3 100644 --- a/U2F_CMD.hpp +++ b/U2F_CMD.hpp @@ -27,6 +27,6 @@ struct U2F_CMD public: virtual ~U2F_CMD() = default; - static std::shared_ptr get(const std::shared_ptr uMsg); + static std::shared_ptr get(const U2FMessage& uMsg); virtual void respond(const uint32_t channelID) const = 0; }; //For polymorphic type casting diff --git a/U2F_Init_CMD.cpp b/U2F_Init_CMD.cpp index b35f643..2ba9e60 100644 --- a/U2F_Init_CMD.cpp +++ b/U2F_Init_CMD.cpp @@ -23,22 +23,22 @@ along with this program. If not, see . using namespace std; -U2F_Init_CMD::U2F_Init_CMD(const shared_ptr uMsg) +U2F_Init_CMD::U2F_Init_CMD(const U2FMessage& uMsg) { - if (uMsg->cmd != U2FHID_INIT) + if (uMsg.cmd != U2FHID_INIT) throw runtime_error{ "Failed to get U2F Init message" }; - else if (uMsg->cid != CID_BROADCAST) + else if (uMsg.cid != CID_BROADCAST) { - U2FMessage::error(uMsg->cid, ERR_OTHER); + U2FMessage::error(uMsg.cid, ERR_OTHER); throw runtime_error{ "Invalid CID for init command" }; } - else if (uMsg->data.size() != INIT_NONCE_SIZE) + else if (uMsg.data.size() != INIT_NONCE_SIZE) { - U2FMessage::error(uMsg->cid, ERR_INVALID_LEN); + U2FMessage::error(uMsg.cid, ERR_INVALID_LEN); throw runtime_error{ "Init nonce is incorrect size" }; } - this->nonce = *reinterpret_cast(uMsg->data.data()); + this->nonce = *reinterpret_cast(uMsg.data.data()); } void U2F_Init_CMD::respond(const uint32_t channelID) const diff --git a/U2F_Init_CMD.hpp b/U2F_Init_CMD.hpp index acdfcdc..8f2b864 100644 --- a/U2F_Init_CMD.hpp +++ b/U2F_Init_CMD.hpp @@ -27,6 +27,6 @@ struct U2F_Init_CMD : U2F_CMD uint64_t nonce; public: - U2F_Init_CMD(const std::shared_ptr uMsg); + U2F_Init_CMD(const U2FMessage& uMsg); virtual void respond(const uint32_t channelID) const override; }; diff --git a/U2F_Msg_CMD.cpp b/U2F_Msg_CMD.cpp index b78d0b2..14e4794 100644 --- a/U2F_Msg_CMD.cpp +++ b/U2F_Msg_CMD.cpp @@ -35,7 +35,7 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector bytes) if (byteCount != 0) { //Le must be length of data in bytes - + switch (byteCount) { case 1: @@ -60,24 +60,24 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector bytes) return 0; } -shared_ptr U2F_Msg_CMD::generate(const shared_ptr uMsg) +shared_ptr U2F_Msg_CMD::generate(const U2FMessage& uMsg) { - if (uMsg->cmd != U2FHID_MSG) + if (uMsg.cmd != U2FHID_MSG) throw runtime_error{ "Failed to get U2F Msg uMsg" }; - else if (uMsg->data.size() < 4) + else if (uMsg.data.size() < 4) { - U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); + U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH); throw runtime_error{ "Msg data is incorrect size" }; } U2F_Msg_CMD cmd; - auto &dat = uMsg->data; + auto &dat = uMsg.data; cmd.cla = dat[0]; if (cmd.cla != 0) { - U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_COMMAND_NOT_ALLOWED); + U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_COMMAND_NOT_ALLOWED); throw runtime_error{ "Invalid CLA value in U2F Message" }; } @@ -93,7 +93,7 @@ shared_ptr U2F_Msg_CMD::generate(const shared_ptr uMsg) { if (cBCount == 0) { - U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); + U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH); throw runtime_error{ "Invalid command - should have attached data" }; } @@ -116,7 +116,7 @@ shared_ptr U2F_Msg_CMD::generate(const shared_ptr uMsg) } catch (runtime_error& ignored) { - U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); + U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH); throw; } } @@ -131,7 +131,7 @@ shared_ptr U2F_Msg_CMD::generate(const shared_ptr uMsg) } catch (runtime_error& ignored) { - U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); + U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH); throw; } } @@ -161,7 +161,7 @@ shared_ptr U2F_Msg_CMD::generate(const shared_ptr uMsg) "\t\t\t\t\t%u\n" "\t\t\t\t\t%3u\n" "\t\t\t\t\t", cmd.cla, cmd.ins, cmd.p1, cmd.p2, cmd.lc); - + for (auto b : dBytes) fprintf(hAS, "%3u ", b); @@ -190,7 +190,7 @@ shared_ptr U2F_Msg_CMD::generate(const shared_ptr uMsg) } catch (const APDU_STATUS e) { - U2F_Msg_CMD::error(uMsg->cid, e); + U2F_Msg_CMD::error(uMsg.cid, e); throw runtime_error{ "APDU construction error" }; return {}; } diff --git a/U2F_Msg_CMD.hpp b/U2F_Msg_CMD.hpp index 5a7f69a..74c913c 100644 --- a/U2F_Msg_CMD.hpp +++ b/U2F_Msg_CMD.hpp @@ -39,7 +39,7 @@ struct U2F_Msg_CMD : U2F_CMD U2F_Msg_CMD() = default; public: - static std::shared_ptr generate(const std::shared_ptr uMsg); + static std::shared_ptr generate(const U2FMessage& uMsg); static void error(const uint32_t channelID, const uint16_t errCode); void respond(const uint32_t channelID) const; }; diff --git a/U2F_Ping_CMD.cpp b/U2F_Ping_CMD.cpp index 57f1f71..c778c17 100644 --- a/U2F_Ping_CMD.cpp +++ b/U2F_Ping_CMD.cpp @@ -21,10 +21,10 @@ along with this program. If not, see . using namespace std; -U2F_Ping_CMD::U2F_Ping_CMD(const shared_ptr uMsg) - : nonce{ uMsg->data } +U2F_Ping_CMD::U2F_Ping_CMD(const U2FMessage& uMsg) + : nonce{ uMsg.data } { - if (uMsg->cmd != U2FHID_PING) + if (uMsg.cmd != U2FHID_PING) throw runtime_error{ "Failed to get U2F ping message" }; } diff --git a/U2F_Ping_CMD.hpp b/U2F_Ping_CMD.hpp index c7e8c5c..62fb028 100644 --- a/U2F_Ping_CMD.hpp +++ b/U2F_Ping_CMD.hpp @@ -27,6 +27,6 @@ struct U2F_Ping_CMD : U2F_CMD std::vector nonce; public: - U2F_Ping_CMD(const std::shared_ptr uMsg); + U2F_Ping_CMD(const U2FMessage& uMsg); virtual void respond(const uint32_t channelID) const override; };