Separated waiting for message from acting on message.

Necessary refactoring of parameter types propagated through source.
This commit is contained in:
2019-08-21 15:55:33 +01:00
parent 2f8e417d00
commit 2987cbe26e
12 changed files with 51 additions and 45 deletions

View File

@@ -33,11 +33,11 @@ uint32_t Channel::getCID() const
return cid;
}
void Channel::handle(const shared_ptr<U2FMessage> uMsg)
void Channel::handle(const U2FMessage& uMsg)
{
if (uMsg->cmd == U2FHID_INIT)
if (uMsg.cmd == U2FHID_INIT)
this->initState = ChannelInitState::Initialised;
else if (uMsg->cid != this->cid)
else if (uMsg.cid != this->cid)
throw runtime_error{ "CID of request invalid for this channel" };
if (this->initState != ChannelInitState::Initialised)

View File

@@ -42,8 +42,8 @@ class Channel
public:
Channel(const uint32_t channelID);
void handle(const std::shared_ptr<U2FMessage> uMsg);
void handle(const U2FMessage& uMsg);
uint32_t getCID() const;
void init(const ChannelInitState newInitState);
void lock(const ChannelLockedState newLockedState);

View File

@@ -29,6 +29,16 @@ Controller::Controller(const uint32_t startChannel)
{}
void Controller::handleTransaction()
{
auto msg = U2FMessage::readNonBlock();
if (!msg)
return;
handleTransaction(*msg);
}
void Controller::handleTransaction(const U2FMessage& msg)
{
try
{
@@ -36,20 +46,15 @@ void Controller::handleTransaction()
toggleACTLED();
else
enableACTLED(false);
}
}
catch (runtime_error& ignored)
{}
auto msg = U2FMessage::readNonBlock();
if (!msg)
return;
lastMessage = chrono::system_clock::now();
auto opChannel = msg->cid;
auto opChannel = msg.cid;
if (msg->cmd == U2FHID_INIT)
if (msg.cmd == U2FHID_INIT)
{
opChannel = nextChannel();
auto channel = Channel{ opChannel };
@@ -66,7 +71,7 @@ void Controller::handleTransaction()
#ifdef DEBUG_MSGS
clog << "Message:" << endl;
clog << "cid: " << msg->cid << ", cmd: " << static_cast<unsigned int>(msg->cmd) << endl;
clog << "cid: " << msg.cid << ", cmd: " << static_cast<unsigned int>(msg.cmd) << endl;
#endif
channels.at(opChannel).handle(msg);

View File

@@ -26,11 +26,12 @@ class Controller
protected:
std::map<uint32_t, Channel> channels;
uint32_t currChannel;
std::chrono::system_clock::time_point lastMessage;
std::chrono::system_clock::time_point lastMessage;
public:
Controller(const uint32_t startChannel = 1);
void handleTransaction();
void handleTransaction(const U2FMessage& msg);
uint32_t nextChannel();
};

View File

@@ -24,11 +24,11 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
using namespace std;
shared_ptr<U2F_CMD> U2F_CMD::get(const shared_ptr<U2FMessage> uMsg)
shared_ptr<U2F_CMD> U2F_CMD::get(const U2FMessage& uMsg)
{
try
{
switch (uMsg->cmd)
switch (uMsg.cmd)
{
case U2FHID_PING:
return make_shared<U2F_Ping_CMD>(uMsg);
@@ -37,13 +37,13 @@ shared_ptr<U2F_CMD> U2F_CMD::get(const shared_ptr<U2FMessage> uMsg)
case U2FHID_INIT:
return make_shared<U2F_Init_CMD>(uMsg);
default:
U2FMessage::error(uMsg->cid, ERR_INVALID_CMD);
U2FMessage::error(uMsg.cid, ERR_INVALID_CMD);
return {};
}
}
catch (runtime_error& ignored)
{
U2FMessage::error(uMsg->cid, ERR_OTHER);
U2FMessage::error(uMsg.cid, ERR_OTHER);
return {};
}
}

View File

@@ -27,6 +27,6 @@ struct U2F_CMD
public:
virtual ~U2F_CMD() = default;
static std::shared_ptr<U2F_CMD> get(const std::shared_ptr<U2FMessage> uMsg);
static std::shared_ptr<U2F_CMD> get(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const = 0;
}; //For polymorphic type casting

View File

@@ -23,22 +23,22 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
using namespace std;
U2F_Init_CMD::U2F_Init_CMD(const shared_ptr<U2FMessage> uMsg)
U2F_Init_CMD::U2F_Init_CMD(const U2FMessage& uMsg)
{
if (uMsg->cmd != U2FHID_INIT)
if (uMsg.cmd != U2FHID_INIT)
throw runtime_error{ "Failed to get U2F Init message" };
else if (uMsg->cid != CID_BROADCAST)
else if (uMsg.cid != CID_BROADCAST)
{
U2FMessage::error(uMsg->cid, ERR_OTHER);
U2FMessage::error(uMsg.cid, ERR_OTHER);
throw runtime_error{ "Invalid CID for init command" };
}
else if (uMsg->data.size() != INIT_NONCE_SIZE)
else if (uMsg.data.size() != INIT_NONCE_SIZE)
{
U2FMessage::error(uMsg->cid, ERR_INVALID_LEN);
U2FMessage::error(uMsg.cid, ERR_INVALID_LEN);
throw runtime_error{ "Init nonce is incorrect size" };
}
this->nonce = *reinterpret_cast<const uint64_t*>(uMsg->data.data());
this->nonce = *reinterpret_cast<const uint64_t*>(uMsg.data.data());
}
void U2F_Init_CMD::respond(const uint32_t channelID) const

View File

@@ -27,6 +27,6 @@ struct U2F_Init_CMD : U2F_CMD
uint64_t nonce;
public:
U2F_Init_CMD(const std::shared_ptr<U2FMessage> uMsg);
U2F_Init_CMD(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const override;
};

View File

@@ -35,7 +35,7 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes)
if (byteCount != 0)
{
//Le must be length of data in bytes
switch (byteCount)
{
case 1:
@@ -60,24 +60,24 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes)
return 0;
}
shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const U2FMessage& uMsg)
{
if (uMsg->cmd != U2FHID_MSG)
if (uMsg.cmd != U2FHID_MSG)
throw runtime_error{ "Failed to get U2F Msg uMsg" };
else if (uMsg->data.size() < 4)
else if (uMsg.data.size() < 4)
{
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH);
U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw runtime_error{ "Msg data is incorrect size" };
}
U2F_Msg_CMD cmd;
auto &dat = uMsg->data;
auto &dat = uMsg.data;
cmd.cla = dat[0];
if (cmd.cla != 0)
{
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_COMMAND_NOT_ALLOWED);
U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_COMMAND_NOT_ALLOWED);
throw runtime_error{ "Invalid CLA value in U2F Message" };
}
@@ -93,7 +93,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
{
if (cBCount == 0)
{
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH);
U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw runtime_error{ "Invalid command - should have attached data" };
}
@@ -116,7 +116,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
}
catch (runtime_error& ignored)
{
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH);
U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw;
}
}
@@ -131,7 +131,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
}
catch (runtime_error& ignored)
{
U2F_Msg_CMD::error(uMsg->cid, APDU_STATUS::SW_WRONG_LENGTH);
U2F_Msg_CMD::error(uMsg.cid, APDU_STATUS::SW_WRONG_LENGTH);
throw;
}
}
@@ -161,7 +161,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
"\t\t\t\t\t<td>%u</td>\n"
"\t\t\t\t\t<td>%3u</td>\n"
"\t\t\t\t\t<td class=\"data\">", cmd.cla, cmd.ins, cmd.p1, cmd.p2, cmd.lc);
for (auto b : dBytes)
fprintf(hAS, "%3u ", b);
@@ -190,7 +190,7 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
}
catch (const APDU_STATUS e)
{
U2F_Msg_CMD::error(uMsg->cid, e);
U2F_Msg_CMD::error(uMsg.cid, e);
throw runtime_error{ "APDU construction error" };
return {};
}

View File

@@ -39,7 +39,7 @@ struct U2F_Msg_CMD : U2F_CMD
U2F_Msg_CMD() = default;
public:
static std::shared_ptr<U2F_Msg_CMD> generate(const std::shared_ptr<U2FMessage> uMsg);
static std::shared_ptr<U2F_Msg_CMD> generate(const U2FMessage& uMsg);
static void error(const uint32_t channelID, const uint16_t errCode);
void respond(const uint32_t channelID) const;
};

View File

@@ -21,10 +21,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
using namespace std;
U2F_Ping_CMD::U2F_Ping_CMD(const shared_ptr<U2FMessage> uMsg)
: nonce{ uMsg->data }
U2F_Ping_CMD::U2F_Ping_CMD(const U2FMessage& uMsg)
: nonce{ uMsg.data }
{
if (uMsg->cmd != U2FHID_PING)
if (uMsg.cmd != U2FHID_PING)
throw runtime_error{ "Failed to get U2F ping message" };
}

View File

@@ -27,6 +27,6 @@ struct U2F_Ping_CMD : U2F_CMD
std::vector<uint8_t> nonce;
public:
U2F_Ping_CMD(const std::shared_ptr<U2FMessage> uMsg);
U2F_Ping_CMD(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const override;
};