Added resumable loading, along with signal handling.

This version can store state after receiving SIGINT.
This is achieved by polling FIFO read state;
This commit is contained in:
2018-08-09 20:23:23 +00:00
parent 673577a601
commit 48840ad36c
29 changed files with 598 additions and 233 deletions

View File

@@ -3,20 +3,64 @@
#include "u2f.hpp"
#include <cstring>
#include <iostream>
#include <unistd.h>
#include "Streams.hpp"
using namespace std;
shared_ptr<InitPacket> InitPacket::getPacket(const uint32_t rCID, const uint8_t rCMD)
{
auto p = make_shared<InitPacket>();
p->cid = rCID;
p->cmd = rCMD;
p->bcnth = readBytes(1)[0];
p->bcntl = readBytes(1)[0];
static size_t bytesRead = 0;
static uint8_t bcnth;
static uint8_t bcntl;
static decltype(InitPacket::data) dataBytes;
vector<uint8_t> bytes{};
const auto dataBytes = readBytes(p->data.size());
copy(dataBytes.begin(), dataBytes.end(), p->data.begin());
switch (bytesRead)
{
case 0:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
bcnth = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 1:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
bcntl = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 2:
bytes = readNonBlock(dataBytes.size());
if (bytes.size() == 0)
return {};
copy(bytes.begin(), bytes.end(), dataBytes.begin());;
bytesRead += bytes.size();
[[fallthrough]];
case 2 + dataBytes.size():
break;
default:
throw runtime_error{ "Unknown stage in InitPacket construction" };
}
auto p = make_shared<InitPacket>();
p->cid = rCID;
p->cmd = rCMD;
p->bcnth = bcnth;
p->bcntl = bcntl;
p->data = dataBytes;
auto hPStream = getHostPacketStream().get();
fprintf(hPStream, "\t\t<table>\n"
@@ -47,17 +91,33 @@ shared_ptr<InitPacket> InitPacket::getPacket(const uint32_t rCID, const uint8_t
"\t\t<br />");
clog << "Fully read init packet" << endl;
bytesRead = 0;
return p;
}
shared_ptr<ContPacket> ContPacket::getPacket(const uint32_t rCID, const uint8_t rSeq)
{
static size_t readBytes = 0;
static decltype(ContPacket::data) dataBytes;
vector<uint8_t> bytes{};
auto p = make_shared<ContPacket>();
if (readBytes != dataBytes.size())
{
dataBytes = {};
bytes = readNonBlock(dataBytes.size());
if (bytes.size() == 0)
return {};
copy(bytes.begin(), bytes.end(), dataBytes.begin());
readBytes += bytes.size();
}
p->cid = rCID;
p->seq = rSeq;
const auto dataBytes = readBytes(p->data.size());
copy(dataBytes.begin(), dataBytes.end(), p->data.begin());
p->data = dataBytes;
auto hPStream = getHostPacketStream().get();
fprintf(hPStream, "\t\t<table>\n"
@@ -84,49 +144,86 @@ shared_ptr<ContPacket> ContPacket::getPacket(const uint32_t rCID, const uint8_t
"\t\t<br />");
//clog << "Fully read cont packet" << endl;
readBytes = 0;
return p;
}
shared_ptr<Packet> Packet::getPacket()
{
const uint32_t cid = *reinterpret_cast<uint32_t*>(readBytes(4).data());
uint8_t b = readBytes(1)[0];
static size_t bytesRead = 0;
vector<uint8_t> bytes{};
//clog << "Packet read 2nd byte as " << static_cast<uint16_t>(b) << endl;
static uint32_t cid;
static uint8_t b;
shared_ptr<Packet> packet{};
if (b & TYPE_MASK)
switch (bytesRead)
{
//Init packet
return InitPacket::getPacket(cid, b);
}
else
{
//Cont packet
return ContPacket::getPacket(cid, b);
case 0:
bytes = readNonBlock(4);
if (bytes.size() == 0)
return {};
cid = *reinterpret_cast<uint32_t*>(bytes.data());
bytesRead += bytes.size();
[[fallthrough]];
case 4:
bytes = readNonBlock(1);
if (bytes.size() == 0)
return {};
b = bytes[0];
bytesRead += bytes.size();
[[fallthrough]];
case 5:
if (b & TYPE_MASK)
{
//Init packet
//clog << "Getting init packet" << endl;
packet = InitPacket::getPacket(cid, b);
if (packet)
bytesRead = 0;
return packet;
}
else
{
//Cont packet
//clog << "Getting cont packet" << endl;
packet = ContPacket::getPacket(cid, b);
if (packet)
bytesRead = 0;
return packet;
}
default:
throw runtime_error{ "Unknown stage in Packet construction" };
}
}
void Packet::writePacket()
{
memset(this->buf, 0, packetSize);
memset(this->buf, 0, HID_RPT_SIZE);
memcpy(this->buf, &cid, 4);
}
void InitPacket::writePacket()
{
Packet::writePacket();
auto hostStream = getHostStream().get();
auto devStream = getComDevStream().get();
memcpy(this->buf + 4, &cmd, 1);
memcpy(this->buf + 5, &bcnth, 1);
memcpy(this->buf + 6, &bcntl, 1);
memcpy(this->buf + 7, data.data(), data.size());
fwrite(this->buf, packetSize, 1, hostStream);
fwrite(this->buf, packetSize, 1, devStream);
if (errno != 0)
perror("perror " __FILE__ " 85");
write(this->buf, sizeof(this->buf));
fwrite(this->buf, 1, sizeof(this->buf), devStream);
auto dPStream = getDevPacketStream().get();
fprintf(dPStream, "\t\t<table>\n"
@@ -155,23 +252,17 @@ void InitPacket::writePacket()
"\t\t\t</tbody>\n"
"\t\t</table>"
"\t\t<br />");
clog << "Fully wrote init packet" << endl;
}
void ContPacket::writePacket()
{
Packet::writePacket();
auto hostStream = getHostStream().get();
auto devStream = getComDevStream().get();
memcpy(this->buf + 4, &seq, 1);
memcpy(this->buf + 5, data.data(), data.size());
fwrite(this->buf, packetSize, 1, hostStream);
fwrite(this->buf, packetSize, 1, devStream);
if (errno != 0)
perror("perror " __FILE__ " 107");
write(this->buf, HID_RPT_SIZE);
fwrite(this->buf, HID_RPT_SIZE, 1, devStream);
auto dPStream = getDevPacketStream().get();