diff --git a/Channel.cpp b/Channel.cpp new file mode 100644 index 0000000..15e74a2 --- /dev/null +++ b/Channel.cpp @@ -0,0 +1,42 @@ +#include "Channel.hpp" +#include +#include "u2f.hpp" +#include "U2F_CMD.hpp" +#include + +using namespace std; + +Channel::Channel(const uint32_t channelID) + : cid{ channelID }, initState{ ChannelInitState::Unitialised }, lockedState{ ChannelLockedState::Unlocked } +{} + +uint32_t Channel::getCID() const +{ + return cid; +} + +void Channel::handle(const shared_ptr uMsg) +{ + if (uMsg->cmd == U2FHID_INIT) + this->initState = ChannelInitState::Initialised; + else if (uMsg->cid != this->cid) + throw runtime_error{ "CID of request invalid for this channel" }; + + if (this->initState != ChannelInitState::Initialised) + throw runtime_error{ "Channel in incorrect (uninitialized) state to handle request" }; + else if (this->lockedState != ChannelLockedState::Unlocked) + throw runtime_error{ "Channel in incorrect (locked) state to handle request" }; + + clog << "Handling uMsg with CMD: " << static_cast(uMsg->cmd) << endl; + return U2F_CMD::get(uMsg)->respond(this->cid); +} + +void Channel::init(const ChannelInitState newInitState) +{ + this->initState = newInitState; +} + +void Channel::lock(const ChannelLockedState newLockedState) +{ + this->lockedState = newLockedState; +} diff --git a/Channel.hpp b/Channel.hpp new file mode 100644 index 0000000..ffd17e9 --- /dev/null +++ b/Channel.hpp @@ -0,0 +1,32 @@ +#pragma once +#include +#include +#include "U2FMessage.hpp" + +enum class ChannelInitState +{ + Unitialised, + Initialised +}; + +enum class ChannelLockedState +{ + Locked, + Unlocked +}; + +class Channel +{ + protected: + uint32_t cid; + ChannelInitState initState; + ChannelLockedState lockedState; + + public: + Channel(const uint32_t channelID); + void handle(const std::shared_ptr uMsg); + + uint32_t getCID() const; + void init(const ChannelInitState newInitState); + void lock(const ChannelLockedState newLockedState); +}; diff --git a/Constants.hpp b/Constants.hpp deleted file mode 100644 index e2b3d6b..0000000 --- a/Constants.hpp +++ /dev/null @@ -1,4 +0,0 @@ -#pragma once -#include - -const constexpr uint16_t packetSize = 64; diff --git a/Controller.cpp b/Controller.cpp new file mode 100644 index 0000000..a008d96 --- /dev/null +++ b/Controller.cpp @@ -0,0 +1,48 @@ +#include "Controller.hpp" +#include "u2f.hpp" +#include +#include "IO.hpp" + +using namespace std; + +Controller::Controller(const uint32_t startChannel) + : channels{}, currChannel{ startChannel } +{} + +void Controller::handleTransaction() +{ + auto msg = U2FMessage::readNonBlock(); + + if (!msg) + return; + + auto opChannel = msg->cid; + + clog << "Got msg with cmd of: " << static_cast(msg->cmd) << endl; + + if (msg->cmd == U2FHID_INIT) + { + opChannel = nextChannel(); + auto channel = Channel{ opChannel }; + + try + { + channels.emplace(opChannel, channel); //In case of wrap-around replace existing one + } + catch (...) + { + channels.insert(make_pair(opChannel, channel)); + } + } + + channels.at(opChannel).handle(msg); +} + +const uint32_t Controller::nextChannel() +{ + do + currChannel++; + while (currChannel == 0xFFFFFFFF || currChannel == 0); + + return currChannel; +} diff --git a/Controller.hpp b/Controller.hpp new file mode 100644 index 0000000..8b78ed9 --- /dev/null +++ b/Controller.hpp @@ -0,0 +1,16 @@ +#pragma once +#include +#include "Channel.hpp" + +class Controller +{ + protected: + std::map channels; + uint32_t currChannel; + + public: + Controller(const uint32_t startChannel = 1); + + void handleTransaction(); + const uint32_t nextChannel(); +}; diff --git a/IO.cpp b/IO.cpp index 5ec855a..16b97a9 100644 --- a/IO.cpp +++ b/IO.cpp @@ -1,26 +1,94 @@ #include "IO.hpp" #include "Streams.hpp" #include +#include +#include +//#include +#include +#include +#include +#include +#include "u2f.hpp" +#include "Macro.hpp" using namespace std; -vector readBytes(const size_t count) +bool bytesAvailable(const size_t count); +vector& getBuffer(); + +vector readNonBlock(const size_t count) { - vector bytes(count); - - size_t readByteCount; - - do + if (!bytesAvailable(count)) { - readByteCount = fread(bytes.data(), 1, count, getHostStream().get()); - fwrite(bytes.data(), 1, bytes.size(), getComHostStream().get()); - } while (readByteCount == 0); + //clog << "No bytes available" << endl; + return vector{}; + } - clog << "Read " << readByteCount << " bytes" << endl; + auto &buffer = getBuffer(); + auto buffStart = buffer.begin(), buffEnd = buffer.begin() + count; + vector bytes{ buffStart, buffEnd }; + buffer.erase(buffStart, buffEnd); + + fwrite(bytes.data(), 1, bytes.size(), getComHostStream().get()); - if (readByteCount != count) - throw runtime_error{ "Failed to read sufficient bytes" }; + errno = 0; return bytes; } +void write(const uint8_t* bytes, const size_t count) +{ + size_t totalBytes = 0; + auto hostDescriptor = *getHostDescriptor(); + + while (totalBytes < count) + { + auto writtenBytes = write(hostDescriptor, bytes + totalBytes, count - totalBytes); + + if (writtenBytes > 0) + totalBytes += writtenBytes; + else if (errno != 0 && errno != EAGAIN && errno != EWOULDBLOCK) //Expect file blocking behaviour + ERR(); + } + + errno = 0; +} + +bool bytesAvailable(const size_t count) +{ + return getBuffer().size() >= count; +} + +vector& bufferVar() +{ + static vector buffer{}; + return buffer; +} + +vector& getBuffer() +{ + auto &buff = bufferVar(); + array bytes{}; + auto hostDescriptor = *getHostDescriptor(); + + while (true) + { + auto readByteCount = read(hostDescriptor, bytes.data(), HID_RPT_SIZE); + + if (readByteCount > 0) + { + copy(bytes.begin(), bytes.begin() + readByteCount, back_inserter(buff)); + } + else if (errno != EAGAIN && errno != EWOULDBLOCK) //Expect read would block + { + ERR(); + } + else + { + break; //Escape loop if blocking would occur + } + } + + return buff; +} + diff --git a/IO.hpp b/IO.hpp index 60813b1..6804347 100644 --- a/IO.hpp +++ b/IO.hpp @@ -3,4 +3,9 @@ #include #include -std::vector readBytes(const size_t count); +//Returns either the number of bytes specified, +//or returns empty vector without discarding bytes from HID stream +std::vector readNonBlock(const size_t count); + +//Blocking write to HID stream - shouldn't block for too long +void write(const uint8_t* bytes, const size_t count); diff --git a/Macro.hpp b/Macro.hpp new file mode 100644 index 0000000..55072b6 --- /dev/null +++ b/Macro.hpp @@ -0,0 +1,5 @@ +#pragma once +#include +#include + +#define ERR() if (errno != 0) perror((string{ "(" } + __FILE__ + ":" + to_string(__LINE__) + ")" + " " + __PRETTY_FUNCTION__).c_str()), errno = 0 diff --git a/Makefile b/Makefile index 9a2eb53..13ceeea 100755 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ LDFLAGS := -lmbedcrypto CPPFLAGS := CXXFLAGS := --std=c++14 -CXXFLAGS += -MMD -MP +CXXFLAGS += -MMD -MP -Wall -Wfatal-errors -Wextra MODULES := $(wildcard $(SRC_DIR)/*.cpp) OBJECTS := $(MODULES:$(SRC_DIR)/%.cpp=$(OBJ_DIR)/%.o) diff --git a/Packet.cpp b/Packet.cpp index 8f8321c..60e335a 100644 --- a/Packet.cpp +++ b/Packet.cpp @@ -3,20 +3,64 @@ #include "u2f.hpp" #include #include +#include #include "Streams.hpp" using namespace std; shared_ptr InitPacket::getPacket(const uint32_t rCID, const uint8_t rCMD) { - auto p = make_shared(); - p->cid = rCID; - p->cmd = rCMD; - p->bcnth = readBytes(1)[0]; - p->bcntl = readBytes(1)[0]; + static size_t bytesRead = 0; + static uint8_t bcnth; + static uint8_t bcntl; + static decltype(InitPacket::data) dataBytes; + vector bytes{}; - const auto dataBytes = readBytes(p->data.size()); - copy(dataBytes.begin(), dataBytes.end(), p->data.begin()); + switch (bytesRead) + { + case 0: + bytes = readNonBlock(1); + + if (bytes.size() == 0) + return {}; + + bcnth = bytes[0]; + bytesRead += bytes.size(); + [[fallthrough]]; + + case 1: + bytes = readNonBlock(1); + + if (bytes.size() == 0) + return {}; + + bcntl = bytes[0]; + bytesRead += bytes.size(); + [[fallthrough]]; + + case 2: + bytes = readNonBlock(dataBytes.size()); + + if (bytes.size() == 0) + return {}; + + copy(bytes.begin(), bytes.end(), dataBytes.begin());; + bytesRead += bytes.size(); + [[fallthrough]]; + + case 2 + dataBytes.size(): + break; + + default: + throw runtime_error{ "Unknown stage in InitPacket construction" }; + } + + auto p = make_shared(); + p->cid = rCID; + p->cmd = rCMD; + p->bcnth = bcnth; + p->bcntl = bcntl; + p->data = dataBytes; auto hPStream = getHostPacketStream().get(); fprintf(hPStream, "\t\t\n" @@ -47,17 +91,33 @@ shared_ptr InitPacket::getPacket(const uint32_t rCID, const uint8_t "\t\t
"); clog << "Fully read init packet" << endl; + bytesRead = 0; return p; } shared_ptr ContPacket::getPacket(const uint32_t rCID, const uint8_t rSeq) { + static size_t readBytes = 0; + static decltype(ContPacket::data) dataBytes; + + vector bytes{}; auto p = make_shared(); + + if (readBytes != dataBytes.size()) + { + dataBytes = {}; + bytes = readNonBlock(dataBytes.size()); + + if (bytes.size() == 0) + return {}; + + copy(bytes.begin(), bytes.end(), dataBytes.begin()); + readBytes += bytes.size(); + } + p->cid = rCID; p->seq = rSeq; - - const auto dataBytes = readBytes(p->data.size()); - copy(dataBytes.begin(), dataBytes.end(), p->data.begin()); + p->data = dataBytes; auto hPStream = getHostPacketStream().get(); fprintf(hPStream, "\t\t
\n" @@ -84,49 +144,86 @@ shared_ptr ContPacket::getPacket(const uint32_t rCID, const uint8_t "\t\t
"); //clog << "Fully read cont packet" << endl; + readBytes = 0; return p; } shared_ptr Packet::getPacket() { - const uint32_t cid = *reinterpret_cast(readBytes(4).data()); - uint8_t b = readBytes(1)[0]; + static size_t bytesRead = 0; + vector bytes{}; - //clog << "Packet read 2nd byte as " << static_cast(b) << endl; + static uint32_t cid; + static uint8_t b; + shared_ptr packet{}; - if (b & TYPE_MASK) + switch (bytesRead) { - //Init packet - return InitPacket::getPacket(cid, b); - } - else - { - //Cont packet - return ContPacket::getPacket(cid, b); + case 0: + bytes = readNonBlock(4); + + if (bytes.size() == 0) + return {}; + + cid = *reinterpret_cast(bytes.data()); + bytesRead += bytes.size(); + [[fallthrough]]; + + case 4: + bytes = readNonBlock(1); + + if (bytes.size() == 0) + return {}; + + b = bytes[0]; + bytesRead += bytes.size(); + [[fallthrough]]; + + case 5: + if (b & TYPE_MASK) + { + //Init packet + //clog << "Getting init packet" << endl; + packet = InitPacket::getPacket(cid, b); + + if (packet) + bytesRead = 0; + + return packet; + } + else + { + //Cont packet + //clog << "Getting cont packet" << endl; + packet = ContPacket::getPacket(cid, b); + + if (packet) + bytesRead = 0; + + return packet; + } + default: + throw runtime_error{ "Unknown stage in Packet construction" }; } } void Packet::writePacket() { - memset(this->buf, 0, packetSize); + memset(this->buf, 0, HID_RPT_SIZE); memcpy(this->buf, &cid, 4); } void InitPacket::writePacket() { Packet::writePacket(); - auto hostStream = getHostStream().get(); auto devStream = getComDevStream().get(); memcpy(this->buf + 4, &cmd, 1); memcpy(this->buf + 5, &bcnth, 1); memcpy(this->buf + 6, &bcntl, 1); memcpy(this->buf + 7, data.data(), data.size()); - fwrite(this->buf, packetSize, 1, hostStream); - fwrite(this->buf, packetSize, 1, devStream); - - if (errno != 0) - perror("perror " __FILE__ " 85"); + write(this->buf, sizeof(this->buf)); + fwrite(this->buf, 1, sizeof(this->buf), devStream); auto dPStream = getDevPacketStream().get(); fprintf(dPStream, "\t\t
\n" @@ -155,23 +252,17 @@ void InitPacket::writePacket() "\t\t\t\n" "\t\t
" "\t\t
"); - - clog << "Fully wrote init packet" << endl; } void ContPacket::writePacket() { Packet::writePacket(); - auto hostStream = getHostStream().get(); auto devStream = getComDevStream().get(); memcpy(this->buf + 4, &seq, 1); memcpy(this->buf + 5, data.data(), data.size()); - fwrite(this->buf, packetSize, 1, hostStream); - fwrite(this->buf, packetSize, 1, devStream); - - if (errno != 0) - perror("perror " __FILE__ " 107"); + write(this->buf, HID_RPT_SIZE); + fwrite(this->buf, HID_RPT_SIZE, 1, devStream); auto dPStream = getDevPacketStream().get(); diff --git a/Packet.hpp b/Packet.hpp index 9a20d51..0d701bc 100644 --- a/Packet.hpp +++ b/Packet.hpp @@ -2,12 +2,13 @@ #include #include #include -#include "Constants.hpp" +#include "u2f.hpp" struct Packet { - uint32_t cid; - uint8_t buf[packetSize]; + public: + uint32_t cid; + uint8_t buf[HID_RPT_SIZE]; protected: Packet() = default; @@ -20,10 +21,11 @@ struct Packet struct InitPacket : Packet { - uint8_t cmd; - uint8_t bcnth; - uint8_t bcntl; - std::array data{}; + public: + uint8_t cmd; + uint8_t bcnth; + uint8_t bcntl; + std::array data{}; public: InitPacket() = default; @@ -33,8 +35,9 @@ struct InitPacket : Packet struct ContPacket : Packet { - uint8_t seq; - std::array data{}; + public: + uint8_t seq; + std::array data{}; public: ContPacket() = default; diff --git a/Streams.cpp b/Streams.cpp index 2015dfe..06b6cde 100644 --- a/Streams.cpp +++ b/Streams.cpp @@ -1,21 +1,29 @@ #include "Streams.hpp" #include +#include +#include +#include +#include +#include using namespace std; FILE* initHTML(FILE *fPtr, const string &title); void closeHTML(FILE *fPtr); -shared_ptr getHostStream() +shared_ptr getHostDescriptor() { - static shared_ptr stream{ fopen("/dev/hidg0", "ab+"), [](FILE *f){ - fclose(f); - } }; + static shared_ptr descriptor{}; + + descriptor.reset(new int{ open("/dev/hidg0", O_RDWR | O_NONBLOCK | O_APPEND) }, [](int* fd){ + close(*fd); + delete fd; + }); - if (!stream) - clog << "Stream is unavailable" << endl; + if (*descriptor == -1) + throw runtime_error{ "Descriptor is unavailable" }; - return stream; + return descriptor; } shared_ptr getComHostStream() diff --git a/Streams.hpp b/Streams.hpp index c0d87e3..9b306ec 100644 --- a/Streams.hpp +++ b/Streams.hpp @@ -2,7 +2,7 @@ #include #include -std::shared_ptr getHostStream(); +std::shared_ptr getHostDescriptor(); std::shared_ptr getComHostStream(); std::shared_ptr getHostPacketStream(); std::shared_ptr getHostAPDUStream(); diff --git a/U2FMessage.cpp b/U2FMessage.cpp index e8ae197..1ab8e45 100644 --- a/U2FMessage.cpp +++ b/U2FMessage.cpp @@ -8,40 +8,83 @@ using namespace std; -U2FMessage U2FMessage::read() +shared_ptr U2FMessage::readNonBlock() { - auto fPack = dynamic_pointer_cast(Packet::getPacket()); + static size_t currSeq = -1; + static uint16_t messageSize; + static uint32_t cid; + static uint8_t cmd; + static vector dataBytes; - if (!fPack) - throw runtime_error{ "Failed to receive Init packet" }; + shared_ptr p{}; - const uint16_t messageSize = ((static_cast(fPack->bcnth) << 8u) + fPack->bcntl); - - clog << "Message on channel 0x" << hex << fPack->cid << dec << " has size: " << messageSize << endl; - - const uint16_t copyByteCount = min(static_cast(fPack->data.size()), messageSize); - - U2FMessage message{ fPack-> cid, fPack->cmd }; - message.data.assign(fPack->data.begin(), fPack->data.begin() + copyByteCount); - - uint8_t currSeq = 0; - - while (message.data.size() < messageSize) + if (currSeq == -1u) { - auto newPack = dynamic_pointer_cast(Packet::getPacket()); + cid = 0; + cmd = 0; + messageSize = 0; + dataBytes = {}; + + shared_ptr initPack{}; + do + { + p = Packet::getPacket(); - if (!newPack) - throw runtime_error{ "Failed to receive Cont packet" }; - else if (newPack->seq != currSeq) - throw runtime_error{ "Packet out of sequence" }; + if (!p) + return {}; - const uint16_t remainingBytes = messageSize - message.data.size(); - const uint16_t copyBytes = min(static_cast(newPack->data.size()), remainingBytes); - message.data.insert(message.data.end(), newPack->data.begin(), newPack->data.begin() + copyBytes); + initPack = dynamic_pointer_cast(p); + } while (!initPack); //Spurious cont. packet - spec states ignore + messageSize = ((static_cast(initPack->bcnth) << 8u) + initPack->bcntl); + const uint16_t copyByteCount = min(static_cast(initPack->data.size()), messageSize); + + cid = initPack->cid; + cmd = initPack->cmd; + + copy(initPack->data.begin(), initPack->data.begin() + copyByteCount, back_inserter(dataBytes)); currSeq++; } + while (messageSize > dataBytes.size() && static_cast(p = Packet::getPacket())) //While there is a packet + { + auto contPack = dynamic_pointer_cast(p); + + if (!contPack) //Spurious init. packet + { + currSeq = -1; //Reset + return {}; + } + + if (contPack->cid != cid) //Cont. packet of different CID + { + cerr << "Invalid CID: was handling channel 0x" << hex << cid << " and received packet from channel 0x" << contPack->cid << dec << endl; + U2FMessage::error(contPack->cid, ERR_CHANNEL_BUSY); + currSeq = -1; + return {}; + } + + if (contPack->seq != currSeq) + { + cerr << "Invalid packet seq. value" << endl; + U2FMessage::error(cid, ERR_INVALID_SEQ); + currSeq = -1; + return {}; + } + + const uint16_t remainingBytes = messageSize - dataBytes.size(); + const uint16_t copyBytes = min(static_cast(contPack->data.size()), remainingBytes); + dataBytes.insert(dataBytes.end(), contPack->data.begin(), contPack->data.begin() + copyBytes); + currSeq++; + } + + if (messageSize != dataBytes.size()) + return {}; + + auto message = make_shared(cid, cmd); + message->data.assign(dataBytes.begin(), dataBytes.end()); + currSeq = -1u; + std::clog << "Read all of message" << std::endl; return message; @@ -49,9 +92,6 @@ U2FMessage U2FMessage::read() void U2FMessage::write() { - clog << "Flushing host stream" << endl; - fflush(getHostStream().get()); - clog << "Flushed host stream" << endl; const uint16_t bytesToWrite = this->data.size(); uint16_t bytesWritten = 0; @@ -88,8 +128,7 @@ void U2FMessage::write() bytesWritten += newByteCount; } - auto stream = getHostStream().get(); - fflush(stream); + //auto stream = *getHostStream(); if (cmd == U2FHID_MSG) { @@ -120,3 +159,16 @@ void U2FMessage::write() "\t\t
", err); } } + +U2FMessage::U2FMessage(const uint32_t nCID, const uint8_t nCMD) + : cid{ nCID }, cmd{ nCMD } +{} + +void U2FMessage::error(const uint32_t tCID, const uint16_t tErr) +{ + U2FMessage msg{}; + msg.cid = tCID; + msg.cmd = U2FHID_ERROR; + msg.data.push_back((tErr >> 8) & 0xFF); + msg.data.push_back(tErr & 0xFF); +} diff --git a/U2FMessage.hpp b/U2FMessage.hpp index a5b1784..869b36f 100644 --- a/U2FMessage.hpp +++ b/U2FMessage.hpp @@ -1,14 +1,21 @@ #pragma once #include #include +#include struct U2FMessage { - uint32_t cid; - uint8_t cmd; - std::vector data; + public: + uint32_t cid; + uint8_t cmd; + std::vector data; - static U2FMessage read(); + public: + U2FMessage() = default; + U2FMessage(const uint32_t nCID, const uint8_t nCMD); + static std::shared_ptr readNonBlock(); + void write(); - void write(); + protected: + static void error(const uint32_t tCID, const uint16_t tErr); }; diff --git a/U2F_Authenticate_APDU.cpp b/U2F_Authenticate_APDU.cpp index a66d2b1..5a6b409 100644 --- a/U2F_Authenticate_APDU.cpp +++ b/U2F_Authenticate_APDU.cpp @@ -27,10 +27,10 @@ U2F_Authenticate_APDU::U2F_Authenticate_APDU(const U2F_Msg_CMD &msg, const vecto clog << "Got U2F_Auth request" << endl; } -void U2F_Authenticate_APDU::respond() +void U2F_Authenticate_APDU::respond(const uint32_t channelID) const { U2FMessage msg{}; - msg.cid = 0xF1D0F1D0; + msg.cid = channelID; msg.cmd = U2FHID_MSG; auto statusCode = APDU_STATUS::SW_NO_ERROR; diff --git a/U2F_Authenticate_APDU.hpp b/U2F_Authenticate_APDU.hpp index 6023c40..8ed51f3 100644 --- a/U2F_Authenticate_APDU.hpp +++ b/U2F_Authenticate_APDU.hpp @@ -12,7 +12,7 @@ struct U2F_Authenticate_APDU : U2F_Msg_CMD public: U2F_Authenticate_APDU(const U2F_Msg_CMD &msg, const std::vector &data); - void respond(); + virtual void respond(const uint32_t channelID) const override; enum ControlCode { diff --git a/U2F_CMD.cpp b/U2F_CMD.cpp new file mode 100644 index 0000000..b50ca80 --- /dev/null +++ b/U2F_CMD.cpp @@ -0,0 +1,19 @@ +#include "U2F_CMD.hpp" +#include "u2f.hpp" +#include "U2F_Msg_CMD.hpp" +#include "U2F_Init_CMD.hpp" + +using namespace std; + +shared_ptr U2F_CMD::get(const shared_ptr uMsg) +{ + switch (uMsg->cmd) + { + case U2FHID_MSG: + return U2F_Msg_CMD::generate(uMsg); + case U2FHID_INIT: + return make_shared(uMsg); + default: + return {}; + } +} diff --git a/U2F_CMD.hpp b/U2F_CMD.hpp index 0f1a954..31e0560 100644 --- a/U2F_CMD.hpp +++ b/U2F_CMD.hpp @@ -1,4 +1,6 @@ #pragma once +#include +#include "U2FMessage.hpp" struct U2F_CMD { @@ -6,5 +8,7 @@ struct U2F_CMD U2F_CMD() = default; public: - ~U2F_CMD() = default; + virtual ~U2F_CMD() = default; + static std::shared_ptr get(const std::shared_ptr uMsg); + virtual void respond(const uint32_t channelID) const = 0; }; //For polymorphic type casting diff --git a/U2F_Init_CMD.cpp b/U2F_Init_CMD.cpp index ef67671..5cfea69 100644 --- a/U2F_Init_CMD.cpp +++ b/U2F_Init_CMD.cpp @@ -1,24 +1,33 @@ #include "U2F_Init_CMD.hpp" #include -#include "U2FMessage.hpp" #include "u2f.hpp" -#include +#include "Field.hpp" using namespace std; -U2F_Init_CMD U2F_Init_CMD::get() +U2F_Init_CMD::U2F_Init_CMD(const shared_ptr uMsg) { - const auto message = U2FMessage::read(); - - if (message.cmd != U2FHID_INIT) + if (uMsg->cmd != U2FHID_INIT) throw runtime_error{ "Failed to get U2F Init message" }; - else if (message.data.size() != INIT_NONCE_SIZE) + else if (uMsg->data.size() != INIT_NONCE_SIZE) throw runtime_error{ "Init nonce is incorrect size" }; - U2F_Init_CMD cmd; - cmd.nonce = *reinterpret_cast(message.data.data()); - - clog << "Fully read nonce" << endl; - - return cmd; + this->nonce = *reinterpret_cast(uMsg->data.data()); +} + +void U2F_Init_CMD::respond(const uint32_t channelID) const +{ + U2FMessage msg{}; + msg.cid = CID_BROADCAST; + msg.cmd = U2FHID_INIT; + + msg.data.insert(msg.data.end(), FIELD(this->nonce)); + msg.data.insert(msg.data.end(), FIELD(channelID)); + msg.data.push_back(2); //Protocol version + msg.data.push_back(1); //Major device version + msg.data.push_back(0); //Minor device version + msg.data.push_back(1); //Build device version + msg.data.push_back(CAPFLAG_WINK); //Wink capability + + msg.write(); } diff --git a/U2F_Init_CMD.hpp b/U2F_Init_CMD.hpp index 62cf408..d202e9f 100644 --- a/U2F_Init_CMD.hpp +++ b/U2F_Init_CMD.hpp @@ -1,11 +1,14 @@ #pragma once #include +#include #include "U2F_CMD.hpp" +#include "U2FMessage.hpp" struct U2F_Init_CMD : U2F_CMD { uint64_t nonce; public: - static U2F_Init_CMD get(); + U2F_Init_CMD(const std::shared_ptr uMsg); + virtual void respond(const uint32_t channelID) const override; }; diff --git a/U2F_Init_Response.hpp b/U2F_Init_Response.hpp deleted file mode 100644 index bd940cf..0000000 --- a/U2F_Init_Response.hpp +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once -#include -#include "U2FMessage.hpp" -#include "Field.hpp" - -struct U2F_Init_Response : U2F_CMD -{ - uint32_t cid; - uint64_t nonce; - uint8_t protocolVer; - uint8_t majorDevVer; - uint8_t minorDevVer; - uint8_t buildDevVer; - uint8_t capabilities; - - void write() - { - std::clog << "Beginning writeout of U2F_Init_Response" << std::endl; - U2FMessage m{}; - m.cid = CID_BROADCAST; - m.cmd = U2FHID_INIT; - - m.data.insert(m.data.begin() + 0, FIELD(nonce)); - m.data.insert(m.data.begin() + 8, FIELD(cid)); - m.data.insert(m.data.begin() + 12, FIELD(protocolVer)); - m.data.insert(m.data.begin() + 13, FIELD(majorDevVer)); - m.data.insert(m.data.begin() + 14, FIELD(minorDevVer)); - m.data.insert(m.data.begin() + 15, FIELD(buildDevVer)); - m.data.insert(m.data.begin() + 16, FIELD(capabilities)); - - std::clog << "Finished inserting U2F_Init_Response fields to data buffer" << std::endl; - - m.write(); - std::clog << "Completed writeout of U2F_Init_Response" << std::endl; - } -}; diff --git a/U2F_Msg_CMD.cpp b/U2F_Msg_CMD.cpp index a85bdfd..7aedb64 100644 --- a/U2F_Msg_CMD.cpp +++ b/U2F_Msg_CMD.cpp @@ -8,16 +8,17 @@ #include "APDU.hpp" #include #include "Streams.hpp" +#include "Field.hpp" using namespace std; uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector bytes) { - if (byteCount > 3) - throw runtime_error{ "Too much data for command" }; if (byteCount != 0) { //Le must be length of data in bytes + clog << "Le must be length of data in bytes" << endl; + clog << "Le has a size of " << byteCount << " bytes" << endl; switch (byteCount) { @@ -27,44 +28,47 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector bytes) //Don't handle non-compliance with spec here //This case is only possible if extended Lc used //CBA - return (bytes[0] == 0 && bytes[1] == 0 ? 65536 : bytes[0] << 8 + bytes[1]); + return (bytes[0] == 0 && bytes[1] == 0 ? 65536 : (bytes[0] << 8) + bytes[1]); case 3: //Don't handle non-compliance with spec here //This case is only possible if extended Lc not used //CBA if (bytes[0] != 0) throw runtime_error{ "First byte of 3-byte Le should be 0"}; - return (bytes[1] == 0 & bytes[2] == 0 ? 65536 : bytes[1] << 8 + bytes[2]); + return (bytes[1] == 0 && bytes[2] == 0 ? 65536 : (bytes[1] << 8) + bytes[2]); + default: + throw runtime_error{ "Too much data for command" }; } } else return 0; } -shared_ptr U2F_Msg_CMD::get() +shared_ptr U2F_Msg_CMD::generate(const shared_ptr uMsg) { - const auto message = U2FMessage::read(); - - if (message.cmd != U2FHID_MSG) - throw runtime_error{ "Failed to get U2F Msg message" }; + if (uMsg->cmd != U2FHID_MSG) + throw runtime_error{ "Failed to get U2F Msg uMsg" }; + else if (uMsg->data.size() < 4) + throw runtime_error{ "Msg data is incorrect size" }; - U2F_Msg_CMD cmd{}; - auto &dat = message.data; + U2F_Msg_CMD cmd; + auto &dat = uMsg->data; cmd.cla = dat[0]; cmd.ins = dat[1]; cmd.p1 = dat[2]; cmd.p2 = dat[3]; - const uint32_t cBCount = dat.size() - 4; + clog << "Loaded U2F_Msg_CMD parameters" << endl; vector data{ dat.begin() + 4, dat.end() }; + const uint32_t cBCount = data.size(); auto startPtr = data.begin(), endPtr = data.end(); + clog << "Loaded iters" << endl; + if (usesData.at(cmd.ins) || data.size() > 3) { - //clog << "First bytes are: " << static_cast(data[0]) << " " << static_cast(data[1]) << " " << static_cast(data[2]) << endl; - if (cBCount == 0) throw runtime_error{ "Invalid command - should have attached data" }; @@ -81,17 +85,22 @@ shared_ptr U2F_Msg_CMD::get() endPtr = startPtr + cmd.lc; + clog << "Getting Le" << endl; cmd.le = getLe(data.end() - endPtr, vector(endPtr, data.end())); } else { cmd.lc = 0; endPtr = startPtr; + + clog << "Getting Le" << endl; cmd.le = getLe(cBCount, data); } const auto dBytes = vector(startPtr, endPtr); + clog << "Determined message format" << endl; + auto hAS = getHostAPDUStream().get(); fprintf(hAS, "\n" @@ -125,6 +134,8 @@ shared_ptr U2F_Msg_CMD::get() "\t\t
\n" "\t\t
", cmd.le); + clog << "Constructing message specialisation" << endl; + switch (cmd.ins) { case APDU::U2F_REG: @@ -138,11 +149,18 @@ shared_ptr U2F_Msg_CMD::get() } } -void U2F_Msg_CMD::respond(){}; - const map U2F_Msg_CMD::usesData = { { U2F_REG, true }, { U2F_AUTH, true }, { U2F_VER, false } }; +void U2F_Msg_CMD::respond(const uint32_t channelID) const +{ + U2FMessage msg{}; + msg.cid = channelID; + msg.cmd = U2FHID_MSG; + auto errorCode = APDU_STATUS::SW_INS_NOT_SUPPORTED; + msg.data.insert(msg.data.end(), FIELD_BE(errorCode)); + msg.write(); +} diff --git a/U2F_Msg_CMD.hpp b/U2F_Msg_CMD.hpp index 24bb64e..aa5ffdc 100644 --- a/U2F_Msg_CMD.hpp +++ b/U2F_Msg_CMD.hpp @@ -18,10 +18,10 @@ struct U2F_Msg_CMD : U2F_CMD protected: static uint32_t getLe(const uint32_t byteCount, std::vector bytes); + U2F_Msg_CMD() = default; public: - static std::shared_ptr get(); - - virtual void respond(); + static std::shared_ptr generate(const std::shared_ptr uMsg); + void respond(const uint32_t channelID) const; }; diff --git a/U2F_Register_APDU.cpp b/U2F_Register_APDU.cpp index 6b2e78c..a67ee81 100644 --- a/U2F_Register_APDU.cpp +++ b/U2F_Register_APDU.cpp @@ -50,10 +50,10 @@ U2F_Register_APDU::U2F_Register_APDU(const U2F_Msg_CMD &msg, const vectorkeyH)); - auto fakeKeyHBytes = reinterpret_cast(&this->keyH); + auto fakeKeyHBytes = reinterpret_cast(&this->keyH); copy(fakeKeyHBytes, fakeKeyHBytes + sizeof(this->keyH), back_inserter(response)); copy(attestCert, end(attestCert), back_inserter(response)); @@ -83,9 +83,9 @@ void U2F_Register_APDU::respond() mbedtls_sha256_update(&shaContext, reinterpret_cast(appParam.data()), appParam.size()); - mbedtls_sha256_update(&shaContext, reinterpret_cast(challengeP.data()), challengeP.size()); + mbedtls_sha256_update(&shaContext, reinterpret_cast(challengeP.data()), challengeP.size()); - mbedtls_sha256_update(&shaContext, reinterpret_cast(&keyH), sizeof(keyH)); + mbedtls_sha256_update(&shaContext, reinterpret_cast(&keyH), sizeof(keyH)); mbedtls_sha256_update(&shaContext, reinterpret_cast(pubKey.data()), pubKey.size()); diff --git a/U2F_Register_APDU.hpp b/U2F_Register_APDU.hpp index e28faef..3a31858 100644 --- a/U2F_Register_APDU.hpp +++ b/U2F_Register_APDU.hpp @@ -11,6 +11,6 @@ struct U2F_Register_APDU : U2F_Msg_CMD public: U2F_Register_APDU(const U2F_Msg_CMD &msg, const std::vector &data); - void respond(); + void respond(const uint32_t channelID) const override; }; diff --git a/U2F_Version_APDU.cpp b/U2F_Version_APDU.cpp index 2f70779..3d1df44 100644 --- a/U2F_Version_APDU.cpp +++ b/U2F_Version_APDU.cpp @@ -14,12 +14,12 @@ U2F_Version_APDU::U2F_Version_APDU(const U2F_Msg_CMD &msg) //Don't actually respond yet } -void U2F_Version_APDU::respond() +void U2F_Version_APDU::respond(const uint32_t channelID) const { char ver[]{ 'U', '2', 'F', '_', 'V', '2' }; U2FMessage m{}; - m.cid = 0xF1D0F1D0; + m.cid = channelID; m.cmd = U2FHID_MSG; m.data.insert(m.data.end(), FIELD(ver)); auto sC = APDU_STATUS::SW_NO_ERROR; diff --git a/U2F_Version_APDU.hpp b/U2F_Version_APDU.hpp index fbc41e2..37d0105 100644 --- a/U2F_Version_APDU.hpp +++ b/U2F_Version_APDU.hpp @@ -5,5 +5,5 @@ struct U2F_Version_APDU : U2F_Msg_CMD { public: U2F_Version_APDU(const U2F_Msg_CMD &msg); - void respond(); + void respond(const uint32_t channelID) const override; }; diff --git a/monitor.cpp b/monitor.cpp index 89f7f27..0b9b771 100644 --- a/monitor.cpp +++ b/monitor.cpp @@ -1,60 +1,35 @@ -#include -#include -#include -#include -#include -#include #include -#include -#include -#include "u2f.hpp" -#include "Constants.hpp" -#include "U2FMessage.hpp" -#include "U2F_CMD.hpp" #include "Storage.hpp" -#include "U2F_Init_CMD.hpp" -#include "U2F_Init_Response.hpp" -#include "U2F_Msg_CMD.hpp" +#include "Controller.hpp" +#include +#include using namespace std; -int main(int argc, char** argv) +void signalCallback(int signum); + +volatile bool contProc = true; + +int main() { + signal(SIGINT, signalCallback); Storage::init(); - auto initFrame = U2F_Init_CMD::get(); + Controller ch{ 0xF1D00000 }; - U2F_Init_Response resp{}; - - resp.cid = 0xF1D0F1D0; - resp.nonce = initFrame.nonce; - resp.protocolVer = 2; - resp.majorDevVer = 1; - resp.minorDevVer = 0; - resp.buildDevVer = 1; - resp.capabilities = CAPFLAG_WINK; - - resp.write(); - - size_t msgCount = (argc == 2 ? stoul(argv[1]) : 3u); - - for (size_t i = 0; i < msgCount; i++) + while (contProc) { - auto reg = U2F_Msg_CMD::get(); - - cout << "U2F CMD ins: " << static_cast(reg->ins) << endl; - - reg->respond(); - - clog << "Fully responded" << endl; + ch.handleTransaction(); + usleep(10000); } - /*auto m = U2FMessage::read(); - - cout << "U2F CMD: " << static_cast(m.cmd) << endl; - - for (const auto d : m.data) - clog << static_cast(d) << endl;*/ - Storage::save(); + + return EXIT_SUCCESS; +} + +void signalCallback(int signum) +{ + contProc = false; + clog << "Caught SIGINT signal" << endl; }