diff --git a/Channel.cpp b/Channel.cpp
index bf16cf5..eb808fd 100644
--- a/Channel.cpp
+++ b/Channel.cpp
@@ -21,6 +21,7 @@ along with this program. If not, see .
#include "u2f.hpp"
#include
#include
+#include
using namespace std;
@@ -33,7 +34,7 @@ uint32_t Channel::getCID() const {
return cid;
}
-void Channel::handle(const U2FMessage& uMsg) {
+bool Channel::handle(const U2FMessage& uMsg, AuthorisationLevel auth) {
if (uMsg.cmd == U2FHID_INIT)
this->initState = ChannelInitState::Initialised;
else if (uMsg.cid != this->cid)
@@ -46,8 +47,21 @@ void Channel::handle(const U2FMessage& uMsg) {
auto cmd = U2F_CMD::get(uMsg);
- if (cmd)
- return cmd->respond(this->cid);
+ if (!cmd)
+ return true;
+
+ if (cmd->requiresAuthorisation()) {
+ if (auth == AuthorisationLevel::Unspecified) {
+ __android_log_print(ANDROID_LOG_DEBUG, "U2FDevice",
+ "Action requires authorisation; seeking authorisation");
+ return false;
+ }
+ else if (auth == AuthorisationLevel::Unauthorised)
+ __android_log_print(ANDROID_LOG_DEBUG, "U2FDevice", "Action requires authorisation; authorisation not granted");
+ }
+
+ cmd->respond(this->cid, auth == AuthorisationLevel::Authorised);
+ return true;
}
void Channel::init(const ChannelInitState newInitState) {
diff --git a/Channel.hpp b/Channel.hpp
index 10e16e3..f96ed13 100644
--- a/Channel.hpp
+++ b/Channel.hpp
@@ -25,6 +25,8 @@ enum class ChannelInitState { Unitialised, Initialised };
enum class ChannelLockedState { Locked, Unlocked };
+enum class AuthorisationLevel { Unspecified, Unauthorised, Authorised };
+
class Channel {
protected:
uint32_t cid;
@@ -33,7 +35,10 @@ protected:
public:
Channel(const uint32_t channelID);
- void handle(const U2FMessage& uMsg);
+
+ // Returns false if requires authorisation check
+ // True otherwise
+ bool handle(const U2FMessage& uMsg, AuthorisationLevel auth);
uint32_t getCID() const;
void init(const ChannelInitState newInitState);
diff --git a/Controller.cpp b/Controller.cpp
index 91560ff..3202d04 100644
--- a/Controller.cpp
+++ b/Controller.cpp
@@ -32,10 +32,10 @@ void Controller::handleTransaction() {
if (!msg)
return;
- handleTransaction(*msg);
+ handleTransaction(*msg, AuthorisationLevel::Unspecified);
}
-void Controller::handleTransaction(const U2FMessage& msg) {
+bool Controller::handleTransaction(const U2FMessage& msg, AuthorisationLevel auth) {
try {
if (channels.size() != 0 &&
chrono::duration_cast(chrono::system_clock::now() - lastMessage) <
@@ -60,7 +60,7 @@ void Controller::handleTransaction(const U2FMessage& msg) {
}
} else if (channels.find(opChannelID) == channels.end()) {
U2FMessage::error(opChannelID, ERR_CHANNEL_BUSY);
- return;
+ return true;
}
#ifdef DEBUG_MSGS
@@ -68,7 +68,7 @@ void Controller::handleTransaction(const U2FMessage& msg) {
clog << "cid: " << msg.cid << ", cmd: " << static_cast(msg.cmd) << endl;
#endif
- channels.at(opChannelID).handle(msg);
+ return channels.at(opChannelID).handle(msg, auth);
}
uint32_t Controller::nextChannel() {
diff --git a/Controller.hpp b/Controller.hpp
index 650b611..e5c3947 100644
--- a/Controller.hpp
+++ b/Controller.hpp
@@ -31,6 +31,9 @@ public:
Controller(const uint32_t startChannel = 1);
void handleTransaction();
- void handleTransaction(const U2FMessage& msg);
+
+ // Returns false if required authentication
+ // Returns true otherwise
+ bool handleTransaction(const U2FMessage& msg, AuthorisationLevel auth);
uint32_t nextChannel();
};
diff --git a/U2F_Authenticate_APDU.cpp b/U2F_Authenticate_APDU.cpp
index a07fc86..a7be072 100644
--- a/U2F_Authenticate_APDU.cpp
+++ b/U2F_Authenticate_APDU.cpp
@@ -47,7 +47,7 @@ U2F_Authenticate_APDU::U2F_Authenticate_APDU(const U2F_Msg_CMD& msg, const vecto
copy(data.begin() + 65, data.begin() + 65 + keyHLen, back_inserter(keyH));
}
-void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
+void U2F_Authenticate_APDU::respond(const uint32_t channelID, bool hasAuthorisation) const {
if (keyH.size() != sizeof(Storage::KeyHandle)) {
// Respond with error code - key handle is of wrong size
cerr << "Invalid key handle length" << endl;
@@ -71,6 +71,8 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
return;
}
+ uint8_t presence;
+
switch (p1) {
case ControlCode::CheckOnly:
this->error(channelID, APDU_STATUS::SW_CONDITIONS_NOT_SATISFIED);
@@ -98,7 +100,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
auto& keyCount = Storage::keyCounts[keyHB];
keyCount++;
- response.push_back(0x01);
+ response.push_back(hasAuthorisation ? 1u : 0u);
response.insert(response.end(), FIELD_BE(keyCount));
Digest digest;
@@ -110,7 +112,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
mbedtls_sha256_update(&shaContext, reinterpret_cast(appParam.data()),
sizeof(appParam));
- uint8_t userPresence{ 1u };
+ uint8_t userPresence = hasAuthorisation ? 1u : 0u;
mbedtls_sha256_update(&shaContext, &userPresence, 1);
const auto beCounter = beEncode(keyCount);
mbedtls_sha256_update(&shaContext, beCounter.data(), beCounter.size());
@@ -128,3 +130,7 @@ void U2F_Authenticate_APDU::respond(const uint32_t channelID) const {
msg.write();
}
+
+bool U2F_Authenticate_APDU::requiresAuthorisation() const {
+ return p1 == ControlCode::EnforcePresenceSign;
+}
diff --git a/U2F_Authenticate_APDU.hpp b/U2F_Authenticate_APDU.hpp
index e540682..37014d2 100644
--- a/U2F_Authenticate_APDU.hpp
+++ b/U2F_Authenticate_APDU.hpp
@@ -29,7 +29,8 @@ struct U2F_Authenticate_APDU : U2F_Msg_CMD {
public:
U2F_Authenticate_APDU(const U2F_Msg_CMD& msg, const std::vector& data);
- virtual void respond(const uint32_t channelID) const override;
+ bool requiresAuthorisation() const override;
+ virtual void respond(const uint32_t channelID, bool hasAuthorisation) const override;
enum ControlCode {
CheckOnly = 0x07,
diff --git a/U2F_CMD.cpp b/U2F_CMD.cpp
index 6966e14..a1389aa 100644
--- a/U2F_CMD.cpp
+++ b/U2F_CMD.cpp
@@ -42,3 +42,7 @@ shared_ptr U2F_CMD::get(const U2FMessage& uMsg) {
return {};
}
}
+
+bool U2F_CMD::requiresAuthorisation() const {
+ return false;
+}
diff --git a/U2F_CMD.hpp b/U2F_CMD.hpp
index 5d5973a..4e54dd8 100644
--- a/U2F_CMD.hpp
+++ b/U2F_CMD.hpp
@@ -26,6 +26,7 @@ protected:
public:
virtual ~U2F_CMD() = default;
+ virtual bool requiresAuthorisation() const;
static std::shared_ptr get(const U2FMessage& uMsg);
- virtual void respond(const uint32_t channelID) const = 0;
+ virtual void respond(const uint32_t channelID, bool hasAuthorisation) const = 0;
}; // For polymorphic type casting
diff --git a/U2F_Init_CMD.cpp b/U2F_Init_CMD.cpp
index b09d54f..a217313 100644
--- a/U2F_Init_CMD.cpp
+++ b/U2F_Init_CMD.cpp
@@ -37,7 +37,7 @@ U2F_Init_CMD::U2F_Init_CMD(const U2FMessage& uMsg) {
this->nonce = *reinterpret_cast(uMsg.data.data());
}
-void U2F_Init_CMD::respond(const uint32_t channelID) const {
+void U2F_Init_CMD::respond(const uint32_t channelID, bool) const {
U2FMessage msg{};
msg.cid = CID_BROADCAST;
msg.cmd = U2FHID_INIT;
diff --git a/U2F_Init_CMD.hpp b/U2F_Init_CMD.hpp
index 255df00..10383c8 100644
--- a/U2F_Init_CMD.hpp
+++ b/U2F_Init_CMD.hpp
@@ -27,5 +27,5 @@ struct U2F_Init_CMD : U2F_CMD {
public:
U2F_Init_CMD(const U2FMessage& uMsg);
- virtual void respond(const uint32_t channelID) const override;
+ virtual void respond(const uint32_t channelID, bool) const override;
};
diff --git a/U2F_Msg_CMD.cpp b/U2F_Msg_CMD.cpp
index 66533d1..fadd52b 100644
--- a/U2F_Msg_CMD.cpp
+++ b/U2F_Msg_CMD.cpp
@@ -194,6 +194,6 @@ const map U2F_Msg_CMD::usesData = { { U2F_REG, true },
{ U2F_AUTH, true },
{ U2F_VER, false } };
-void U2F_Msg_CMD::respond(const uint32_t channelID) const {
+void U2F_Msg_CMD::respond(const uint32_t channelID, bool) const {
U2F_Msg_CMD::error(channelID, static_cast(APDU_STATUS::SW_INS_NOT_SUPPORTED));
}
diff --git a/U2F_Msg_CMD.hpp b/U2F_Msg_CMD.hpp
index 7ea7782..32b126a 100644
--- a/U2F_Msg_CMD.hpp
+++ b/U2F_Msg_CMD.hpp
@@ -23,7 +23,7 @@ along with this program. If not, see .
#include
#include
-struct U2F_Msg_CMD : U2F_CMD {
+struct U2F_Msg_CMD : public U2F_CMD {
uint8_t cla;
uint8_t ins;
uint8_t p1;
@@ -36,9 +36,10 @@ struct U2F_Msg_CMD : U2F_CMD {
protected:
static uint32_t getLe(const uint32_t byteCount, std::vector bytes);
U2F_Msg_CMD() = default;
+ virtual ~U2F_Msg_CMD() = default;
public:
static std::shared_ptr generate(const U2FMessage& uMsg);
static void error(const uint32_t channelID, const uint16_t errCode);
- void respond(const uint32_t channelID) const;
+ void respond(const uint32_t channelID, bool hasAuthorisation) const;
};
diff --git a/U2F_Ping_CMD.cpp b/U2F_Ping_CMD.cpp
index cbcdcc1..d51ef1b 100644
--- a/U2F_Ping_CMD.cpp
+++ b/U2F_Ping_CMD.cpp
@@ -26,7 +26,7 @@ U2F_Ping_CMD::U2F_Ping_CMD(const U2FMessage& uMsg) : nonce{ uMsg.data } {
throw runtime_error{ "Failed to get U2F ping message" };
}
-void U2F_Ping_CMD::respond(const uint32_t channelID) const {
+void U2F_Ping_CMD::respond(const uint32_t channelID, bool) const {
U2FMessage msg{};
msg.cid = channelID;
msg.cmd = U2FHID_PING;
diff --git a/U2F_Ping_CMD.hpp b/U2F_Ping_CMD.hpp
index ecfa031..f7ef01e 100644
--- a/U2F_Ping_CMD.hpp
+++ b/U2F_Ping_CMD.hpp
@@ -27,5 +27,5 @@ struct U2F_Ping_CMD : U2F_CMD {
public:
U2F_Ping_CMD(const U2FMessage& uMsg);
- virtual void respond(const uint32_t channelID) const override;
+ virtual void respond(const uint32_t channelID, bool) const override;
};
diff --git a/U2F_Register_APDU.cpp b/U2F_Register_APDU.cpp
index e5cc62b..011ec27 100644
--- a/U2F_Register_APDU.cpp
+++ b/U2F_Register_APDU.cpp
@@ -61,7 +61,13 @@ U2F_Register_APDU::U2F_Register_APDU(const U2F_Msg_CMD& msg, const vectorkeyH] = 0;
}
-void U2F_Register_APDU::respond(const uint32_t channelID) const {
+void U2F_Register_APDU::respond(const uint32_t channelID, bool hasAuthorisation) const {
+ if (!hasAuthorisation) {
+ error(channelID, APDU_STATUS::SW_CONDITIONS_NOT_SATISFIED);
+ return;
+ }
+
+
U2FMessage m{};
m.cid = channelID;
m.cmd = U2FHID_MSG;
@@ -118,3 +124,7 @@ void U2F_Register_APDU::respond(const uint32_t channelID) const {
m.write();
}
+
+bool U2F_Register_APDU::requiresAuthorisation() const {
+ return true;
+}
diff --git a/U2F_Register_APDU.hpp b/U2F_Register_APDU.hpp
index db6ca39..4354164 100644
--- a/U2F_Register_APDU.hpp
+++ b/U2F_Register_APDU.hpp
@@ -27,6 +27,8 @@ struct U2F_Register_APDU : U2F_Msg_CMD {
public:
U2F_Register_APDU(const U2F_Msg_CMD& msg, const std::vector& data);
+ virtual ~U2F_Register_APDU() = default;
- void respond(const uint32_t channelID) const override;
+ bool requiresAuthorisation() const override;
+ void respond(const uint32_t channelID, bool hasAuthorisation) const override;
};
diff --git a/U2F_Version_APDU.cpp b/U2F_Version_APDU.cpp
index 0d087bc..a615f7d 100644
--- a/U2F_Version_APDU.cpp
+++ b/U2F_Version_APDU.cpp
@@ -33,7 +33,7 @@ U2F_Version_APDU::U2F_Version_APDU(const U2F_Msg_CMD& msg, const std::vector.
struct U2F_Version_APDU : U2F_Msg_CMD {
public:
U2F_Version_APDU(const U2F_Msg_CMD& msg, const std::vector& data);
- void respond(const uint32_t channelID) const override;
+ void respond(const uint32_t channelID, bool) const override;
};