Added authorisation.

This commit is contained in:
2019-09-13 18:52:37 +01:00
parent 8a62dee131
commit b0d990f708
18 changed files with 72 additions and 25 deletions

View File

@@ -21,6 +21,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
#include "u2f.hpp" #include "u2f.hpp"
#include <iostream> #include <iostream>
#include <stdexcept> #include <stdexcept>
#include <android/log.h>
using namespace std; using namespace std;
@@ -33,7 +34,7 @@ uint32_t Channel::getCID() const {
return cid; return cid;
} }
void Channel::handle(const U2FMessage& uMsg) { bool Channel::handle(const U2FMessage& uMsg, AuthorisationLevel auth) {
if (uMsg.cmd == U2FHID_INIT) if (uMsg.cmd == U2FHID_INIT)
this->initState = ChannelInitState::Initialised; this->initState = ChannelInitState::Initialised;
else if (uMsg.cid != this->cid) else if (uMsg.cid != this->cid)
@@ -46,8 +47,21 @@ void Channel::handle(const U2FMessage& uMsg) {
auto cmd = U2F_CMD::get(uMsg); auto cmd = U2F_CMD::get(uMsg);
if (cmd) if (!cmd)
return cmd->respond(this->cid); return true;
if (cmd->requiresAuthorisation()) {
if (auth == AuthorisationLevel::Unspecified) {
__android_log_print(ANDROID_LOG_DEBUG, "U2FDevice",
"Action requires authorisation; seeking authorisation");
return false;
}
else if (auth == AuthorisationLevel::Unauthorised)
__android_log_print(ANDROID_LOG_DEBUG, "U2FDevice", "Action requires authorisation; authorisation not granted");
}
cmd->respond(this->cid, auth == AuthorisationLevel::Authorised);
return true;
} }
void Channel::init(const ChannelInitState newInitState) { void Channel::init(const ChannelInitState newInitState) {

View File

@@ -25,6 +25,8 @@ enum class ChannelInitState { Unitialised, Initialised };
enum class ChannelLockedState { Locked, Unlocked }; enum class ChannelLockedState { Locked, Unlocked };
enum class AuthorisationLevel { Unspecified, Unauthorised, Authorised };
class Channel { class Channel {
protected: protected:
uint32_t cid; uint32_t cid;
@@ -33,7 +35,10 @@ protected:
public: public:
Channel(const uint32_t channelID); Channel(const uint32_t channelID);
void handle(const U2FMessage& uMsg);
// Returns false if requires authorisation check
// True otherwise
bool handle(const U2FMessage& uMsg, AuthorisationLevel auth);
uint32_t getCID() const; uint32_t getCID() const;
void init(const ChannelInitState newInitState); void init(const ChannelInitState newInitState);

View File

@@ -32,10 +32,10 @@ void Controller::handleTransaction() {
if (!msg) if (!msg)
return; return;
handleTransaction(*msg); handleTransaction(*msg, AuthorisationLevel::Unspecified);
} }
void Controller::handleTransaction(const U2FMessage& msg) { bool Controller::handleTransaction(const U2FMessage& msg, AuthorisationLevel auth) {
try { try {
if (channels.size() != 0 && if (channels.size() != 0 &&
chrono::duration_cast<chrono::seconds>(chrono::system_clock::now() - lastMessage) < chrono::duration_cast<chrono::seconds>(chrono::system_clock::now() - lastMessage) <
@@ -60,7 +60,7 @@ void Controller::handleTransaction(const U2FMessage& msg) {
} }
} else if (channels.find(opChannelID) == channels.end()) { } else if (channels.find(opChannelID) == channels.end()) {
U2FMessage::error(opChannelID, ERR_CHANNEL_BUSY); U2FMessage::error(opChannelID, ERR_CHANNEL_BUSY);
return; return true;
} }
#ifdef DEBUG_MSGS #ifdef DEBUG_MSGS
@@ -68,7 +68,7 @@ void Controller::handleTransaction(const U2FMessage& msg) {
clog << "cid: " << msg.cid << ", cmd: " << static_cast<unsigned int>(msg.cmd) << endl; clog << "cid: " << msg.cid << ", cmd: " << static_cast<unsigned int>(msg.cmd) << endl;
#endif #endif
channels.at(opChannelID).handle(msg); return channels.at(opChannelID).handle(msg, auth);
} }
uint32_t Controller::nextChannel() { uint32_t Controller::nextChannel() {

View File

@@ -31,6 +31,9 @@ public:
Controller(const uint32_t startChannel = 1); Controller(const uint32_t startChannel = 1);
void handleTransaction(); void handleTransaction();
void handleTransaction(const U2FMessage& msg);
// Returns false if required authentication
// Returns true otherwise
bool handleTransaction(const U2FMessage& msg, AuthorisationLevel auth);
uint32_t nextChannel(); uint32_t nextChannel();
}; };

View File

@@ -47,7 +47,7 @@ U2F_Authenticate_APDU::U2F_Authenticate_APDU(const U2F_Msg_CMD& msg, const vecto
copy(data.begin() + 65, data.begin() + 65 + keyHLen, back_inserter(keyH)); copy(data.begin() + 65, data.begin() + 65 + keyHLen, back_inserter(keyH));
} }
void U2F_Authenticate_APDU::respond(const uint32_t channelID) const { void U2F_Authenticate_APDU::respond(const uint32_t channelID, bool hasAuthorisation) const {
if (keyH.size() != sizeof(Storage::KeyHandle)) { if (keyH.size() != sizeof(Storage::KeyHandle)) {
// Respond with error code - key handle is of wrong size // Respond with error code - key handle is of wrong size
cerr << "Invalid key handle length" << endl; cerr << "Invalid key handle length" << endl;
@@ -71,6 +71,8 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
return; return;
} }
uint8_t presence;
switch (p1) { switch (p1) {
case ControlCode::CheckOnly: case ControlCode::CheckOnly:
this->error(channelID, APDU_STATUS::SW_CONDITIONS_NOT_SATISFIED); this->error(channelID, APDU_STATUS::SW_CONDITIONS_NOT_SATISFIED);
@@ -98,7 +100,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
auto& keyCount = Storage::keyCounts[keyHB]; auto& keyCount = Storage::keyCounts[keyHB];
keyCount++; keyCount++;
response.push_back(0x01); response.push_back(hasAuthorisation ? 1u : 0u);
response.insert(response.end(), FIELD_BE(keyCount)); response.insert(response.end(), FIELD_BE(keyCount));
Digest digest; Digest digest;
@@ -110,7 +112,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
mbedtls_sha256_update(&shaContext, reinterpret_cast<const uint8_t*>(appParam.data()), mbedtls_sha256_update(&shaContext, reinterpret_cast<const uint8_t*>(appParam.data()),
sizeof(appParam)); sizeof(appParam));
uint8_t userPresence{ 1u }; uint8_t userPresence = hasAuthorisation ? 1u : 0u;
mbedtls_sha256_update(&shaContext, &userPresence, 1); mbedtls_sha256_update(&shaContext, &userPresence, 1);
const auto beCounter = beEncode(keyCount); const auto beCounter = beEncode(keyCount);
mbedtls_sha256_update(&shaContext, beCounter.data(), beCounter.size()); mbedtls_sha256_update(&shaContext, beCounter.data(), beCounter.size());
@@ -128,3 +130,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
msg.write(); msg.write();
} }
bool U2F_Authenticate_APDU::requiresAuthorisation() const {
return p1 == ControlCode::EnforcePresenceSign;
}

View File

@@ -29,7 +29,8 @@ struct U2F_Authenticate_APDU : U2F_Msg_CMD {
public: public:
U2F_Authenticate_APDU(const U2F_Msg_CMD& msg, const std::vector<uint8_t>& data); U2F_Authenticate_APDU(const U2F_Msg_CMD& msg, const std::vector<uint8_t>& data);
virtual void respond(const uint32_t channelID) const override; bool requiresAuthorisation() const override;
virtual void respond(const uint32_t channelID, bool hasAuthorisation) const override;
enum ControlCode { enum ControlCode {
CheckOnly = 0x07, CheckOnly = 0x07,

View File

@@ -42,3 +42,7 @@ shared_ptr<U2F_CMD> U2F_CMD::get(const U2FMessage& uMsg) {
return {}; return {};
} }
} }
bool U2F_CMD::requiresAuthorisation() const {
return false;
}

View File

@@ -26,6 +26,7 @@ protected:
public: public:
virtual ~U2F_CMD() = default; virtual ~U2F_CMD() = default;
virtual bool requiresAuthorisation() const;
static std::shared_ptr<U2F_CMD> get(const U2FMessage& uMsg); static std::shared_ptr<U2F_CMD> get(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const = 0; virtual void respond(const uint32_t channelID, bool hasAuthorisation) const = 0;
}; // For polymorphic type casting }; // For polymorphic type casting

View File

@@ -37,7 +37,7 @@ U2F_Init_CMD::U2F_Init_CMD(const U2FMessage& uMsg) {
this->nonce = *reinterpret_cast<const uint64_t*>(uMsg.data.data()); this->nonce = *reinterpret_cast<const uint64_t*>(uMsg.data.data());
} }
void U2F_Init_CMD::respond(const uint32_t channelID) const { void U2F_Init_CMD::respond(const uint32_t channelID, bool) const {
U2FMessage msg{}; U2FMessage msg{};
msg.cid = CID_BROADCAST; msg.cid = CID_BROADCAST;
msg.cmd = U2FHID_INIT; msg.cmd = U2FHID_INIT;

View File

@@ -27,5 +27,5 @@ struct U2F_Init_CMD : U2F_CMD {
public: public:
U2F_Init_CMD(const U2FMessage& uMsg); U2F_Init_CMD(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const override; virtual void respond(const uint32_t channelID, bool) const override;
}; };

View File

@@ -194,6 +194,6 @@ const map<uint8_t, bool> U2F_Msg_CMD::usesData = { { U2F_REG, true },
{ U2F_AUTH, true }, { U2F_AUTH, true },
{ U2F_VER, false } }; { U2F_VER, false } };
void U2F_Msg_CMD::respond(const uint32_t channelID) const { void U2F_Msg_CMD::respond(const uint32_t channelID, bool) const {
U2F_Msg_CMD::error(channelID, static_cast<uint16_t>(APDU_STATUS::SW_INS_NOT_SUPPORTED)); U2F_Msg_CMD::error(channelID, static_cast<uint16_t>(APDU_STATUS::SW_INS_NOT_SUPPORTED));
} }

View File

@@ -23,7 +23,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
#include <memory> #include <memory>
#include <vector> #include <vector>
struct U2F_Msg_CMD : U2F_CMD { struct U2F_Msg_CMD : public U2F_CMD {
uint8_t cla; uint8_t cla;
uint8_t ins; uint8_t ins;
uint8_t p1; uint8_t p1;
@@ -36,9 +36,10 @@ struct U2F_Msg_CMD : U2F_CMD {
protected: protected:
static uint32_t getLe(const uint32_t byteCount, std::vector<uint8_t> bytes); static uint32_t getLe(const uint32_t byteCount, std::vector<uint8_t> bytes);
U2F_Msg_CMD() = default; U2F_Msg_CMD() = default;
virtual ~U2F_Msg_CMD() = default;
public: public:
static std::shared_ptr<U2F_Msg_CMD> generate(const U2FMessage& uMsg); static std::shared_ptr<U2F_Msg_CMD> generate(const U2FMessage& uMsg);
static void error(const uint32_t channelID, const uint16_t errCode); static void error(const uint32_t channelID, const uint16_t errCode);
void respond(const uint32_t channelID) const; void respond(const uint32_t channelID, bool hasAuthorisation) const;
}; };

View File

@@ -26,7 +26,7 @@ U2F_Ping_CMD::U2F_Ping_CMD(const U2FMessage& uMsg) : nonce{ uMsg.data } {
throw runtime_error{ "Failed to get U2F ping message" }; throw runtime_error{ "Failed to get U2F ping message" };
} }
void U2F_Ping_CMD::respond(const uint32_t channelID) const { void U2F_Ping_CMD::respond(const uint32_t channelID, bool) const {
U2FMessage msg{}; U2FMessage msg{};
msg.cid = channelID; msg.cid = channelID;
msg.cmd = U2FHID_PING; msg.cmd = U2FHID_PING;

View File

@@ -27,5 +27,5 @@ struct U2F_Ping_CMD : U2F_CMD {
public: public:
U2F_Ping_CMD(const U2FMessage& uMsg); U2F_Ping_CMD(const U2FMessage& uMsg);
virtual void respond(const uint32_t channelID) const override; virtual void respond(const uint32_t channelID, bool) const override;
}; };

View File

@@ -61,7 +61,13 @@ U2F_Register_APDU::U2F_Register_APDU(const U2F_Msg_CMD& msg, const vector<uint8_
Storage::keyCounts[this->keyH] = 0; Storage::keyCounts[this->keyH] = 0;
} }
void U2F_Register_APDU::respond(const uint32_t channelID) const { void U2F_Register_APDU::respond(const uint32_t channelID, bool hasAuthorisation) const {
if (!hasAuthorisation) {
error(channelID, APDU_STATUS::SW_CONDITIONS_NOT_SATISFIED);
return;
}
U2FMessage m{}; U2FMessage m{};
m.cid = channelID; m.cid = channelID;
m.cmd = U2FHID_MSG; m.cmd = U2FHID_MSG;
@@ -118,3 +124,7 @@ void U2F_Register_APDU::respond(const uint32_t channelID) const {
m.write(); m.write();
} }
bool U2F_Register_APDU::requiresAuthorisation() const {
return true;
}

View File

@@ -27,6 +27,8 @@ struct U2F_Register_APDU : U2F_Msg_CMD {
public: public:
U2F_Register_APDU(const U2F_Msg_CMD& msg, const std::vector<uint8_t>& data); U2F_Register_APDU(const U2F_Msg_CMD& msg, const std::vector<uint8_t>& data);
virtual ~U2F_Register_APDU() = default;
void respond(const uint32_t channelID) const override; bool requiresAuthorisation() const override;
void respond(const uint32_t channelID, bool hasAuthorisation) const override;
}; };

View File

@@ -33,7 +33,7 @@ U2F_Version_APDU::U2F_Version_APDU(const U2F_Msg_CMD& msg, const std::vector<uin
throw APDU_STATUS::SW_WRONG_LENGTH; throw APDU_STATUS::SW_WRONG_LENGTH;
} }
void U2F_Version_APDU::respond(const uint32_t channelID) const { void U2F_Version_APDU::respond(const uint32_t channelID, bool) const {
char ver[]{ 'U', '2', 'F', '_', 'V', '2' }; char ver[]{ 'U', '2', 'F', '_', 'V', '2' };
U2FMessage m{}; U2FMessage m{};

View File

@@ -22,5 +22,5 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
struct U2F_Version_APDU : U2F_Msg_CMD { struct U2F_Version_APDU : U2F_Msg_CMD {
public: public:
U2F_Version_APDU(const U2F_Msg_CMD& msg, const std::vector<uint8_t>& data); U2F_Version_APDU(const U2F_Msg_CMD& msg, const std::vector<uint8_t>& data);
void respond(const uint32_t channelID) const override; void respond(const uint32_t channelID, bool) const override;
}; };