Separated waiting for message from acting on message.

Necessary refactoring of parameter types propagated through source.
This commit is contained in:
2019-08-21 15:55:33 +01:00
parent 2f8e417d00
commit 2987cbe26e
12 changed files with 51 additions and 45 deletions

View File

@@ -33,11 +33,11 @@ uint32_t Channel::getCID() const
return cid; return cid;
} }
void Channel::handle(const shared_ptr<U2FMessage> uMsg) void Channel::handle(const U2FMessage& uMsg)
{ {
if (uMsg->cmd == U2FHID_INIT) if (uMsg.cmd == U2FHID_INIT)
this->initState = ChannelInitState::Initialised; this->initState = ChannelInitState::Initialised;
else if (uMsg->cid != this->cid) else if (uMsg.cid != this->cid)
throw runtime_error{ "CID of request invalid for this channel" }; throw runtime_error{ "CID of request invalid for this channel" };
if (this->initState != ChannelInitState::Initialised) if (this->initState != ChannelInitState::Initialised)

View File

@@ -42,7 +42,7 @@ class Channel
public: public:
Channel(const uint32_t channelID); Channel(const uint32_t channelID);
void handle(const std::shared_ptr<U2FMessage> uMsg); void handle(const U2FMessage& uMsg);
uint32_t getCID() const; uint32_t getCID() const;
void init(const ChannelInitState newInitState); void init(const ChannelInitState newInitState);

View File

@@ -29,6 +29,16 @@ Controller::Controller(const uint32_t startChannel)
{} {}
void Controller::handleTransaction() void Controller::handleTransaction()
{
auto msg = U2FMessage::readNonBlock();
if (!msg)
return;
handleTransaction(*msg);
}
void Controller::handleTransaction(const U2FMessage& msg)
{ {
try try
{ {
@@ -36,20 +46,15 @@ void Controller::handleTransaction()
toggleACTLED(); toggleACTLED();
else else
enableACTLED(false); enableACTLED(false);
} }
catch (runtime_error& ignored) catch (runtime_error& ignored)
{} {}
auto msg = U2FMessage::readNonBlock();
if (!msg)
return;
lastMessage = chrono::system_clock::now(); lastMessage = chrono::system_clock::now();
auto opChannel = msg->cid; auto opChannel = msg.cid;
if (msg->cmd == U2FHID_INIT) if (msg.cmd == U2FHID_INIT)
{ {
opChannel = nextChannel(); opChannel = nextChannel();
auto channel = Channel{ opChannel }; auto channel = Channel{ opChannel };
@@ -66,7 +71,7 @@ void Controller::handleTransaction()
#ifdef DEBUG_MSGS #ifdef DEBUG_MSGS
clog << "Message:" << endl; clog << "Message:" << endl;
clog << "cid: " << msg->cid << ", cmd: " << static_cast<unsigned int>(msg->cmd) << endl; clog << "cid: " << msg.cid << ", cmd: " << static_cast<unsigned int>(msg.cmd) << endl;
#endif #endif
channels.at(opChannel).handle(msg); channels.at(opChannel).handle(msg);

View File

@@ -32,5 +32,6 @@ class Controller
Controller(const uint32_t startChannel = 1); Controller(const uint32_t startChannel = 1);
void handleTransaction(); void handleTransaction();
void handleTransaction(const U2FMessage& msg);
uint32_t nextChannel(); uint32_t nextChannel();
}; };

View File

@@ -24,11 +24,11 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
using namespace std; using namespace std;
shared_ptr<U2F_CMD> U2F_CMD::get(const shared_ptr<U2FMessage> uMsg) shared_ptr<U2F_CMD> U2F_CMD::get(const U2FMessage& uMsg)
{ {
try try
{ {
switch (uMsg->cmd) switch (uMsg.cmd)
{ {
case U2FHID_PING: case U2FHID_PING:
return make_shared<U2F_Ping_CMD>(uMsg); return make_shared<U2F_Ping_CMD>(uMsg);
@@ -37,13 +37,13 @@ shared_ptr<U2F_CMD> U2F_CMD::get(const shared_ptr<U2FMessage> uMsg)
case U2FHID_INIT: case U2FHID_INIT:
return make_shared<U2F_Init_CMD>(uMsg); return make_shared<U2F_Init_CMD>(uMsg);
default: default:
U2FMessage::error(uMsg->cid, ERR_INVALID_CMD); U2FMessage::error(uMsg.cid, ERR_INVALID_CMD);
return {}; return {};
} }
} }
catch (runtime_error& ignored) catch (runtime_error& ignored)
{ {
U2FMessage::error(uMsg->cid, ERR_OTHER); U2FMessage::error(uMsg.cid, ERR_OTHER);
return {}; return {};
} }
} }

View File

@@ -27,6 +27,6 @@ struct U2F_CMD
public: public:
virtual ~U2F_CMD() = default; virtual ~U2F_CMD() = default;
static std::shared_ptr<U2F_CMD> get(const std::shared_ptr<U2FMessage> uMsg); static std::shared_ptr<U2F_CMD> get(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const = 0; virtual void respond(const uint32_t channelID) const = 0;
}; //For polymorphic type casting }; //For polymorphic type casting

View File

@@ -23,22 +23,22 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
using namespace std; using namespace std;
U2F_Init_CMD::U2F_Init_CMD(const shared_ptr<U2FMessage> uMsg) U2F_Init_CMD::U2F_Init_CMD(const U2FMessage& uMsg)
{ {
if (uMsg->cmd != U2FHID_INIT) if (uMsg.cmd != U2FHID_INIT)
throw runtime_error{ "Failed to get U2F Init message" }; throw runtime_error{ "Failed to get U2F Init message" };
else if (uMsg->cid != CID_BROADCAST) else if (uMsg.cid != CID_BROADCAST)
{ {
U2FMessage::error(uMsg->cid, ERR_OTHER); U2FMessage::error(uMsg.cid, ERR_OTHER);
throw runtime_error{ "Invalid CID for init command" }; throw runtime_error{ "Invalid CID for init command" };
} }
else if (uMsg->data.size() != INIT_NONCE_SIZE) else if (uMsg.data.size() != INIT_NONCE_SIZE)
{ {
U2FMessage::error(uMsg->cid, ERR_INVALID_LEN); U2FMessage::error(uMsg.cid, ERR_INVALID_LEN);
throw runtime_error{ "Init nonce is incorrect size" }; throw runtime_error{ "Init nonce is incorrect size" };
} }
this->nonce = *reinterpret_cast<const uint64_t*>(uMsg->data.data()); this->nonce = *reinterpret_cast<const uint64_t*>(uMsg.data.data());
} }
void U2F_Init_CMD::respond(const uint32_t channelID) const void U2F_Init_CMD::respond(const uint32_t channelID) const

View File

@@ -27,6 +27,6 @@ struct U2F_Init_CMD : U2F_CMD
uint64_t nonce; uint64_t nonce;
public: public:
U2F_Init_CMD(const std::shared_ptr<U2FMessage> uMsg); U2F_Init_CMD(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const override; virtual void respond(const uint32_t channelID) const override;
}; };

View File

@@ -60,24 +60,24 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes)
return 0; return 0;
} }
shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg) shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const U2FMessage& uMsg)
{ {
if (uMsg->cmd != U2FHID_MSG) if (uMsg.cmd != U2FHID_MSG)
throw runtime_error{ "Failed to get U2F Msg uMsg" }; throw runtime_error{ "Failed to get U2F Msg uMsg" };
else if (uMsg->data.size() < 4) else if (uMsg.data.size() < 4)
{ {
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw runtime_error{ "Msg data is incorrect size" }; throw runtime_error{ "Msg data is incorrect size" };
} }
U2F_Msg_CMD cmd; U2F_Msg_CMD cmd;
auto &dat = uMsg->data; auto &dat = uMsg.data;
cmd.cla = dat[0]; cmd.cla = dat[0];
if (cmd.cla != 0) if (cmd.cla != 0)
{ {
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_COMMAND_NOT_ALLOWED); U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_COMMAND_NOT_ALLOWED);
throw runtime_error{ "Invalid CLA value in U2F Message" }; throw runtime_error{ "Invalid CLA value in U2F Message" };
} }
@@ -93,7 +93,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
{ {
if (cBCount == 0) if (cBCount == 0)
{ {
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw runtime_error{ "Invalid command - should have attached data" }; throw runtime_error{ "Invalid command - should have attached data" };
} }
@@ -116,7 +116,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
} }
catch (runtime_error& ignored) catch (runtime_error& ignored)
{ {
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw; throw;
} }
} }
@@ -131,7 +131,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
} }
catch (runtime_error& ignored) catch (runtime_error& ignored)
{ {
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH); U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw; throw;
} }
} }
@@ -190,7 +190,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
} }
catch (const APDU_STATUS e) catch (const APDU_STATUS e)
{ {
U2F_Msg_CMD::error(uMsg->cid, e); U2F_Msg_CMD::error(uMsg.cid, e);
throw runtime_error{ "APDU construction error" }; throw runtime_error{ "APDU construction error" };
return {}; return {};
} }

View File

@@ -39,7 +39,7 @@ struct U2F_Msg_CMD : U2F_CMD
U2F_Msg_CMD() = default; U2F_Msg_CMD() = default;
public: public:
static std::shared_ptr<U2F_Msg_CMD> generate(const std::shared_ptr<U2FMessage> uMsg); static std::shared_ptr<U2F_Msg_CMD> generate(const U2FMessage& uMsg);
static void error(const uint32_t channelID, const uint16_t errCode); static void error(const uint32_t channelID, const uint16_t errCode);
void respond(const uint32_t channelID) const; void respond(const uint32_t channelID) const;
}; };

View File

@@ -21,10 +21,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
using namespace std; using namespace std;
U2F_Ping_CMD::U2F_Ping_CMD(const shared_ptr<U2FMessage> uMsg) U2F_Ping_CMD::U2F_Ping_CMD(const U2FMessage& uMsg)
: nonce{ uMsg->data } : nonce{ uMsg.data }
{ {
if (uMsg->cmd != U2FHID_PING) if (uMsg.cmd != U2FHID_PING)
throw runtime_error{ "Failed to get U2F ping message" }; throw runtime_error{ "Failed to get U2F ping message" };
} }

View File

@@ -27,6 +27,6 @@ struct U2F_Ping_CMD : U2F_CMD
std::vector<uint8_t> nonce; std::vector<uint8_t> nonce;
public: public:
U2F_Ping_CMD(const std::shared_ptr<U2FMessage> uMsg); U2F_Ping_CMD(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const override; virtual void respond(const uint32_t channelID) const override;
}; };