Added resumable loading, along with signal handling.

This version can store state after receiving SIGINT.
This is achieved by polling FIFO read state;
This commit is contained in:
2018-08-09 20:23:23 +00:00
parent 673577a601
commit 48840ad36c
29 changed files with 598 additions and 233 deletions

42
Channel.cpp Normal file
View File

@@ -0,0 +1,42 @@
#include "Channel.hpp"
#include <stdexcept>
#include "u2f.hpp"
#include "U2F_CMD.hpp"
#include <iostream>
using namespace std;
Channel::Channel(const uint32_t channelID)
: cid{ channelID }, initState{ ChannelInitState::Unitialised }, lockedState{ ChannelLockedState::Unlocked }
{}
uint32_t Channel::getCID() const
{
return cid;
}
void Channel::handle(const shared_ptr<U2FMessage> uMsg)
{
if (uMsg->cmd == U2FHID_INIT)
this->initState = ChannelInitState::Initialised;
else if (uMsg->cid != this->cid)
throw runtime_error{ "CID of request invalid for this channel" };
if (this->initState != ChannelInitState::Initialised)
throw runtime_error{ "Channel in incorrect (uninitialized) state to handle request" };
else if (this->lockedState != ChannelLockedState::Unlocked)
throw runtime_error{ "Channel in incorrect (locked) state to handle request" };
clog << "Handling uMsg with CMD: " << static_cast<uint32_t>(uMsg->cmd) << endl;
return U2F_CMD::get(uMsg)->respond(this->cid);
}
void Channel::init(const ChannelInitState newInitState)
{
this->initState = newInitState;
}
void Channel::lock(const ChannelLockedState newLockedState)
{
this->lockedState = newLockedState;
}

32
Channel.hpp Normal file
View File

@@ -0,0 +1,32 @@
#pragma once
#include <cstdint>
#include <memory>
#include "U2FMessage.hpp"
enum class ChannelInitState
{
Unitialised,
Initialised
};
enum class ChannelLockedState
{
Locked,
Unlocked
};
class Channel
{
protected:
uint32_t cid;
ChannelInitState initState;
ChannelLockedState lockedState;
public:
Channel(const uint32_t channelID);
void handle(const std::shared_ptr<U2FMessage> uMsg);
uint32_t getCID() const;
void init(const ChannelInitState newInitState);
void lock(const ChannelLockedState newLockedState);
};

View File

@@ -1,4 +0,0 @@
#pragma once
#include <cstdint>
const constexpr uint16_t packetSize = 64;

48
Controller.cpp Normal file
View File

@@ -0,0 +1,48 @@
#include "Controller.hpp"
#include "u2f.hpp"
#include <iostream>
#include "IO.hpp"
using namespace std;
Controller::Controller(const uint32_t startChannel)
: channels{}, currChannel{ startChannel }
{}
void Controller::handleTransaction()
{
auto msg = U2FMessage::readNonBlock();
if (!msg)
return;
auto opChannel = msg->cid;
clog << "Got msg with cmd of: " << static_cast<uint16_t>(msg->cmd) << endl;
if (msg->cmd == U2FHID_INIT)
{
opChannel = nextChannel();
auto channel = Channel{ opChannel };
try
{
channels.emplace(opChannel, channel); //In case of wrap-around replace existing one
}
catch (...)
{
channels.insert(make_pair(opChannel, channel));
}
}
channels.at(opChannel).handle(msg);
}
const uint32_t Controller::nextChannel()
{
do
currChannel++;
while (currChannel == 0xFFFFFFFF || currChannel == 0);
return currChannel;
}

16
Controller.hpp Normal file
View File

@@ -0,0 +1,16 @@
#pragma once
#include <map>
#include "Channel.hpp"
class Controller
{
protected:
std::map<uint32_t, Channel> channels;
uint32_t currChannel;
public:
Controller(const uint32_t startChannel = 1);
void handleTransaction();
const uint32_t nextChannel();
};

92
IO.cpp
View File

@@ -1,26 +1,94 @@
#include "IO.hpp"
#include "Streams.hpp"
#include <iostream>
#include <unistd.h>
#include <stropts.h>
//#include <sys/ioctl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include "u2f.hpp"
#include "Macro.hpp"
using namespace std;
vector<uint8_t> readBytes(const size_t count)
bool bytesAvailable(const size_t count);
vector<uint8_t>& getBuffer();
vector<uint8_t> readNonBlock(const size_t count)
{
vector<uint8_t> bytes(count);
size_t readByteCount;
do
if (!bytesAvailable(count))
{
readByteCount = fread(bytes.data(), 1, count, getHostStream().get());
fwrite(bytes.data(), 1, bytes.size(), getComHostStream().get());
} while (readByteCount == 0);
//clog << "No bytes available" << endl;
return vector<uint8_t>{};
}
clog << "Read " << readByteCount << " bytes" << endl;
auto &buffer = getBuffer();
auto buffStart = buffer.begin(), buffEnd = buffer.begin() + count;
vector<uint8_t> bytes{ buffStart, buffEnd };
buffer.erase(buffStart, buffEnd);
if (readByteCount != count)
throw runtime_error{ "Failed to read sufficient bytes" };
fwrite(bytes.data(), 1, bytes.size(), getComHostStream().get());
errno = 0;
return bytes;
}
void write(const uint8_t* bytes, const size_t count)
{
size_t totalBytes = 0;
auto hostDescriptor = *getHostDescriptor();
while (totalBytes < count)
{
auto writtenBytes = write(hostDescriptor, bytes + totalBytes, count - totalBytes);
if (writtenBytes > 0)
totalBytes += writtenBytes;
else if (errno != 0 && errno != EAGAIN && errno != EWOULDBLOCK) //Expect file blocking behaviour
ERR();
}
errno = 0;
}
bool bytesAvailable(const size_t count)
{
return getBuffer().size() >= count;
}
vector<uint8_t>& bufferVar()
{
static vector<uint8_t> buffer{};
return buffer;
}
vector<uint8_t>& getBuffer()
{
auto &buff = bufferVar();
array<uint8_t, HID_RPT_SIZE> bytes{};
auto hostDescriptor = *getHostDescriptor();
while (true)
{
auto readByteCount = read(hostDescriptor, bytes.data(), HID_RPT_SIZE);
if (readByteCount > 0)
{
copy(bytes.begin(), bytes.begin() + readByteCount, back_inserter(buff));
}
else if (errno != EAGAIN && errno != EWOULDBLOCK) //Expect read would block
{
ERR();
}
else
{
break; //Escape loop if blocking would occur
}
}
return buff;
}

7
IO.hpp
View File

@@ -3,4 +3,9 @@
#include <cstdint>
#include <vector>
std::vector<uint8_t> readBytes(const size_t count);
//Returns either the number of bytes specified,
//or returns empty vector without discarding bytes from HID stream
std::vector<uint8_t> readNonBlock(const size_t count);
//Blocking write to HID stream - shouldn't block for too long
void write(const uint8_t* bytes, const size_t count);

5
Macro.hpp Normal file
View File

@@ -0,0 +1,5 @@
#pragma once
#include <unistd.h>
#include <string>
#define ERR() if (errno != 0) perror((string{ "(" } + __FILE__ + ":" + to_string(__LINE__) + ")" + " " + __PRETTY_FUNCTION__).c_str()), errno = 0

View File

@@ -6,7 +6,7 @@ LDFLAGS := -lmbedcrypto
CPPFLAGS :=
CXXFLAGS := --std=c++14
CXXFLAGS += -MMD -MP
CXXFLAGS += -MMD -MP -Wall -Wfatal-errors -Wextra
MODULES := $(wildcard $(SRC_DIR)/*.cpp)
OBJECTS := $(MODULES:$(SRC_DIR)/%.cpp=$(OBJ_DIR)/%.o)

View File

@@ -3,20 +3,64 @@
#include "u2f.hpp"
#include <cstring>
#include <iostream>
#include <unistd.h>
#include "Streams.hpp"
using namespace std;
shared_ptr<InitPacket> InitPacket::getPacket(const uint32_t rCID, const uint8_t rCMD)
{
auto p = make_shared<InitPacket>();
p->cid = rCID;
p->cmd = rCMD;
p->bcnth = readBytes(1)[0];
p->bcntl = readBytes(1)[0];
static size_t bytesRead = 0;
static uint8_t bcnth;
static uint8_t bcntl;
static decltype(InitPacket::data) dataBytes;
vector<uint8_t> bytes{};
const auto dataBytes = readBytes(p->data.size());
copy(dataBytes.begin(), dataBytes.end(), p->data.begin());
switch (bytesRead)
{
case 0:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
bcnth = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 1:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
bcntl = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 2:
bytes = readNonBlock(dataBytes.size());
if (bytes.size() == 0)
return {};
copy(bytes.begin(), bytes.end(), dataBytes.begin());;
bytesRead += bytes.size();
[[fallthrough]];
case 2 + dataBytes.size():
break;
default:
throw runtime_error{ "Unknown stage in InitPacket construction" };
}
auto p = make_shared<InitPacket>();
p->cid = rCID;
p->cmd = rCMD;
p->bcnth = bcnth;
p->bcntl = bcntl;
p->data = dataBytes;
auto hPStream = getHostPacketStream().get();
fprintf(hPStream, "\t\t<table>\n"
@@ -47,17 +91,33 @@ shared_ptr<InitPacket> InitPacket::getPacket(const uint32_t rCID, const uint8_t
"\t\t<br />");
clog << "Fully read init packet" << endl;
bytesRead = 0;
return p;
}
shared_ptr<ContPacket> ContPacket::getPacket(const uint32_t rCID, const uint8_t rSeq)
{
static size_t readBytes = 0;
static decltype(ContPacket::data) dataBytes;
vector<uint8_t> bytes{};
auto p = make_shared<ContPacket>();
if (readBytes != dataBytes.size())
{
dataBytes = {};
bytes = readNonBlock(dataBytes.size());
if (bytes.size() == 0)
return {};
copy(bytes.begin(), bytes.end(), dataBytes.begin());
readBytes += bytes.size();
}
p->cid = rCID;
p->seq = rSeq;
const auto dataBytes = readBytes(p->data.size());
copy(dataBytes.begin(), dataBytes.end(), p->data.begin());
p->data = dataBytes;
auto hPStream = getHostPacketStream().get();
fprintf(hPStream, "\t\t<table>\n"
@@ -84,49 +144,86 @@ shared_ptr<ContPacket> ContPacket::getPacket(const uint32_t rCID, const uint8_t
"\t\t<br />");
//clog << "Fully read cont packet" << endl;
readBytes = 0;
return p;
}
shared_ptr<Packet> Packet::getPacket()
{
const uint32_t cid = *reinterpret_cast<uint32_t*>(readBytes(4).data());
uint8_t b = readBytes(1)[0];
static size_t bytesRead = 0;
vector<uint8_t> bytes{};
//clog << "Packet read 2nd byte as " << static_cast<uint16_t>(b) << endl;
static uint32_t cid;
static uint8_t b;
shared_ptr<Packet> packet{};
if (b & TYPE_MASK)
switch (bytesRead)
{
//Init packet
return InitPacket::getPacket(cid, b);
}
else
{
//Cont packet
return ContPacket::getPacket(cid, b);
case 0:
bytes = readNonBlock(4);
if (bytes.size() == 0)
return {};
cid = *reinterpret_cast<uint32_t*>(bytes.data());
bytesRead += bytes.size();
[[fallthrough]];
case 4:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
b = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 5:
if (b & TYPE_MASK)
{
//Init packet
//clog << "Getting init packet" << endl;
packet = InitPacket::getPacket(cid, b);
if (packet)
bytesRead = 0;
return packet;
}
else
{
//Cont packet
//clog << "Getting cont packet" << endl;
packet = ContPacket::getPacket(cid, b);
if (packet)
bytesRead = 0;
return packet;
}
default:
throw runtime_error{ "Unknown stage in Packet construction" };
}
}
void Packet::writePacket()
{
memset(this->buf, 0, packetSize);
memset(this->buf, 0, HID_RPT_SIZE);
memcpy(this->buf, &cid, 4);
}
void InitPacket::writePacket()
{
Packet::writePacket();
auto hostStream = getHostStream().get();
auto devStream = getComDevStream().get();
memcpy(this->buf + 4, &cmd, 1);
memcpy(this->buf + 5, &bcnth, 1);
memcpy(this->buf + 6, &bcntl, 1);
memcpy(this->buf + 7, data.data(), data.size());
fwrite(this->buf, packetSize, 1, hostStream);
fwrite(this->buf, packetSize, 1, devStream);
if (errno != 0)
perror("perror " __FILE__ " 85");
write(this->buf, sizeof(this->buf));
fwrite(this->buf, 1, sizeof(this->buf), devStream);
auto dPStream = getDevPacketStream().get();
fprintf(dPStream, "\t\t<table>\n"
@@ -155,23 +252,17 @@ void InitPacket::writePacket()
"\t\t\t</tbody>\n"
"\t\t</table>"
"\t\t<br />");
clog << "Fully wrote init packet" << endl;
}
void ContPacket::writePacket()
{
Packet::writePacket();
auto hostStream = getHostStream().get();
auto devStream = getComDevStream().get();
memcpy(this->buf + 4, &seq, 1);
memcpy(this->buf + 5, data.data(), data.size());
fwrite(this->buf, packetSize, 1, hostStream);
fwrite(this->buf, packetSize, 1, devStream);
if (errno != 0)
perror("perror " __FILE__ " 107");
write(this->buf, HID_RPT_SIZE);
fwrite(this->buf, HID_RPT_SIZE, 1, devStream);
auto dPStream = getDevPacketStream().get();

View File

@@ -2,12 +2,13 @@
#include <cstdint>
#include <memory>
#include <array>
#include "Constants.hpp"
#include "u2f.hpp"
struct Packet
{
uint32_t cid;
uint8_t buf[packetSize];
public:
uint32_t cid;
uint8_t buf[HID_RPT_SIZE];
protected:
Packet() = default;
@@ -20,10 +21,11 @@ struct Packet
struct InitPacket : Packet
{
uint8_t cmd;
uint8_t bcnth;
uint8_t bcntl;
std::array<uint8_t, packetSize - 7> data{};
public:
uint8_t cmd;
uint8_t bcnth;
uint8_t bcntl;
std::array<uint8_t, HID_RPT_SIZE - 7> data{};
public:
InitPacket() = default;
@@ -33,8 +35,9 @@ struct InitPacket : Packet
struct ContPacket : Packet
{
uint8_t seq;
std::array<uint8_t, packetSize - 5> data{};
public:
uint8_t seq;
std::array<uint8_t, HID_RPT_SIZE - 5> data{};
public:
ContPacket() = default;

View File

@@ -1,21 +1,29 @@
#include "Streams.hpp"
#include <iostream>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <cstdio>
#include <unistd.h>
using namespace std;
FILE* initHTML(FILE *fPtr, const string &title);
void closeHTML(FILE *fPtr);
shared_ptr<FILE> getHostStream()
shared_ptr<int> getHostDescriptor()
{
static shared_ptr<FILE> stream{ fopen("/dev/hidg0", "ab+"), [](FILE *f){
fclose(f);
} };
static shared_ptr<int> descriptor{};
if (!stream)
clog << "Stream is unavailable" << endl;
descriptor.reset(new int{ open("/dev/hidg0", O_RDWR | O_NONBLOCK | O_APPEND) }, [](int* fd){
close(*fd);
delete fd;
});
return stream;
if (*descriptor == -1)
throw runtime_error{ "Descriptor is unavailable" };
return descriptor;
}
shared_ptr<FILE> getComHostStream()

View File

@@ -2,7 +2,7 @@
#include <cstdio>
#include <memory>
std::shared_ptr<FILE> getHostStream();
std::shared_ptr<int> getHostDescriptor();
std::shared_ptr<FILE> getComHostStream();
std::shared_ptr<FILE> getHostPacketStream();
std::shared_ptr<FILE> getHostAPDUStream();

View File

@@ -8,40 +8,83 @@
using namespace std;
U2FMessage U2FMessage::read()
shared_ptr<U2FMessage> U2FMessage::readNonBlock()
{
auto fPack = dynamic_pointer_cast<InitPacket>(Packet::getPacket());
static size_t currSeq = -1;
static uint16_t messageSize;
static uint32_t cid;
static uint8_t cmd;
static vector<uint8_t> dataBytes;
if (!fPack)
throw runtime_error{ "Failed to receive Init packet" };
shared_ptr<Packet> p{};
const uint16_t messageSize = ((static_cast<uint16_t>(fPack->bcnth) << 8u) + fPack->bcntl);
clog << "Message on channel 0x" << hex << fPack->cid << dec << " has size: " << messageSize << endl;
const uint16_t copyByteCount = min(static_cast<uint16_t>(fPack->data.size()), messageSize);
U2FMessage message{ fPack-> cid, fPack->cmd };
message.data.assign(fPack->data.begin(), fPack->data.begin() + copyByteCount);
uint8_t currSeq = 0;
while (message.data.size() < messageSize)
if (currSeq == -1u)
{
auto newPack = dynamic_pointer_cast<ContPacket>(Packet::getPacket());
cid = 0;
cmd = 0;
messageSize = 0;
dataBytes = {};
if (!newPack)
throw runtime_error{ "Failed to receive Cont packet" };
else if (newPack->seq != currSeq)
throw runtime_error{ "Packet out of sequence" };
shared_ptr<InitPacket> initPack{};
do
{
p = Packet::getPacket();
const uint16_t remainingBytes = messageSize - message.data.size();
const uint16_t copyBytes = min(static_cast<uint16_t>(newPack->data.size()), remainingBytes);
message.data.insert(message.data.end(), newPack->data.begin(), newPack->data.begin() + copyBytes);
if (!p)
return {};
initPack = dynamic_pointer_cast<InitPacket>(p);
} while (!initPack); //Spurious cont. packet - spec states ignore
messageSize = ((static_cast<uint16_t>(initPack->bcnth) << 8u) + initPack->bcntl);
const uint16_t copyByteCount = min(static_cast<uint16_t>(initPack->data.size()), messageSize);
cid = initPack->cid;
cmd = initPack->cmd;
copy(initPack->data.begin(), initPack->data.begin() + copyByteCount, back_inserter(dataBytes));
currSeq++;
}
while (messageSize > dataBytes.size() && static_cast<bool>(p = Packet::getPacket())) //While there is a packet
{
auto contPack = dynamic_pointer_cast<ContPacket>(p);
if (!contPack) //Spurious init. packet
{
currSeq = -1; //Reset
return {};
}
if (contPack->cid != cid) //Cont. packet of different CID
{
cerr << "Invalid CID: was handling channel 0x" << hex << cid << " and received packet from channel 0x" << contPack->cid << dec << endl;
U2FMessage::error(contPack->cid, ERR_CHANNEL_BUSY);
currSeq = -1;
return {};
}
if (contPack->seq != currSeq)
{
cerr << "Invalid packet seq. value" << endl;
U2FMessage::error(cid, ERR_INVALID_SEQ);
currSeq = -1;
return {};
}
const uint16_t remainingBytes = messageSize - dataBytes.size();
const uint16_t copyBytes = min(static_cast<uint16_t>(contPack->data.size()), remainingBytes);
dataBytes.insert(dataBytes.end(), contPack->data.begin(), contPack->data.begin() + copyBytes);
currSeq++;
}
if (messageSize != dataBytes.size())
return {};
auto message = make_shared<U2FMessage>(cid, cmd);
message->data.assign(dataBytes.begin(), dataBytes.end());
currSeq = -1u;
std::clog << "Read all of message" << std::endl;
return message;
@@ -49,9 +92,6 @@ U2FMessage U2FMessage::read()
void U2FMessage::write()
{
clog << "Flushing host stream" << endl;
fflush(getHostStream().get());
clog << "Flushed host stream" << endl;
const uint16_t bytesToWrite = this->data.size();
uint16_t bytesWritten = 0;
@@ -88,8 +128,7 @@ void U2FMessage::write()
bytesWritten += newByteCount;
}
auto stream = getHostStream().get();
fflush(stream);
//auto stream = *getHostStream();
if (cmd == U2FHID_MSG)
{
@@ -120,3 +159,16 @@ void U2FMessage::write()
"\t\t<br />", err);
}
}
U2FMessage::U2FMessage(const uint32_t nCID, const uint8_t nCMD)
: cid{ nCID }, cmd{ nCMD }
{}
void U2FMessage::error(const uint32_t tCID, const uint16_t tErr)
{
U2FMessage msg{};
msg.cid = tCID;
msg.cmd = U2FHID_ERROR;
msg.data.push_back((tErr >> 8) & 0xFF);
msg.data.push_back(tErr & 0xFF);
}

View File

@@ -1,14 +1,21 @@
#pragma once
#include <cstdint>
#include <vector>
#include <memory>
struct U2FMessage
{
uint32_t cid;
uint8_t cmd;
std::vector<uint8_t> data;
public:
uint32_t cid;
uint8_t cmd;
std::vector<uint8_t> data;
static U2FMessage read();
public:
U2FMessage() = default;
U2FMessage(const uint32_t nCID, const uint8_t nCMD);
static std::shared_ptr<U2FMessage> readNonBlock();
void write();
void write();
protected:
static void error(const uint32_t tCID, const uint16_t tErr);
};

View File

@@ -27,10 +27,10 @@ U2F_Authenticate_APDU::U2F_Authenticate_APDU(const U2F_Msg_CMD &msg, const vecto
clog << "Got U2F_Auth request" << endl;
}
void U2F_Authenticate_APDU::respond()
void U2F_Authenticate_APDU::respond(const uint32_t channelID) const
{
U2FMessage msg{};
msg.cid = 0xF1D0F1D0;
msg.cid = channelID;
msg.cmd = U2FHID_MSG;
auto statusCode = APDU_STATUS::SW_NO_ERROR;

View File

@@ -12,7 +12,7 @@ struct U2F_Authenticate_APDU : U2F_Msg_CMD
public:
U2F_Authenticate_APDU(const U2F_Msg_CMD &msg, const std::vector<uint8_t> &data);
void respond();
virtual void respond(const uint32_t channelID) const override;
enum ControlCode
{

19
U2F_CMD.cpp Normal file
View File

@@ -0,0 +1,19 @@
#include "U2F_CMD.hpp"
#include "u2f.hpp"
#include "U2F_Msg_CMD.hpp"
#include "U2F_Init_CMD.hpp"
using namespace std;
shared_ptr<U2F_CMD> U2F_CMD::get(const shared_ptr<U2FMessage> uMsg)
{
switch (uMsg->cmd)
{
case U2FHID_MSG:
return U2F_Msg_CMD::generate(uMsg);
case U2FHID_INIT:
return make_shared<U2F_Init_CMD>(uMsg);
default:
return {};
}
}

View File

@@ -1,4 +1,6 @@
#pragma once
#include <memory>
#include "U2FMessage.hpp"
struct U2F_CMD
{
@@ -6,5 +8,7 @@ struct U2F_CMD
U2F_CMD() = default;
public:
~U2F_CMD() = default;
virtual ~U2F_CMD() = default;
static std::shared_ptr<U2F_CMD> get(const std::shared_ptr<U2FMessage> uMsg);
virtual void respond(const uint32_t channelID) const = 0;
}; //For polymorphic type casting

View File

@@ -1,24 +1,33 @@
#include "U2F_Init_CMD.hpp"
#include <stdexcept>
#include "U2FMessage.hpp"
#include "u2f.hpp"
#include <iostream>
#include "Field.hpp"
using namespace std;
U2F_Init_CMD U2F_Init_CMD::get()
U2F_Init_CMD::U2F_Init_CMD(const shared_ptr<U2FMessage> uMsg)
{
const auto message = U2FMessage::read();
if (message.cmd != U2FHID_INIT)
if (uMsg->cmd != U2FHID_INIT)
throw runtime_error{ "Failed to get U2F Init message" };
else if (message.data.size() != INIT_NONCE_SIZE)
else if (uMsg->data.size() != INIT_NONCE_SIZE)
throw runtime_error{ "Init nonce is incorrect size" };
U2F_Init_CMD cmd;
cmd.nonce = *reinterpret_cast<const uint64_t*>(message.data.data());
clog << "Fully read nonce" << endl;
return cmd;
this->nonce = *reinterpret_cast<const uint64_t*>(uMsg->data.data());
}
void U2F_Init_CMD::respond(const uint32_t channelID) const
{
U2FMessage msg{};
msg.cid = CID_BROADCAST;
msg.cmd = U2FHID_INIT;
msg.data.insert(msg.data.end(), FIELD(this->nonce));
msg.data.insert(msg.data.end(), FIELD(channelID));
msg.data.push_back(2); //Protocol version
msg.data.push_back(1); //Major device version
msg.data.push_back(0); //Minor device version
msg.data.push_back(1); //Build device version
msg.data.push_back(CAPFLAG_WINK); //Wink capability
msg.write();
}

View File

@@ -1,11 +1,14 @@
#pragma once
#include <cstdint>
#include <memory>
#include "U2F_CMD.hpp"
#include "U2FMessage.hpp"
struct U2F_Init_CMD : U2F_CMD
{
uint64_t nonce;
public:
static U2F_Init_CMD get();
U2F_Init_CMD(const std::shared_ptr<U2FMessage> uMsg);
virtual void respond(const uint32_t channelID) const override;
};

View File

@@ -1,36 +0,0 @@
#pragma once
#include <cstdint>
#include "U2FMessage.hpp"
#include "Field.hpp"
struct U2F_Init_Response : U2F_CMD
{
uint32_t cid;
uint64_t nonce;
uint8_t protocolVer;
uint8_t majorDevVer;
uint8_t minorDevVer;
uint8_t buildDevVer;
uint8_t capabilities;
void write()
{
std::clog << "Beginning writeout of U2F_Init_Response" << std::endl;
U2FMessage m{};
m.cid = CID_BROADCAST;
m.cmd = U2FHID_INIT;
m.data.insert(m.data.begin() + 0, FIELD(nonce));
m.data.insert(m.data.begin() + 8, FIELD(cid));
m.data.insert(m.data.begin() + 12, FIELD(protocolVer));
m.data.insert(m.data.begin() + 13, FIELD(majorDevVer));
m.data.insert(m.data.begin() + 14, FIELD(minorDevVer));
m.data.insert(m.data.begin() + 15, FIELD(buildDevVer));
m.data.insert(m.data.begin() + 16, FIELD(capabilities));
std::clog << "Finished inserting U2F_Init_Response fields to data buffer" << std::endl;
m.write();
std::clog << "Completed writeout of U2F_Init_Response" << std::endl;
}
};

View File

@@ -8,16 +8,17 @@
#include "APDU.hpp"
#include <iostream>
#include "Streams.hpp"
#include "Field.hpp"
using namespace std;
uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes)
{
if (byteCount > 3)
throw runtime_error{ "Too much data for command" };
if (byteCount != 0)
{
//Le must be length of data in bytes
clog << "Le must be length of data in bytes" << endl;
clog << "Le has a size of " << byteCount << " bytes" << endl;
switch (byteCount)
{
@@ -27,44 +28,47 @@ uint32_t U2F_Msg_CMD::getLe(const uint32_t byteCount, vector<uint8_t> bytes)
//Don't handle non-compliance with spec here
//This case is only possible if extended Lc used
//CBA
return (bytes[0] == 0 && bytes[1] == 0 ? 65536 : bytes[0] << 8 + bytes[1]);
return (bytes[0] == 0 && bytes[1] == 0 ? 65536 : (bytes[0] << 8) + bytes[1]);
case 3:
//Don't handle non-compliance with spec here
//This case is only possible if extended Lc not used
//CBA
if (bytes[0] != 0)
throw runtime_error{ "First byte of 3-byte Le should be 0"};
return (bytes[1] == 0 & bytes[2] == 0 ? 65536 : bytes[1] << 8 + bytes[2]);
return (bytes[1] == 0 && bytes[2] == 0 ? 65536 : (bytes[1] << 8) + bytes[2]);
default:
throw runtime_error{ "Too much data for command" };
}
}
else
return 0;
}
shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get()
shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::generate(const shared_ptr<U2FMessage> uMsg)
{
const auto message = U2FMessage::read();
if (uMsg->cmd != U2FHID_MSG)
throw runtime_error{ "Failed to get U2F Msg uMsg" };
else if (uMsg->data.size() < 4)
throw runtime_error{ "Msg data is incorrect size" };
if (message.cmd != U2FHID_MSG)
throw runtime_error{ "Failed to get U2F Msg message" };
U2F_Msg_CMD cmd{};
auto &dat = message.data;
U2F_Msg_CMD cmd;
auto &dat = uMsg->data;
cmd.cla = dat[0];
cmd.ins = dat[1];
cmd.p1 = dat[2];
cmd.p2 = dat[3];
const uint32_t cBCount = dat.size() - 4;
clog << "Loaded U2F_Msg_CMD parameters" << endl;
vector<uint8_t> data{ dat.begin() + 4, dat.end() };
const uint32_t cBCount = data.size();
auto startPtr = data.begin(), endPtr = data.end();
clog << "Loaded iters" << endl;
if (usesData.at(cmd.ins) || data.size() > 3)
{
//clog << "First bytes are: " << static_cast<uint16_t>(data[0]) << " " << static_cast<uint16_t>(data[1]) << " " << static_cast<uint16_t>(data[2]) << endl;
if (cBCount == 0)
throw runtime_error{ "Invalid command - should have attached data" };
@@ -81,17 +85,22 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get()
endPtr = startPtr + cmd.lc;
clog << "Getting Le" << endl;
cmd.le = getLe(data.end() - endPtr, vector<uint8_t>(endPtr, data.end()));
}
else
{
cmd.lc = 0;
endPtr = startPtr;
clog << "Getting Le" << endl;
cmd.le = getLe(cBCount, data);
}
const auto dBytes = vector<uint8_t>(startPtr, endPtr);
clog << "Determined message format" << endl;
auto hAS = getHostAPDUStream().get();
fprintf(hAS, "<table>\n"
@@ -125,6 +134,8 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get()
"\t\t</table>\n"
"\t\t<br />", cmd.le);
clog << "Constructing message specialisation" << endl;
switch (cmd.ins)
{
case APDU::U2F_REG:
@@ -138,11 +149,18 @@ shared_ptr<U2F_Msg_CMD> U2F_Msg_CMD::get()
}
}
void U2F_Msg_CMD::respond(){};
const map<uint8_t, bool> U2F_Msg_CMD::usesData = {
{ U2F_REG, true },
{ U2F_AUTH, true },
{ U2F_VER, false }
};
void U2F_Msg_CMD::respond(const uint32_t channelID) const
{
U2FMessage msg{};
msg.cid = channelID;
msg.cmd = U2FHID_MSG;
auto errorCode = APDU_STATUS::SW_INS_NOT_SUPPORTED;
msg.data.insert(msg.data.end(), FIELD_BE(errorCode));
msg.write();
}

View File

@@ -18,10 +18,10 @@ struct U2F_Msg_CMD : U2F_CMD
protected:
static uint32_t getLe(const uint32_t byteCount, std::vector<uint8_t> bytes);
U2F_Msg_CMD() = default;
public:
static std::shared_ptr<U2F_Msg_CMD> get();
virtual void respond();
static std::shared_ptr<U2F_Msg_CMD> generate(const std::shared_ptr<U2FMessage> uMsg);
void respond(const uint32_t channelID) const;
};

View File

@@ -50,10 +50,10 @@ U2F_Register_APDU::U2F_Register_APDU(const U2F_Msg_CMD &msg, const vector<uint8_
clog << endl << dec << "Got U2F_Reg request" << endl;
}
void U2F_Register_APDU::respond()
void U2F_Register_APDU::respond(const uint32_t channelID) const
{
U2FMessage m{};
m.cid = 0xF1D0F1D0;
m.cid = channelID;
m.cmd = U2FHID_MSG;
auto& response = m.data;
@@ -65,7 +65,7 @@ void U2F_Register_APDU::respond()
copy(pubKey.begin(), pubKey.end(), back_inserter(response));
response.push_back(sizeof(this->keyH));
auto fakeKeyHBytes = reinterpret_cast<uint8_t *>(&this->keyH);
auto fakeKeyHBytes = reinterpret_cast<const uint8_t *>(&this->keyH);
copy(fakeKeyHBytes, fakeKeyHBytes + sizeof(this->keyH), back_inserter(response));
copy(attestCert, end(attestCert), back_inserter(response));
@@ -83,9 +83,9 @@ void U2F_Register_APDU::respond()
mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(appParam.data()), appParam.size());
mbedtls_sha256_update(&shaContext, reinterpret_cast<unsigned char*>(challengeP.data()), challengeP.size());
mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(challengeP.data()), challengeP.size());
mbedtls_sha256_update(&shaContext, reinterpret_cast<unsigned char*>(&keyH), sizeof(keyH));
mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(&keyH), sizeof(keyH));
mbedtls_sha256_update(&shaContext, reinterpret_cast<const unsigned char*>(pubKey.data()), pubKey.size());

View File

@@ -11,6 +11,6 @@ struct U2F_Register_APDU : U2F_Msg_CMD
public:
U2F_Register_APDU(const U2F_Msg_CMD &msg, const std::vector<uint8_t> &data);
void respond();
void respond(const uint32_t channelID) const override;
};

View File

@@ -14,12 +14,12 @@ U2F_Version_APDU::U2F_Version_APDU(const U2F_Msg_CMD &msg)
//Don't actually respond yet
}
void U2F_Version_APDU::respond()
void U2F_Version_APDU::respond(const uint32_t channelID) const
{
char ver[]{ 'U', '2', 'F', '_', 'V', '2' };
U2FMessage m{};
m.cid = 0xF1D0F1D0;
m.cid = channelID;
m.cmd = U2FHID_MSG;
m.data.insert(m.data.end(), FIELD(ver));
auto sC = APDU_STATUS::SW_NO_ERROR;

View File

@@ -5,5 +5,5 @@ struct U2F_Version_APDU : U2F_Msg_CMD
{
public:
U2F_Version_APDU(const U2F_Msg_CMD &msg);
void respond();
void respond(const uint32_t channelID) const override;
};

View File

@@ -1,60 +1,35 @@
#include <cstdio>
#include <memory>
#include <stdexcept>
#include <array>
#include <vector>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <map>
#include "u2f.hpp"
#include "Constants.hpp"
#include "U2FMessage.hpp"
#include "U2F_CMD.hpp"
#include "Storage.hpp"
#include "U2F_Init_CMD.hpp"
#include "U2F_Init_Response.hpp"
#include "U2F_Msg_CMD.hpp"
#include "Controller.hpp"
#include <signal.h>
#include <unistd.h>
using namespace std;
int main(int argc, char** argv)
void signalCallback(int signum);
volatile bool contProc = true;
int main()
{
signal(SIGINT, signalCallback);
Storage::init();
auto initFrame = U2F_Init_CMD::get();
Controller ch{ 0xF1D00000 };
U2F_Init_Response resp{};
resp.cid = 0xF1D0F1D0;
resp.nonce = initFrame.nonce;
resp.protocolVer = 2;
resp.majorDevVer = 1;
resp.minorDevVer = 0;
resp.buildDevVer = 1;
resp.capabilities = CAPFLAG_WINK;
resp.write();
size_t msgCount = (argc == 2 ? stoul(argv[1]) : 3u);
for (size_t i = 0; i < msgCount; i++)
while (contProc)
{
auto reg = U2F_Msg_CMD::get();
cout << "U2F CMD ins: " << static_cast<uint32_t>(reg->ins) << endl;
reg->respond();
clog << "Fully responded" << endl;
ch.handleTransaction();
usleep(10000);
}
/*auto m = U2FMessage::read();
cout << "U2F CMD: " << static_cast<uint32_t>(m.cmd) << endl;
for (const auto d : m.data)
clog << static_cast<uint16_t>(d) << endl;*/
Storage::save();
return EXIT_SUCCESS;
}
void signalCallback(int signum)
{
contProc = false;
clog << "Caught SIGINT signal" << endl;
}