Added resumable loading, along with signal handling.

This version can store state after receiving SIGINT.
This is achieved by polling FIFO read state;
This commit is contained in:
2018-08-09 20:23:23 +00:00
parent 673577a601
commit 48840ad36c
29 changed files with 598 additions and 233 deletions

View File

@@ -8,40 +8,83 @@
using namespace std;
U2FMessage U2FMessage::read()
shared_ptr<U2FMessage> U2FMessage::readNonBlock()
{
auto fPack = dynamic_pointer_cast<InitPacket>(Packet::getPacket());
static size_t currSeq = -1;
static uint16_t messageSize;
static uint32_t cid;
static uint8_t cmd;
static vector<uint8_t> dataBytes;
if (!fPack)
throw runtime_error{ "Failed to receive Init packet" };
shared_ptr<Packet> p{};
const uint16_t messageSize = ((static_cast<uint16_t>(fPack->bcnth) << 8u) + fPack->bcntl);
clog << "Message on channel 0x" << hex << fPack->cid << dec << " has size: " << messageSize << endl;
const uint16_t copyByteCount = min(static_cast<uint16_t>(fPack->data.size()), messageSize);
U2FMessage message{ fPack-> cid, fPack->cmd };
message.data.assign(fPack->data.begin(), fPack->data.begin() + copyByteCount);
uint8_t currSeq = 0;
while (message.data.size() < messageSize)
if (currSeq == -1u)
{
auto newPack = dynamic_pointer_cast<ContPacket>(Packet::getPacket());
cid = 0;
cmd = 0;
messageSize = 0;
dataBytes = {};
shared_ptr<InitPacket> initPack{};
do
{
p = Packet::getPacket();
if (!newPack)
throw runtime_error{ "Failed to receive Cont packet" };
else if (newPack->seq != currSeq)
throw runtime_error{ "Packet out of sequence" };
if (!p)
return {};
const uint16_t remainingBytes = messageSize - message.data.size();
const uint16_t copyBytes = min(static_cast<uint16_t>(newPack->data.size()), remainingBytes);
message.data.insert(message.data.end(), newPack->data.begin(), newPack->data.begin() + copyBytes);
initPack = dynamic_pointer_cast<InitPacket>(p);
} while (!initPack); //Spurious cont. packet - spec states ignore
messageSize = ((static_cast<uint16_t>(initPack->bcnth) << 8u) + initPack->bcntl);
const uint16_t copyByteCount = min(static_cast<uint16_t>(initPack->data.size()), messageSize);
cid = initPack->cid;
cmd = initPack->cmd;
copy(initPack->data.begin(), initPack->data.begin() + copyByteCount, back_inserter(dataBytes));
currSeq++;
}
while (messageSize > dataBytes.size() && static_cast<bool>(p = Packet::getPacket())) //While there is a packet
{
auto contPack = dynamic_pointer_cast<ContPacket>(p);
if (!contPack) //Spurious init. packet
{
currSeq = -1; //Reset
return {};
}
if (contPack->cid != cid) //Cont. packet of different CID
{
cerr << "Invalid CID: was handling channel 0x" << hex << cid << " and received packet from channel 0x" << contPack->cid << dec << endl;
U2FMessage::error(contPack->cid, ERR_CHANNEL_BUSY);
currSeq = -1;
return {};
}
if (contPack->seq != currSeq)
{
cerr << "Invalid packet seq. value" << endl;
U2FMessage::error(cid, ERR_INVALID_SEQ);
currSeq = -1;
return {};
}
const uint16_t remainingBytes = messageSize - dataBytes.size();
const uint16_t copyBytes = min(static_cast<uint16_t>(contPack->data.size()), remainingBytes);
dataBytes.insert(dataBytes.end(), contPack->data.begin(), contPack->data.begin() + copyBytes);
currSeq++;
}
if (messageSize != dataBytes.size())
return {};
auto message = make_shared<U2FMessage>(cid, cmd);
message->data.assign(dataBytes.begin(), dataBytes.end());
currSeq = -1u;
std::clog << "Read all of message" << std::endl;
return message;
@@ -49,9 +92,6 @@ U2FMessage U2FMessage::read()
void U2FMessage::write()
{
clog << "Flushing host stream" << endl;
fflush(getHostStream().get());
clog << "Flushed host stream" << endl;
const uint16_t bytesToWrite = this->data.size();
uint16_t bytesWritten = 0;
@@ -88,8 +128,7 @@ void U2FMessage::write()
bytesWritten += newByteCount;
}
auto stream = getHostStream().get();
fflush(stream);
//auto stream = *getHostStream();
if (cmd == U2FHID_MSG)
{
@@ -120,3 +159,16 @@ void U2FMessage::write()
"\t\t<br />", err);
}
}
U2FMessage::U2FMessage(const uint32_t nCID, const uint8_t nCMD)
: cid{ nCID }, cmd{ nCMD }
{}
void U2FMessage::error(const uint32_t tCID, const uint16_t tErr)
{
U2FMessage msg{};
msg.cid = tCID;
msg.cmd = U2FHID_ERROR;
msg.data.push_back((tErr >> 8) & 0xFF);
msg.data.push_back(tErr & 0xFF);
}