IBDT.h 1.5 KB
Newer Older
Thiago Santini's avatar
Thiago Santini committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
#ifndef IBDT_H
#define IBDT_H

#include <deque>
#include <vector>
#include <algorithm>

#include <GazeData.h>

#include <opencv2/ml.hpp>

class IBDT_Prob {
public:
    IBDT_Prob() : prior(0), likelihood(0), posterior(0) {}
    double prior;
    double likelihood;
    double posterior;
    void update() { posterior = prior*likelihood; }
};

class IBDT_Data : public GazeDataEntry {
public:
    IBDT_Data(GazeDataEntry base) :
        GazeDataEntry(base),
        pursuit(),
        fixation(),
        saccade() { }

    IBDT_Prob pursuit;
    IBDT_Prob fixation;
    IBDT_Prob saccade;
};

class IBDT
{
public:
    enum CLASSIFICATION {
        TERNARY = 0,
        BINARY = 1
    };

    IBDT(const double &maxSaccadeDurationMs=80, const double &minSampleConfidence=0.5, const enum CLASSIFICATION &classification=TERNARY);
    void addPoint(GazeDataEntry &entry);
    void train(std::vector<GazeDataEntry> &gaze);
    double estimateVelocity(const GazeDataEntry &cur, const GazeDataEntry &prev);

private:
    double maxSaccadeDurationMs;
    double minSampleConfidence;
    CLASSIFICATION classification;
    std::deque<IBDT_Data> window;

    IBDT_Data *cur;
    IBDT_Data *prev;
    bool firstPoint;
    cv::Ptr<cv::ml::EM> model;
    unsigned int fIdx, sIdx;
    double fMean, sMean;

    void updatePursuitPrior();
    void updatePursuitLikelihood();
    void updatePursuitLikelihoodNew();
    void updateFixationAndSaccadeLikelihood();

    void binaryClassification();
    void ternaryClassification();
};

#endif // IBDT_H