Added resumable loading, along with signal handling.

This version can store state after receiving SIGINT.
This is achieved by polling FIFO read state;
This commit is contained in:
2018-08-09 20:23:23 +00:00
parent 673577a601
commit 48840ad36c
29 changed files with 598 additions and 233 deletions

42
Channel.cpp Normal file
View File

@@ -0,0 +1,42 @@
#include "Channel.hpp"
#include <stdexcept>
#include "u2f.hpp"
#include "U2F_CMD.hpp"
#include <iostream>
using namespace std;
Channel::Channel(const uint32_t channelID)
: cid{ channelID }, initState{ ChannelInitState::Unitialised }, lockedState{ ChannelLockedState::Unlocked }
{}
uint32_t Channel::getCID() const
{
return cid;
}
void Channel::handle(const shared_ptr<U2FMessage> uMsg)
{
if (uMsg->cmd == U2FHID_INIT)
this->initState = ChannelInitState::Initialised;
else if (uMsg->cid != this->cid)
throw runtime_error{ "CID of request invalid for this channel" };
if (this->initState != ChannelInitState::Initialised)
throw runtime_error{ "Channel in incorrect (uninitialized) state to handle request" };
else if (this->lockedState != ChannelLockedState::Unlocked)
throw runtime_error{ "Channel in incorrect (locked) state to handle request" };
clog << "Handling uMsg with CMD: " << static_cast<uint32_t>(uMsg->cmd) << endl;
return U2F_CMD::get(uMsg)->respond(this->cid);
}
void Channel::init(const ChannelInitState newInitState)
{
this->initState = newInitState;
}
void Channel::lock(const ChannelLockedState newLockedState)
{
this->lockedState = newLockedState;
}

32
Channel.hpp Normal file
View File

@@ -0,0 +1,32 @@
#pragma once
#include <cstdint>
#include <memory>
#include "U2FMessage.hpp"
enum class ChannelInitState
{
Unitialised,
Initialised
};
enum class ChannelLockedState
{
Locked,
Unlocked
};
class Channel
{
protected:
uint32_t cid;
ChannelInitState initState;
ChannelLockedState lockedState;
public:
Channel(const uint32_t channelID);
void handle(const std::shared_ptr<U2FMessage> uMsg);
uint32_t getCID() const;
void init(const ChannelInitState newInitState);
void lock(const ChannelLockedState newLockedState);
};

View File

@@ -1,4 +0,0 @@
#pragma once
#include <cstdint>
const constexpr uint16_t packetSize = 64;

48
Controller.cpp Normal file
View File

@@ -0,0 +1,48 @@
#include "Controller.hpp"
#include "u2f.hpp"
#include <iostream>
#include "IO.hpp"
using namespace std;
Controller::Controller(const uint32_t startChannel)
: channels{}, currChannel{ startChannel }
{}
void Controller::handleTransaction()
{
auto msg = U2FMessage::readNonBlock();
if (!msg)
return;
auto opChannel = msg->cid;
clog << "Got msg with cmd of: " << static_cast<uint16_t>(msg->cmd) << endl;
if (msg->cmd == U2FHID_INIT)
{
opChannel = nextChannel();
auto channel = Channel{ opChannel };
try
{
channels.emplace(opChannel, channel); //In case of wrap-around replace existing one
}
catch (...)
{
channels.insert(make_pair(opChannel, channel));
}
}
channels.at(opChannel).handle(msg);
}
const uint32_t Controller::nextChannel()
{
do
currChannel++;
while (currChannel == 0xFFFFFFFF || currChannel == 0);
return currChannel;
}

16
Controller.hpp Normal file
View File

@@ -0,0 +1,16 @@
#pragma once
#include <map>
#include "Channel.hpp"
class Controller
{
protected:
std::map<uint32_t, Channel> channels;
uint32_t currChannel;
public:
Controller(const uint32_t startChannel = 1);
void handleTransaction();
const uint32_t nextChannel();
};

92
IO.cpp
View File

@@ -1,26 +1,94 @@
#include "IO.hpp" #include "IO.hpp"
#include "Streams.hpp" #include "Streams.hpp"
#include <iostream> #include <iostream>
#include <unistd.h>
#include <stropts.h>
//#include <sys/ioctl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include "u2f.hpp"
#include "Macro.hpp"
using namespace std; using namespace std;
vector<uint8_t> readBytes(const size_t count) bool bytesAvailable(const size_t count);
vector<uint8_t>& getBuffer();
vector<uint8_t> readNonBlock(const size_t count)
{ {
vector<uint8_t> bytes(count); if (!bytesAvailable(count))
size_t readByteCount;
do
{ {
readByteCount = fread(bytes.data(), 1, count, getHostStream().get()); //clog << "No bytes available" << endl;
fwrite(bytes.data(), 1, bytes.size(), getComHostStream().get()); return vector<uint8_t>{};
} while (readByteCount == 0); }
clog << "Read " << readByteCount << " bytes" << endl; auto &buffer = getBuffer();
auto buffStart = buffer.begin(), buffEnd = buffer.begin() + count;
vector<uint8_t> bytes{ buffStart, buffEnd };
buffer.erase(buffStart, buffEnd);
fwrite(bytes.data(), 1, bytes.size(), getComHostStream().get());
if (readByteCount != count) errno = 0;
throw runtime_error{ "Failed to read sufficient bytes" };
return bytes; return bytes;
} }
void write(const uint8_t* bytes, const size_t count)
{
size_t totalBytes = 0;
auto hostDescriptor = *getHostDescriptor();
while (totalBytes < count)
{
auto writtenBytes = write(hostDescriptor, bytes + totalBytes, count - totalBytes);
if (writtenBytes > 0)
totalBytes += writtenBytes;
else if (errno != 0 && errno != EAGAIN && errno != EWOULDBLOCK) //Expect file blocking behaviour
ERR();
}
errno = 0;
}
bool bytesAvailable(const size_t count)
{
return getBuffer().size() >= count;
}
vector<uint8_t>& bufferVar()
{
static vector<uint8_t> buffer{};
return buffer;
}
vector<uint8_t>& getBuffer()
{
auto &buff = bufferVar();
array<uint8_t, HID_RPT_SIZE> bytes{};
auto hostDescriptor = *getHostDescriptor();
while (true)
{
auto readByteCount = read(hostDescriptor, bytes.data(), HID_RPT_SIZE);
if (readByteCount > 0)
{
copy(bytes.begin(), bytes.begin() + readByteCount, back_inserter(buff));
}
else if (errno != EAGAIN && errno != EWOULDBLOCK) //Expect read would block
{
ERR();
}
else
{
break; //Escape loop if blocking would occur
}
}
return buff;
}

7
IO.hpp
View File

@@ -3,4 +3,9 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
std::vector<uint8_t> readBytes(const size_t count); //Returns either the number of bytes specified,
//or returns empty vector without discarding bytes from HID stream
std::vector<uint8_t> readNonBlock(const size_t count);
//Blocking write to HID stream - shouldn't block for too long
void write(const uint8_t* bytes, const size_t count);

5
Macro.hpp Normal file
View File

@@ -0,0 +1,5 @@
#pragma once
#include <unistd.h>
#include <string>
#define ERR() if (errno != 0) perror((string{ "(" } + __FILE__ + ":" + to_string(__LINE__) + ")" + " " + __PRETTY_FUNCTION__).c_str()), errno = 0

View File

@@ -6,7 +6,7 @@ LDFLAGS := -lmbedcrypto
CPPFLAGS := CPPFLAGS :=
CXXFLAGS := --std=c++14 CXXFLAGS := --std=c++14
CXXFLAGS += -MMD -MP CXXFLAGS += -MMD -MP -Wall -Wfatal-errors -Wextra
MODULES := $(wildcard $(SRC_DIR)/*.cpp) MODULES := $(wildcard $(SRC_DIR)/*.cpp)
OBJECTS := $(MODULES:$(SRC_DIR)/%.cpp=$(OBJ_DIR)/%.o) OBJECTS := $(MODULES:$(SRC_DIR)/%.cpp=$(OBJ_DIR)/%.o)

View File

@@ -3,20 +3,64 @@
#include "u2f.hpp" #include "u2f.hpp"
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <unistd.h>
#include "Streams.hpp" #include "Streams.hpp"
using namespace std; using namespace std;
shared_ptr<InitPacket> InitPacket::getPacket(const uint32_t rCID, const uint8_t rCMD) shared_ptr<InitPacket> InitPacket::getPacket(const uint32_t rCID, const uint8_t rCMD)
{ {
auto p = make_shared<InitPacket>(); static size_t bytesRead = 0;
p->cid = rCID; static uint8_t bcnth;
p->cmd = rCMD; static uint8_t bcntl;
p->bcnth = readBytes(1)[0]; static decltype(InitPacket::data) dataBytes;
p->bcntl = readBytes(1)[0]; vector<uint8_t> bytes{};
const auto dataBytes = readBytes(p->data.size()); switch (bytesRead)
copy(dataBytes.begin(), dataBytes.end(), p->data.begin()); {
case 0:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
bcnth = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 1:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
bcntl = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 2:
bytes = readNonBlock(dataBytes.size());
if (bytes.size() == 0)
return {};
copy(bytes.begin(), bytes.end(), dataBytes.begin());;
bytesRead += bytes.size();
[[fallthrough]];
case 2 + dataBytes.size():
break;
default:
throw runtime_error{ "Unknown stage in InitPacket construction" };
}
auto p = make_shared<InitPacket>();
p->cid = rCID;
p->cmd = rCMD;
p->bcnth = bcnth;
p->bcntl = bcntl;
p->data = dataBytes;
auto hPStream = getHostPacketStream().get(); auto hPStream = getHostPacketStream().get();
fprintf(hPStream, "\t\t<table>\n" fprintf(hPStream, "\t\t<table>\n"
@@ -47,17 +91,33 @@ shared_ptr<InitPacket> InitPacket::getPacket(const uint32_t rCID, const uint8_t
"\t\t<br />"); "\t\t<br />");
clog << "Fully read init packet" << endl; clog << "Fully read init packet" << endl;
bytesRead = 0;
return p; return p;
} }
shared_ptr<ContPacket> ContPacket::getPacket(const uint32_t rCID, const uint8_t rSeq) shared_ptr<ContPacket> ContPacket::getPacket(const uint32_t rCID, const uint8_t rSeq)
{ {
static size_t readBytes = 0;
static decltype(ContPacket::data) dataBytes;
vector<uint8_t> bytes{};
auto p = make_shared<ContPacket>(); auto p = make_shared<ContPacket>();
if (readBytes != dataBytes.size())
{
dataBytes = {};
bytes = readNonBlock(dataBytes.size());
if (bytes.size() == 0)
return {};
copy(bytes.begin(), bytes.end(), dataBytes.begin());
readBytes += bytes.size();
}
p->cid = rCID; p->cid = rCID;
p->seq = rSeq; p->seq = rSeq;
p->data = dataBytes;
const auto dataBytes = readBytes(p->data.size());
copy(dataBytes.begin(), dataBytes.end(), p->data.begin());
auto hPStream = getHostPacketStream().get(); auto hPStream = getHostPacketStream().get();
fprintf(hPStream, "\t\t<table>\n" fprintf(hPStream, "\t\t<table>\n"
@@ -84,49 +144,86 @@ shared_ptr<ContPacket> ContPacket::getPacket(const uint32_t rCID, const uint8_t
"\t\t<br />"); "\t\t<br />");
//clog << "Fully read cont packet" << endl; //clog << "Fully read cont packet" << endl;
readBytes = 0;
return p; return p;
} }
shared_ptr<Packet> Packet::getPacket() shared_ptr<Packet> Packet::getPacket()
{ {
const uint32_t cid = *reinterpret_cast<uint32_t*>(readBytes(4).data()); static size_t bytesRead = 0;
uint8_t b = readBytes(1)[0]; vector<uint8_t> bytes{};
//clog << "Packet read 2nd byte as " << static_cast<uint16_t>(b) << endl; static uint32_t cid;
static uint8_t b;
shared_ptr<Packet> packet{};
if (b & TYPE_MASK) switch (bytesRead)
{ {
//Init packet case 0:
return InitPacket::getPacket(cid, b); bytes = readNonBlock(4);
}
else if (bytes.size() == 0)
{ return {};
//Cont packet
return ContPacket::getPacket(cid, b); cid = *reinterpret_cast<uint32_t*>(bytes.data());
bytesRead += bytes.size();
[[fallthrough]];
case 4:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
b = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 5:
if (b & TYPE_MASK)
{
//Init packet
//clog << "Getting init packet" << endl;
packet = InitPacket::getPacket(cid, b);
if (packet)
bytesRead = 0;
return packet;
}
else
{
//Cont packet
//clog << "Getting cont packet" << endl;
packet = ContPacket::getPacket(cid, b);
if (packet)
bytesRead = 0;
return packet;
}
default:
throw runtime_error{ "Unknown stage in Packet construction" };
} }
} }
void Packet::writePacket() void Packet::writePacket()
{ {
memset(this->buf, 0, packetSize); memset(this->buf, 0, HID_RPT_SIZE);
memcpy(this->buf, &cid, 4); memcpy(this->buf, &cid, 4);
} }
void InitPacket::writePacket() void InitPacket::writePacket()
{ {
Packet::writePacket(); Packet::writePacket();
auto hostStream = getHostStream().get();
auto devStream = getComDevStream().get(); auto devStream = getComDevStream().get();
memcpy(this->buf + 4, &cmd, 1); memcpy(this->buf + 4, &cmd, 1);
memcpy(this->buf + 5, &bcnth, 1); memcpy(this->buf + 5, &bcnth, 1);
memcpy(this->buf + 6, &bcntl, 1); memcpy(this->buf + 6, &bcntl, 1);
memcpy(this->buf + 7, data.data(), data.size()); memcpy(this->buf + 7, data.data(), data.size());
fwrite(this->buf, packetSize, 1, hostStream); write(this->buf, sizeof(this->buf));
fwrite(this->buf, packetSize, 1, devStream); fwrite(this->buf, 1, sizeof(this->buf), devStream);
if (errno != 0)
perror("perror " __FILE__ " 85");
auto dPStream = getDevPacketStream().get(); auto dPStream = getDevPacketStream().get();
fprintf(dPStream, "\t\t<table>\n" fprintf(dPStream, "\t\t<table>\n"
@@ -155,23 +252,17 @@ void InitPacket::writePacket()
"\t\t\t</tbody>\n" "\t\t\t</tbody>\n"
"\t\t</table>" "\t\t</table>"
"\t\t<br />"); "\t\t<br />");
clog << "Fully wrote init packet" << endl;
} }
void ContPacket::writePacket() void ContPacket::writePacket()
{ {
Packet::writePacket(); Packet::writePacket();
auto hostStream = getHostStream().get();
auto devStream = getComDevStream().get(); auto devStream = getComDevStream().get();
memcpy(this->buf + 4, &seq, 1); memcpy(this->buf + 4, &seq, 1);
memcpy(this->buf + 5, data.data(), data.size()); memcpy(this->buf + 5, data.data(), data.size());
fwrite(this->buf, packetSize, 1, hostStream); write(this->buf, HID_RPT_SIZE);
fwrite(this->buf, packetSize, 1, devStream); fwrite(this->buf, HID_RPT_SIZE, 1, devStream);
if (errno != 0)
perror("perror " __FILE__ " 107");
auto dPStream = getDevPacketStream().get(); auto dPStream = getDevPacketStream().get();

View File

@@ -2,12 +2,13 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <array> #include <array>
#include "Constants.hpp" #include "u2f.hpp"
struct Packet struct Packet
{ {
uint32_t cid; public:
uint8_t buf[packetSize]; uint32_t cid;
uint8_t buf[HID_RPT_SIZE];
protected: protected:
Packet() = default; Packet() = default;
@@ -20,10 +21,11 @@ struct Packet
struct InitPacket : Packet struct InitPacket : Packet
{ {
uint8_t cmd; public:
uint8_t bcnth; uint8_t cmd;
uint8_t bcntl; uint8_t bcnth;
std::array<uint8_t, packetSize - 7> data{}; uint8_t bcntl;
std::array<uint8_t, HID_RPT_SIZE - 7> data{};
public: public:
InitPacket() = default; InitPacket() = default;
@@ -33,8 +35,9 @@ struct InitPacket : Packet
struct ContPacket : Packet struct ContPacket : Packet
{ {
uint8_t seq; public:
std::array<uint8_t, packetSize - 5> data{}; uint8_t seq;
std::array<uint8_t, HID_RPT_SIZE - 5> data{};
public: public:
ContPacket() = default; ContPacket() = default;

View File

@@ -1,21 +1,29 @@
#include "Streams.hpp" #include "Streams.hpp"
#include <iostream> #include <iostream>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <cstdio>
#include <unistd.h>
using namespace std; using namespace std;
FILE* initHTML(FILE *fPtr, const string &title); FILE* initHTML(FILE *fPtr, const string &title);
void closeHTML(FILE *fPtr); void closeHTML(FILE *fPtr);
shared_ptr<FILE> getHostStream() shared_ptr<int> getHostDescriptor()
{ {
static shared_ptr<FILE> stream{ fopen("/dev/hidg0", "ab+"), [](FILE *f){ static shared_ptr<int> descriptor{};
fclose(f);
} }; descriptor.reset(new int{ open("/dev/hidg0", O_RDWR | O_NONBLOCK | O_APPEND) }, [](int* fd){
close(*fd);
delete fd;
});
if (!stream) if (*descriptor == -1)
clog << "Stream is unavailable" << endl; throw runtime_error{ "Descriptor is unavailable" };
return stream; return descriptor;
} }
shared_ptr<FILE> getComHostStream() shared_ptr<FILE> getComHostStream()

View File

@@ -2,7 +2,7 @@
#include <cstdio> #include <cstdio>
#include <memory> #include <memory>
std::shared_ptr<FILE> getHostStream(); std::shared_ptr<int> getHostDescriptor();
std::shared_ptr<FILE> getComHostStream(); std::shared_ptr<FILE> getComHostStream();
std::shared_ptr<FILE> getHostPacketStream(); std::shared_ptr<FILE> getHostPacketStream();
std::shared_ptr<FILE> getHostAPDUStream(); std::shared_ptr<FILE> getHostAPDUStream();

View File

@@ -8,40 +8,83 @@
using namespace std; using namespace std;
U2FMessage U2FMessage::read() shared_ptr<U2FMessage> U2FMessage::readNonBlock()
{ {
auto fPack = dynamic_pointer_cast<InitPacket>(Packet::getPacket()); static size_t currSeq = -1;
static uint16_t messageSize;
static uint32_t cid;
static uint8_t cmd;
static vector<uint8_t> dataBytes;
if (!fPack) shared_ptr<Packet> p{};
throw runtime_error{ "Failed to receive Init packet" };
const uint16_t messageSize = ((static_cast<uint16_t>(fPack->bcnth) << 8u) + fPack->bcntl); if (currSeq == -1u)
clog << "Message on channel 0x" << hex << fPack->cid << dec << " has size: " << messageSize << endl;
const uint16_t copyByteCount = min(static_cast<uint16_t>(fPack->data.size()), messageSize);
U2FMessage message{ fPack-> cid, fPack->cmd };
message.data.assign(fPack->data.begin(), fPack->data.begin() + copyByteCount);
uint8_t currSeq = 0;
while (message.data.size() < messageSize)
{ {
auto newPack = dynamic_pointer_cast<ContPacket>(Packet::getPacket()); cid = 0;
cmd = 0;
messageSize = 0;
dataBytes = {};
shared_ptr<InitPacket> initPack{};
do
{
p = Packet::getPacket();
if (!newPack) if (!p)
throw runtime_error{ "Failed to receive Cont packet" }; return {};
else if (newPack->seq != currSeq)
throw runtime_error{ "Packet out of sequence" };
const uint16_t remainingBytes = messageSize - message.data.size(); initPack = dynamic_pointer_cast<InitPacket>(p);
const uint16_t copyBytes = min(static_cast<uint16_t>(newPack->data.size()), remainingBytes); } while (!initPack); //Spurious cont. packet - spec states ignore
message.data.insert(message.data.end(), newPack->data.begin(), newPack->data.begin() + copyBytes);
messageSize = ((static_cast<uint16_t>(initPack->bcnth) << 8u) + initPack->bcntl);
const uint16_t copyByteCount = min(static_cast<uint16_t>(initPack->data.size()), messageSize);
cid = initPack->cid;
cmd = initPack->cmd;
copy(initPack->data.begin(), initPack->data.begin() + copyByteCount, back_inserter(dataBytes));
currSeq++; currSeq++;
} }
while (messageSize > dataBytes.size() && static_cast<bool>(p = Packet::getPacket())) //While there is a packet
{
auto contPack = dynamic_pointer_cast<ContPacket>(p);
if (!contPack) //Spurious init. packet
{
currSeq = -1; //Reset
return {};
}
if (contPack->cid != cid) //Cont. packet of different CID
{
cerr << "Invalid CID: was handling channel 0x" << hex << cid << " and received packet from channel 0x" << contPack->cid << dec << endl;
U2FMessage::error(contPack->cid, ERR_CHANNEL_BUSY);
currSeq = -1;
return {};
}
if (contPack->seq != currSeq)
{
cerr << "Invalid packet seq. value" << endl;
U2FMessage::error(cid, ERR_INVALID_SEQ);
currSeq = -1;
return {};
}
const uint16_t remainingBytes = messageSize - dataBytes.size();
const uint16_t copyBytes = min(static_cast<uint16_t>(contPack->data.size()), remainingBytes);
dataBytes.insert(dataBytes.end(), contPack->data.begin(), contPack->data.begin() + copyBytes);
currSeq++;
}
if (messageSize != dataBytes.size())
return {};
auto message = make_shared<U2FMessage>(cid, cmd);
message->data.assign(dataBytes.begin(), dataBytes.end());
currSeq = -1u;
std::clog << "Read all of message" << std::endl; std::clog << "Read all of message" << std::endl;
return message; return message;
@@ -49,9 +92,6 @@ U2FMessage U2FMessage::read()
void U2FMessage::write() void U2FMessage::write()
{ {
clog << "Flushing host stream" << endl;
fflush(getHostStream().get());
clog << "Flushed host stream" << endl;
const uint16_t bytesToWrite = this->data.size(); const uint16_t bytesToWrite = this->data.size();
uint16_t bytesWritten = 0; uint16_t bytesWritten = 0;
@@ -88,8 +128,7 @@ void U2FMessage::write()
bytesWritten += newByteCount; bytesWritten += newByteCount;
} }
auto stream = getHostStream().get(); //auto stream = *getHostStream();
fflush(stream);
if (cmd == U2FHID_MSG) if (cmd == U2FHID_MSG)
{ {
@@ -120,3 +159,16 @@ void U2FMessage::write()
"\t\t<br />", err); "\t\t<br />", err);
} }
} }
U2FMessage::U2FMessage(const uint32_t nCID, const uint8_t nCMD)
: cid{ nCID }, cmd{ nCMD }
{}
void U2FMessage::error(const uint32_t tCID, const uint16_t tErr)
{
U2FMessage msg{};
msg.cid = tCID;
msg.cmd = U2FHID_ERROR;
msg.data.push_back((tErr >> 8) & 0xFF);
msg.data.push_back(tErr & 0xFF);
}

View File

@@ -1,14 +1,21 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <memory>
struct U2FMessage struct U2FMessage
{ {
uint32_t cid; public:
uint8_t cmd; uint32_t cid;
std::vector<uint8_t> data; uint8_t cmd;
std::vector<uint8_t> data;
static U2FMessage read(); public:
U2FMessage() = default;
U2FMessage(const uint32_t nCID, const uint8_t nCMD);
static std::shared_ptr<U2FMessage> readNonBlock();
void write();
void write(); protected:
static void error(const uint32_t tCID, const uint16_t tErr);
}; };

View File

@@ -27,10 +27,10 @@ U2F_Authenticate_APDU::U2F_Authenticate_APDU(const U2F_Msg_CMD &msg, const vecto
clog << "Got U2F_Auth request" << endl; clog << "Got U2F_Auth request" << endl;
} }
void U2F_Authenticate_APDU::respond() void U2F_Authenticate_APDU::respond(const uint32_t channelID) const
{ {
U2FMessage msg{}; U2FMessage msg{};
msg.cid = 0xF1D0F1D0; msg.cid = channelID;
msg.cmd = U2FHID_MSG; msg.cmd = U2FHID_MSG;
auto statusCode = APDU_STATUS::SW_NO_ERROR; auto statusCode = APDU_STATUS::SW_NO_ERROR;

View File

@@ -12,7 +12,7 @@ struct U2F_Authenticate_APDU : U2F_Msg_CMD
public: public:
U2F_Authenticate_APDU(const U2F_Msg_CMD &msg, const std::vector<uint8_t> &data); U2F_Authenticate_APDU(const U2F_Msg_CMD &msg, const std::vector<uint8_t> &data);
void respond(); virtual void respond(const uint32_t channelID) const override;
enum ControlCode enum ControlCode
{ {

19
U2F_CMD.cpp Normal file
View File

@@ -0,0 +1,19 @@
#include "U2F_CMD.hpp"
#include "u2f.hpp"
#include "U2F_Msg_CMD.hpp"
#include "U2F_Init_CMD.hpp"
using namespace std;
shared_ptr<U2F_CMD> U2F_CMD::get(const shared_ptr<U2FMessage> uMsg)
{
switch (uMsg->cmd)
{
case U2FHID_MSG:
return U2F_Msg_CMD::generate(uMsg);
case U2FHID_INIT:
return make_shared<U2F_Init_CMD>(uMsg);
default:
return {};
}
}

View File

@@ -1,4 +1,6 @@
#pragma once #pragma once
#include <memory>
#include "U2FMessage.hpp"
struct U2F_CMD struct U2F_CMD
{ {
@@ -6,5 +8,7 @@ struct U2F_CMD
U2F_CMD() = default; U2F_CMD() = default;
public: public:
~U2F_CMD() = default; virtual ~U2F_CMD() = default;
static std::shared_ptr<U2F_CMD> get(const std::shared_ptr<U2FMessage> uMsg);
virtual void respond(const uint32_t channelID) const = 0;
}; //For polymorphic type casting }; //For polymorphic type casting

View File

@@ -1,24 +1,33 @@
#include "U2F_Init_CMD.hpp" #include "U2F_Init_CMD.hpp"
#include <stdexcept> #include <stdexcept>
#include "U2FMessage.hpp"
#include "u2f.hpp" #include "u2f.hpp"
#include <iostream> #include "Field.hpp"
using namespace std; using namespace std;
U2F_Init_CMD U2F_Init_CMD::get() U2F_Init_CMD::U2F_Init_CMD(const shared_ptr<U2FMessage> uMsg)
{ {
const auto message = U2FMessage::read(); if (uMsg->cmd != U2FHID_INIT)
if (message.cmd != U2FHID_INIT)
throw runtime_error{ "Failed to get U2F Init message" }; throw runtime_error{ "Failed to get U2F Init message" };
else if (message.data.size() != INIT_NONCE_SIZE) else if (uMsg->data.size() != INIT_NONCE_SIZE)
throw runtime_error{ "Init nonce is incorrect size" }; throw runtime_error{ "Init nonce is incorrect size" };
U2F_Init_CMD cmd; this->nonce = *reinterpret_cast<const uint64_t*>(uMsg->data.data());
cmd.nonce = *reinterpret_cast<const uint64_t*>(message.data.data()); }
clog << "Fully read nonce" << endl; void U2F_Init_CMD::respond(const uint32_t channelID) const
{
return cmd; U2FMessage msg{};
msg.cid = CID_BROADCAST;
msg.cmd = U2FHID_INIT;
msg.data.insert(msg.data.end(), FIELD(this->nonce));
msg.data.insert(msg.data.end(), FIELD(channelID));
msg.data.push_back(2); //Protocol version
msg.data.push_back(1); //Major device version
msg.data.push_back(0); //Minor device version
msg.data.push_back(1); //Build device version
msg.data.push_back(CAPFLAG_WINK); //Wink capability
msg.write();
} }

View File

@@ -1,11 +1,14 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <memory>
#include "U2F_CMD.hpp" #include "U2F_CMD.hpp"
#include "U2FMessage.hpp"
struct U2F_Init_CMD : U2F_CMD struct U2F_Init_CMD : U2F_CMD
{ {
uint64_t nonce; uint64_t nonce;
public: public:
static U2F_Init_CMD get(); U2F_Init_CMD(const std::shared_ptr<U2FMessage> uMsg);
virtual void respond(const uint32_t channelID) const override;
}; };

View File

@@ -1,36 +0,0 @@
#pragma once
#include <cstdint>
#include "U2FMessage.hpp"
#include "Field.hpp"
struct U2F_Init_Response : U2F_CMD
{
uint32_t cid;
uint64_t nonce;
uint8_t protocolVer;
uint8_t majorDevVer;
uint8_t minorDevVer;
uint8_t buildDevVer;
uint8_t capabilities;
void write()
{
std::clog << "Beginning writeout of U2F_Init_Response" << std::endl;
U2FMessage m{};
m.cid = CID_BROADCAST;
m.cmd = U2FHID_INIT;
m.data.insert(m.data.begin() + 0, FIELD(nonce));
m.data.insert(m.data.begin() + 8, FIELD(cid));
m.data.insert(m.data.begin() + 12, FIELD(protocolVer));
m.data.insert(m.data.begin() + 13, FIELD(majorDevVer));
m.data.insert(m.data.begin() + 14, FIELD(minorDevVer));
m.data.insert(m.data.begin() + 15, FIELD(buildDevVer));
m.data.insert(m.data.begin() + 16, FIELD(capabilities));
std::clog << "Finished inserting U2F_Init_Response fields to data buffer" << std::endl;
m.write();
std::clog << "Completed writeout of U2F_Init_Response" << std::endl;
}
};

View File

@@ -8,16 +8,17 @@
#include "APDU.hpp" #include "APDU.hpp"
#include <iostream> #include <iostream>
#include "Streams.hpp" #include "Streams.hpp"
#include "Field.hpp"
using namespace std; using namespace std;
uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes) uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes)
{ {
if (byteCount > 3)
throw runtime_error{ "Too much data for command" };
if (byteCount != 0) if (byteCount != 0)
{ {
//Le must be length of data in bytes //Le must be length of data in bytes
clog << "Le must be length of data in bytes" << endl;
clog << "Le has a size of " << byteCount << " bytes" << endl;
switch (byteCount) switch (byteCount)
{ {
@@ -27,44 +28,47 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes)
//Don't handle non-compliance with spec here //Don't handle non-compliance with spec here
//This case is only possible if extended Lc used //This case is only possible if extended Lc used
//CBA //CBA
return (bytes[0] == 0 && bytes[1] == 0 ? 65536 : bytes[0] << 8 + bytes[1]); return (bytes[0] == 0 && bytes[1] == 0 ? 65536 : (bytes[0] << 8) + bytes[1]);
case 3: case 3:
//Don't handle non-compliance with spec here //Don't handle non-compliance with spec here
//This case is only possible if extended Lc not used //This case is only possible if extended Lc not used
//CBA //CBA
if (bytes[0] != 0) if (bytes[0] != 0)
throw runtime_error{ "First byte of 3-byte Le should be 0"}; throw runtime_error{ "First byte of 3-byte Le should be 0"};
return (bytes[1] == 0 & bytes[2] == 0 ? 65536 : bytes[1] << 8 + bytes[2]); return (bytes[1] == 0 && bytes[2] == 0 ? 65536 : (bytes[1] << 8) + bytes[2]);
default:
throw runtime_error{ "Too much data for command" };
} }
} }
else else
return 0; return 0;
} }
shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get() shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
{ {
const auto message = U2FMessage::read(); if (uMsg->cmd != U2FHID_MSG)
throw runtime_error{ "Failed to get U2F Msg uMsg" };
if (message.cmd != U2FHID_MSG) else if (uMsg->data.size() < 4)
throw runtime_error{ "Failed to get U2F Msg message" }; throw runtime_error{ "Msg data is incorrect size" };
U2F_Msg_CMD cmd{}; U2F_Msg_CMD cmd;
auto &dat = message.data; auto &dat = uMsg->data;
cmd.cla = dat[0]; cmd.cla = dat[0];
cmd.ins = dat[1]; cmd.ins = dat[1];
cmd.p1 = dat[2]; cmd.p1 = dat[2];
cmd.p2 = dat[3]; cmd.p2 = dat[3];
const uint32_t cBCount = dat.size() - 4; clog << "Loaded U2F_Msg_CMD parameters" << endl;
vector<uint8_t> data{ dat.begin() + 4, dat.end() }; vector<uint8_t> data{ dat.begin() + 4, dat.end() };
const uint32_t cBCount = data.size();
auto startPtr = data.begin(), endPtr = data.end(); auto startPtr = data.begin(), endPtr = data.end();
clog << "Loaded iters" << endl;
if (usesData.at(cmd.ins) || data.size() > 3) if (usesData.at(cmd.ins) || data.size() > 3)
{ {
//clog << "First bytes are: " << static_cast<uint16_t>(data[0]) << " " << static_cast<uint16_t>(data[1]) << " " << static_cast<uint16_t>(data[2]) << endl;
if (cBCount == 0) if (cBCount == 0)
throw runtime_error{ "Invalid command - should have attached data" }; throw runtime_error{ "Invalid command - should have attached data" };
@@ -81,17 +85,22 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get()
endPtr = startPtr + cmd.lc; endPtr = startPtr + cmd.lc;
clog << "Getting Le" << endl;
cmd.le = getLe(data.end() - endPtr, vector<uint8_t>(endPtr, data.end())); cmd.le = getLe(data.end() - endPtr, vector<uint8_t>(endPtr, data.end()));
} }
else else
{ {
cmd.lc = 0; cmd.lc = 0;
endPtr = startPtr; endPtr = startPtr;
clog << "Getting Le" << endl;
cmd.le = getLe(cBCount, data); cmd.le = getLe(cBCount, data);
} }
const auto dBytes = vector<uint8_t>(startPtr, endPtr); const auto dBytes = vector<uint8_t>(startPtr, endPtr);
clog << "Determined message format" << endl;
auto hAS = getHostAPDUStream().get(); auto hAS = getHostAPDUStream().get();
fprintf(hAS, "<table>\n" fprintf(hAS, "<table>\n"
@@ -125,6 +134,8 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get()
"\t\t</table>\n" "\t\t</table>\n"
"\t\t<br />", cmd.le); "\t\t<br />", cmd.le);
clog << "Constructing message specialisation" << endl;
switch (cmd.ins) switch (cmd.ins)
{ {
case APDU::U2F_REG: case APDU::U2F_REG:
@@ -138,11 +149,18 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get()
} }
} }
void U2F_Msg_CMD::respond(){};
const map<uint8_t, bool> U2F_Msg_CMD::usesData = { const map<uint8_t, bool> U2F_Msg_CMD::usesData = {
{ U2F_REG, true }, { U2F_REG, true },
{ U2F_AUTH, true }, { U2F_AUTH, true },
{ U2F_VER, false } { U2F_VER, false }
}; };
void U2F_Msg_CMD::respond(const uint32_t channelID) const
{
U2FMessage msg{};
msg.cid = channelID;
msg.cmd = U2FHID_MSG;
auto errorCode = APDU_STATUS::SW_INS_NOT_SUPPORTED;
msg.data.insert(msg.data.end(), FIELD_BE(errorCode));
msg.write();
}

View File

@@ -18,10 +18,10 @@ struct U2F_Msg_CMD : U2F_CMD
protected: protected:
static uint32_t getLe(const uint32_t byteCount, std::vector<uint8_t> bytes); static uint32_t getLe(const uint32_t byteCount, std::vector<uint8_t> bytes);
U2F_Msg_CMD() = default;
public: public:
static std::shared_ptr<U2F_Msg_CMD> get(); static std::shared_ptr<U2F_Msg_CMD> generate(const std::shared_ptr<U2FMessage> uMsg);
void respond(const uint32_t channelID) const;
virtual void respond();
}; };

View File

@@ -50,10 +50,10 @@ U2F_Register_APDU::U2F_Register_APDU(const U2F_Msg_CMD &msg, const vector<uint8_
clog << endl << dec << "Got U2F_Reg request" << endl; clog << endl << dec << "Got U2F_Reg request" << endl;
} }
void U2F_Register_APDU::respond() void U2F_Register_APDU::respond(const uint32_t channelID) const
{ {
U2FMessage m{}; U2FMessage m{};
m.cid = 0xF1D0F1D0; m.cid = channelID;
m.cmd = U2FHID_MSG; m.cmd = U2FHID_MSG;
auto& response = m.data; auto& response = m.data;
@@ -65,7 +65,7 @@ void U2F_Register_APDU::respond()
copy(pubKey.begin(), pubKey.end(), back_inserter(response)); copy(pubKey.begin(), pubKey.end(), back_inserter(response));
response.push_back(sizeof(this->keyH)); response.push_back(sizeof(this->keyH));
auto fakeKeyHBytes = reinterpret_cast<uint8_t *>(&this->keyH); auto fakeKeyHBytes = reinterpret_cast<const uint8_t *>(&this->keyH);
copy(fakeKeyHBytes, fakeKeyHBytes + sizeof(this->keyH), back_inserter(response)); copy(fakeKeyHBytes, fakeKeyHBytes + sizeof(this->keyH), back_inserter(response));
copy(attestCert, end(attestCert), back_inserter(response)); copy(attestCert, end(attestCert), back_inserter(response));
@@ -83,9 +83,9 @@ void U2F_Register_APDU::respond()
mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(appParam.data()), appParam.size()); mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(appParam.data()), appParam.size());
mbedtls_sha256_update(&shaContext, reinterpret_cast<unsigned char*>(challengeP.data()), challengeP.size()); mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(challengeP.data()), challengeP.size());
mbedtls_sha256_update(&shaContext, reinterpret_cast<unsigned char*>(&keyH), sizeof(keyH)); mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(&keyH), sizeof(keyH));
mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(pubKey.data()), pubKey.size()); mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(pubKey.data()), pubKey.size());

View File

@@ -11,6 +11,6 @@ struct U2F_Register_APDU : U2F_Msg_CMD
public: public:
U2F_Register_APDU(const U2F_Msg_CMD &msg, const std::vector<uint8_t> &data); U2F_Register_APDU(const U2F_Msg_CMD &msg, const std::vector<uint8_t> &data);
void respond(); void respond(const uint32_t channelID) const override;
}; };

View File

@@ -14,12 +14,12 @@ U2F_Version_APDU::U2F_Version_APDU(const U2F_Msg_CMD &msg)
//Don't actually respond yet //Don't actually respond yet
} }
void U2F_Version_APDU::respond() void U2F_Version_APDU::respond(const uint32_t channelID) const
{ {
char ver[]{ 'U', '2', 'F', '_', 'V', '2' }; char ver[]{ 'U', '2', 'F', '_', 'V', '2' };
U2FMessage m{}; U2FMessage m{};
m.cid = 0xF1D0F1D0; m.cid = channelID;
m.cmd = U2FHID_MSG; m.cmd = U2FHID_MSG;
m.data.insert(m.data.end(), FIELD(ver)); m.data.insert(m.data.end(), FIELD(ver));
auto sC = APDU_STATUS::SW_NO_ERROR; auto sC = APDU_STATUS::SW_NO_ERROR;

View File

@@ -5,5 +5,5 @@ struct U2F_Version_APDU : U2F_Msg_CMD
{ {
public: public:
U2F_Version_APDU(const U2F_Msg_CMD &msg); U2F_Version_APDU(const U2F_Msg_CMD &msg);
void respond(); void respond(const uint32_t channelID) const override;
}; };

View File

@@ -1,60 +1,35 @@
#include <cstdio>
#include <memory>
#include <stdexcept>
#include <array>
#include <vector>
#include <algorithm>
#include <iostream> #include <iostream>
#include <cstring>
#include <map>
#include "u2f.hpp"
#include "Constants.hpp"
#include "U2FMessage.hpp"
#include "U2F_CMD.hpp"
#include "Storage.hpp" #include "Storage.hpp"
#include "U2F_Init_CMD.hpp" #include "Controller.hpp"
#include "U2F_Init_Response.hpp" #include <signal.h>
#include "U2F_Msg_CMD.hpp" #include <unistd.h>
using namespace std; using namespace std;
int main(int argc, char** argv) void signalCallback(int signum);
volatile bool contProc = true;
int main()
{ {
signal(SIGINT, signalCallback);
Storage::init(); Storage::init();
auto initFrame = U2F_Init_CMD::get(); Controller ch{ 0xF1D00000 };
U2F_Init_Response resp{}; while (contProc)
resp.cid = 0xF1D0F1D0;
resp.nonce = initFrame.nonce;
resp.protocolVer = 2;
resp.majorDevVer = 1;
resp.minorDevVer = 0;
resp.buildDevVer = 1;
resp.capabilities = CAPFLAG_WINK;
resp.write();
size_t msgCount = (argc == 2 ? stoul(argv[1]) : 3u);
for (size_t i = 0; i < msgCount; i++)
{ {
auto reg = U2F_Msg_CMD::get(); ch.handleTransaction();
usleep(10000);
cout << "U2F CMD ins: " << static_cast<uint32_t>(reg->ins) << endl;
reg->respond();
clog << "Fully responded" << endl;
} }
/*auto m = U2FMessage::read();
cout << "U2F CMD: " << static_cast<uint32_t>(m.cmd) << endl;
for (const auto d : m.data)
clog << static_cast<uint16_t>(d) << endl;*/
Storage::save(); Storage::save();
return EXIT_SUCCESS;
}
void signalCallback(int signum)
{
contProc = false;
clog << "Caught SIGINT signal" << endl;
} }