Added resumable loading, along with signal handling.
This version can store state after receiving SIGINT. This is achieved by polling FIFO read state;
This commit is contained in:
110
U2FMessage.cpp
110
U2FMessage.cpp
@@ -8,40 +8,83 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
U2FMessage U2FMessage::read()
|
||||
shared_ptr<U2FMessage> U2FMessage::readNonBlock()
|
||||
{
|
||||
auto fPack = dynamic_pointer_cast<InitPacket>(Packet::getPacket());
|
||||
static size_t currSeq = -1;
|
||||
static uint16_t messageSize;
|
||||
static uint32_t cid;
|
||||
static uint8_t cmd;
|
||||
static vector<uint8_t> dataBytes;
|
||||
|
||||
if (!fPack)
|
||||
throw runtime_error{ "Failed to receive Init packet" };
|
||||
shared_ptr<Packet> p{};
|
||||
|
||||
const uint16_t messageSize = ((static_cast<uint16_t>(fPack->bcnth) << 8u) + fPack->bcntl);
|
||||
|
||||
clog << "Message on channel 0x" << hex << fPack->cid << dec << " has size: " << messageSize << endl;
|
||||
|
||||
const uint16_t copyByteCount = min(static_cast<uint16_t>(fPack->data.size()), messageSize);
|
||||
|
||||
U2FMessage message{ fPack-> cid, fPack->cmd };
|
||||
message.data.assign(fPack->data.begin(), fPack->data.begin() + copyByteCount);
|
||||
|
||||
uint8_t currSeq = 0;
|
||||
|
||||
while (message.data.size() < messageSize)
|
||||
if (currSeq == -1u)
|
||||
{
|
||||
auto newPack = dynamic_pointer_cast<ContPacket>(Packet::getPacket());
|
||||
cid = 0;
|
||||
cmd = 0;
|
||||
messageSize = 0;
|
||||
dataBytes = {};
|
||||
|
||||
shared_ptr<InitPacket> initPack{};
|
||||
do
|
||||
{
|
||||
p = Packet::getPacket();
|
||||
|
||||
if (!newPack)
|
||||
throw runtime_error{ "Failed to receive Cont packet" };
|
||||
else if (newPack->seq != currSeq)
|
||||
throw runtime_error{ "Packet out of sequence" };
|
||||
if (!p)
|
||||
return {};
|
||||
|
||||
const uint16_t remainingBytes = messageSize - message.data.size();
|
||||
const uint16_t copyBytes = min(static_cast<uint16_t>(newPack->data.size()), remainingBytes);
|
||||
message.data.insert(message.data.end(), newPack->data.begin(), newPack->data.begin() + copyBytes);
|
||||
initPack = dynamic_pointer_cast<InitPacket>(p);
|
||||
} while (!initPack); //Spurious cont. packet - spec states ignore
|
||||
|
||||
messageSize = ((static_cast<uint16_t>(initPack->bcnth) << 8u) + initPack->bcntl);
|
||||
const uint16_t copyByteCount = min(static_cast<uint16_t>(initPack->data.size()), messageSize);
|
||||
|
||||
cid = initPack->cid;
|
||||
cmd = initPack->cmd;
|
||||
|
||||
copy(initPack->data.begin(), initPack->data.begin() + copyByteCount, back_inserter(dataBytes));
|
||||
currSeq++;
|
||||
}
|
||||
|
||||
while (messageSize > dataBytes.size() && static_cast<bool>(p = Packet::getPacket())) //While there is a packet
|
||||
{
|
||||
auto contPack = dynamic_pointer_cast<ContPacket>(p);
|
||||
|
||||
if (!contPack) //Spurious init. packet
|
||||
{
|
||||
currSeq = -1; //Reset
|
||||
return {};
|
||||
}
|
||||
|
||||
if (contPack->cid != cid) //Cont. packet of different CID
|
||||
{
|
||||
cerr << "Invalid CID: was handling channel 0x" << hex << cid << " and received packet from channel 0x" << contPack->cid << dec << endl;
|
||||
U2FMessage::error(contPack->cid, ERR_CHANNEL_BUSY);
|
||||
currSeq = -1;
|
||||
return {};
|
||||
}
|
||||
|
||||
if (contPack->seq != currSeq)
|
||||
{
|
||||
cerr << "Invalid packet seq. value" << endl;
|
||||
U2FMessage::error(cid, ERR_INVALID_SEQ);
|
||||
currSeq = -1;
|
||||
return {};
|
||||
}
|
||||
|
||||
const uint16_t remainingBytes = messageSize - dataBytes.size();
|
||||
const uint16_t copyBytes = min(static_cast<uint16_t>(contPack->data.size()), remainingBytes);
|
||||
dataBytes.insert(dataBytes.end(), contPack->data.begin(), contPack->data.begin() + copyBytes);
|
||||
currSeq++;
|
||||
}
|
||||
|
||||
if (messageSize != dataBytes.size())
|
||||
return {};
|
||||
|
||||
auto message = make_shared<U2FMessage>(cid, cmd);
|
||||
message->data.assign(dataBytes.begin(), dataBytes.end());
|
||||
currSeq = -1u;
|
||||
|
||||
std::clog << "Read all of message" << std::endl;
|
||||
|
||||
return message;
|
||||
@@ -49,9 +92,6 @@ U2FMessage U2FMessage::read()
|
||||
|
||||
void U2FMessage::write()
|
||||
{
|
||||
clog << "Flushing host stream" << endl;
|
||||
fflush(getHostStream().get());
|
||||
clog << "Flushed host stream" << endl;
|
||||
const uint16_t bytesToWrite = this->data.size();
|
||||
uint16_t bytesWritten = 0;
|
||||
|
||||
@@ -88,8 +128,7 @@ void U2FMessage::write()
|
||||
bytesWritten += newByteCount;
|
||||
}
|
||||
|
||||
auto stream = getHostStream().get();
|
||||
fflush(stream);
|
||||
//auto stream = *getHostStream();
|
||||
|
||||
if (cmd == U2FHID_MSG)
|
||||
{
|
||||
@@ -120,3 +159,16 @@ void U2FMessage::write()
|
||||
"\t\t<br />", err);
|
||||
}
|
||||
}
|
||||
|
||||
U2FMessage::U2FMessage(const uint32_t nCID, const uint8_t nCMD)
|
||||
: cid{ nCID }, cmd{ nCMD }
|
||||
{}
|
||||
|
||||
void U2FMessage::error(const uint32_t tCID, const uint16_t tErr)
|
||||
{
|
||||
U2FMessage msg{};
|
||||
msg.cid = tCID;
|
||||
msg.cmd = U2FHID_ERROR;
|
||||
msg.data.push_back((tErr >> 8) & 0xFF);
|
||||
msg.data.push_back(tErr & 0xFF);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user