#include #include #include #include #include #include #include #include "u2f.hpp" using namespace std; const constexpr uint16_t packetSize = 32; struct Packet { uint32_t cid; protected: Packet() = default; virtual void writePacket(); public: static shared_ptr getPacket(); virtual ~Packet() = default; }; struct InitPacket : Packet { uint8_t cmd; uint8_t bcnth; uint8_t bcntl; array data{}; public: InitPacket() = default; static shared_ptr getPacket(const uint32_t rCID, const uint8_t rCMD); void writePacket() override; }; struct ContPacket : Packet { uint8_t seq; array data{}; public: ContPacket() = default; static shared_ptr getPacket(const uint32_t rCID, const uint8_t rSeq); void writePacket() override; }; shared_ptr getStream() { static shared_ptr stream{ fopen("/dev/hidg0", "ab+"), [](FILE *f){ fclose(f); } }; if (!stream) clog << "Stream is unavailable" << endl; return stream; } vector readBytes(const size_t count) { vector bytes(count); size_t readByteCount; do { readByteCount = fread(bytes.data(), 1, count, getStream().get()); } while (readByteCount == 0); clog << "Read " << readByteCount << " bytes" << endl; if (readByteCount != count) throw runtime_error{ "Failed to read sufficient bytes" }; return bytes; } shared_ptr InitPacket::getPacket(const uint32_t rCID, const uint8_t rCMD) { auto p = make_shared(); p->cid = rCID; p->cmd = rCMD; p->bcnth = readBytes(1)[0]; p->bcntl = readBytes(1)[0]; /*uint16_t pLen = p->bcnth; p->bcnth <<= 8; p->bcnth += p->bcntl; */ const auto dataBytes = readBytes(p->data.size()); copy(dataBytes.begin(), dataBytes.end(), p->data.data()); clog << "Fully read init packet" << endl; return p; } shared_ptr ContPacket::getPacket(const uint32_t rCID, const uint8_t rSeq) { auto p = make_shared(); p->cid = rCID; p->seq = rSeq; const auto dataBytes = readBytes(p->data.size()); copy(dataBytes.begin(), dataBytes.end(), p->data.data()); clog << "Fully read cont packet" << endl; return p; } shared_ptr Packet::getPacket() { const uint32_t cid = *reinterpret_cast(readBytes(4).data()); uint8_t b = readBytes(1)[0]; clog << "Packet read 2nd byte as " << static_cast(b) << endl; if (b && TYPE_MASK) { //Init packet return InitPacket::getPacket(cid, b); } else { //Cont packet return ContPacket::getPacket(cid, b); } } void Packet::writePacket() { //auto stream = getStream().get(); auto stream = stdout; fwrite(&cid, 4, 1, stream); } void InitPacket::writePacket() { Packet::writePacket(); //auto stream = getStream().get(); auto stream = stdout; fwrite(&cmd, 1, 1, stream); fwrite(&bcnth, 1, 1, stream); fwrite(&bcntl, 1, 1, stream); fwrite(data.data(), data.size(), 1, stream); clog << "Fully wrote init packet" << endl; } void ContPacket::writePacket() { Packet::writePacket(); //auto stream = getStream().get(); auto stream = stdout; fwrite(&seq, 1, 1, stream); fwrite(data.data(), data.size(), 1, stream); clog << "Fully wrote cont packet" << endl; } struct U2FMessage { uint32_t cid; uint8_t cmd; vector data; static U2FMessage read() { auto fPack = dynamic_pointer_cast(Packet::getPacket()); if (!fPack) throw runtime_error{ "Failed to receive Init packet" }; const uint16_t messageSize = ((static_cast(fPack->bcnth) << 8u) + fPack->bcntl); clog << "Message has size: " << messageSize << endl; const uint16_t copyByteCount = min(static_cast(fPack->data.size()), messageSize); U2FMessage message{ fPack-> cid, fPack->cmd }; message.data.assign(fPack->data.begin(), fPack->data.begin() + copyByteCount); uint8_t currSeq = 0; while (message.data.size() < messageSize) { auto newPack = dynamic_pointer_cast(Packet::getPacket()); if (!newPack) throw runtime_error{ "Failed to receive Cont packet" }; else if (newPack->seq != currSeq) throw runtime_error{ "Packet out of sequence" }; const uint16_t remainingBytes = messageSize - message.data.size(); const uint16_t copyBytes = min(static_cast(newPack->data.size()), remainingBytes); message.data.insert(message.data.end(), newPack->data.begin(), newPack->data.begin() + copyBytes); currSeq++; } return message; } void write() { const uint16_t bytesToWrite = this->data.size(); uint16_t bytesWritten = 0; { const uint8_t bcnth = bytesToWrite >> 8; const uint8_t bcntl = bytesToWrite - (bcnth << 8); InitPacket p{}; p.cid = cid; p.cmd = cmd; p.bcnth = bcnth; p.bcntl = bcntl; { uint16_t initialByteCount = min(static_cast(p.data.size()), static_cast(bytesToWrite - bytesWritten)); copy(data.begin(), data.begin() + initialByteCount, p.data.begin()); bytesWritten += initialByteCount; } p.writePacket(); } uint8_t seq = 0; while (bytesWritten != bytesToWrite) { ContPacket p{}; p.cid = cid; p.seq = seq; uint16_t newByteCount = min(static_cast(p.data.size()), static_cast(bytesToWrite - bytesWritten)); copy(data.begin() + bytesWritten, data.begin() + bytesWritten + newByteCount, p.data.begin()); p.writePacket(); seq++; } auto stream = getStream().get(); fflush(stream); } }; struct U2F_CMD { protected: U2F_CMD() = default; public: ~U2F_CMD() = default; }; //For polymorphic type casting struct U2F_Init_CMD : U2F_CMD { uint64_t nonce; public: static U2F_Init_CMD get() { const auto message = U2FMessage::read(); if (message.cmd != U2FHID_INIT) throw runtime_error{ "Failed to get U2F Init message" }; else if (message.data.size() != INIT_NONCE_SIZE) throw runtime_error{ "Init nonce is incorrect size" }; U2F_Init_CMD cmd; cmd.nonce = *reinterpret_cast(message.data.data()); return cmd; } }; #define FIELD(name) reinterpret_cast(&name), (reinterpret_cast(&name) + sizeof(name)) struct U2F_Init_Response : U2F_CMD { uint32_t cid; uint64_t nonce; uint8_t protocolVer; uint8_t majorDevVer; uint8_t minorDevVer; uint8_t buildDevVer; uint8_t capabilities; void write() { U2FMessage m{}; m.cid = CID_BROADCAST; m.cmd = U2FHID_INIT; m.data.insert(m.data.begin() + 0, FIELD(nonce)); m.data.insert(m.data.begin() + 8, FIELD(cid)); m.data.insert(m.data.begin() + 12, FIELD(protocolVer)); m.data.insert(m.data.begin() + 13, FIELD(majorDevVer)); m.data.insert(m.data.begin() + 14, FIELD(minorDevVer)); m.data.insert(m.data.begin() + 15, FIELD(buildDevVer)); m.data.insert(m.data.begin() + 16, FIELD(capabilities)); m.write(); } }; int main() { auto initFrame = U2F_Init_CMD::get(); U2F_Init_Response resp{}; resp.cid = 0xF1D0F1D0; resp.nonce = initFrame.nonce; resp.protocolVer = 2; resp.majorDevVer = 1; resp.minorDevVer = 0; resp.buildDevVer = 1; resp.capabilities = CAPFLAG_WINK; resp.write(); U2FMessage m = m.read(); for (const auto d : m.data) clog << static_cast(d) << endl; }