aboutsummaryrefslogtreecommitdiff
path: root/src/argaze/GazeAnalysis/LinearRegression.py
blob: 747d5f39f356f5fc00d5428d9073110d483fb46e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Module for gaze position calibration based on linear regression.
"""

__author__ = "Théo de la Hogue"
__credits__ = []
__copyright__ = "Copyright 2023, Ecole Nationale de l'Aviation Civile (ENAC)"
__license__ = "BSD"

from typing import TypeVar, Tuple

from argaze import DataFeatures, GazeFeatures

from sklearn.linear_model import LinearRegression
import numpy
import cv2

GazePositionType = TypeVar('GazePositionType', bound="GazePositionType")
# Type definition for type annotation convenience

class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator):
    """Implementation of linear regression algorithm as described in:

        **Drewes, H., Pfeuffer, K., & Alt, F. (2019, June).**  
        *Time- and space-efficient eye tracker calibration.*  
        Proceedings of the 11th ACM Symposium on Eye Tracking Research & Applications (ETRA'19, 1-8).  
        [https://dl.acm.org/doi/pdf/10.1145/3314111.3319818](https://dl.acm.org/doi/pdf/10.1145/3314111.3319818)

    Parameters:
        coefficients: linear regression coefficients.
        intercept: linear regression intercept value.
    """

    def __init__(self, coefficients: list = [[1., 0.], [0., 1.]], intercept: list = [0., 0.]):

        super().__init__()

        self.__linear_regression = LinearRegression()
        self.__linear_regression.coef_ = numpy.array(coefficients)
        self.__linear_regression.intercept_ = numpy.array(intercept)

    @property
    def coefficients(self) -> list:
        """Get linear regression coefficients."""
        return self.__linear_regression.coef_.tolist()

    @property
    def intercept(self):
        """Get linear regression intercept value."""
        return self.__linear_regression.intercept_.tolist()

    def is_calibrating(self) -> bool:
        """Is the calibration running?"""
        return self.__linear_regression is None

    def store(self, observed_gaze_position: GazeFeatures.GazePosition, expected_gaze_position: GazeFeatures.GazePosition):
        """Store observed and expected gaze positions."""
        self.__observed_positions.append(observed_gaze_position)
        self.__expected_positions.append(expected_gaze_position)

    def reset(self):
        """Reset observed and expected gaze positions."""
        self.__observed_positions = []
        self.__expected_positions = []
        self.__linear_regression = None

    def calibrate(self) -> float:
        """Process calibration from observed and expected gaze positions.

        Returns:
            score: the score of linear regression
        """
        self.__linear_regression = LinearRegression().fit(self.__observed_positions, self.__expected_positions)

        # Return calibrated gaze position
        return self.__linear_regression.score(self.__observed_positions, self.__expected_positions)

    def apply(self, gaze_position: GazeFeatures.GazePosition) -> GazePositionType:
        """Apply calibration onto observed gaze position."""
        if not self.is_calibrating():

            return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position]))[0], precision=gaze_position.precision, timestamp=gaze_position.timestamp)

        else:

            return gaze_position

    def draw(self, image: numpy.array, size: tuple, resolution: tuple, line_color: tuple = (0, 0, 0), thickness: int = 1):
        """Draw calibration field."""
        width, height = size
        
        if width * height > 0:

            rx, ry = resolution
            lx = numpy.linspace(0, width, rx)
            ly = numpy.linspace(0, height, ry)
            xv, yv = numpy.meshgrid(lx, ly, indexing='ij')

            for i in range(rx):

                for j in range(ry):

                    start = (xv[i][j], yv[i][j])
                    end = self.apply(GazeFeatures.GazePosition(start)).value

                    cv2.line(image, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), line_color, thickness)