diff options
Diffstat (limited to 'src/argaze/GazeAnalysis/LinearRegression.py')
-rw-r--r-- | src/argaze/GazeAnalysis/LinearRegression.py | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/src/argaze/GazeAnalysis/LinearRegression.py b/src/argaze/GazeAnalysis/LinearRegression.py new file mode 100644 index 0000000..0e10b87 --- /dev/null +++ b/src/argaze/GazeAnalysis/LinearRegression.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +"""Module for gaze position calibration based on linear regression. +""" + +__author__ = "Théo de la Hogue" +__credits__ = [] +__copyright__ = "Copyright 2023, Ecole Nationale de l'Aviation Civile (ENAC)" +__license__ = "BSD" + +from typing import TypeVar, Tuple +from dataclasses import dataclass, field + +from argaze import GazeFeatures + +from sklearn.linear_model import LinearRegression +import numpy +import cv2 + +GazePositionType = TypeVar('GazePositionType', bound="GazePositionType") +# Type definition for type annotation convenience + +@dataclass +class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator): + """Calibration algorithm based on linear regression.""" + + coefficients: numpy.array = field(default_factory=lambda : numpy.array([[1., 0.], [0., 1.]])) + """Linear regression coefficients""" + + intercept: numpy.array = field(default_factory=lambda : numpy.array([0., 0.])) + """Linear regression intercept value""" + + def __post_init__(self): + """Init calibration.""" + + self.__linear_regression = LinearRegression() + self.__linear_regression.coef_ = numpy.array(self.coefficients) + self.__linear_regression.intercept_ = numpy.array(self.intercept) + + def store(self, timestamp: int|float, observed_gaze_position: GazeFeatures.GazePosition, expected_gaze_position: GazeFeatures.GazePosition): + """Store observed and expected gaze positions.""" + + self.__observed_positions.append(observed_gaze_position.value) + self.__expected_positions.append(expected_gaze_position.value) + + def reset(self): + """Reset observed and expected gaze positions.""" + + self.__observed_positions = [] + self.__expected_positions = [] + self.__linear_regression = None + + def calibrate(self) -> float: + """Process calibration from observed and expected gaze positions. + + Returns: + score: the score of linear regression + """ + + self.__linear_regression = LinearRegression().fit(self.__observed_positions, self.__expected_positions) + + # Update frozen coefficients attribute + object.__setattr__(self, 'coefficients', self.__linear_regression.coef_) + + # Update frozen intercept attribute + object.__setattr__(self, 'intercept', self.__linear_regression.intercept_) + + # Return calibrated gaze position + return self.__linear_regression.score(self.__observed_positions, self.__expected_positions) + + def apply(self, gaze_position: GazeFeatures.GazePosition) -> GazePositionType: + """Apply calibration onto observed gaze position.""" + + if not self.calibrating: + + return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position.value]))[0], precision=gaze_position.precision) + + else: + + return gaze_position + + def draw(self, image: numpy.array, size: tuple, resolution: tuple, line_color: tuple = (0, 0, 0), thickness: int = 1): + """Draw calibration field.""" + + width, height = size + + if width * height > 0: + + rx, ry = resolution + lx = numpy.linspace(0, width, rx) + ly = numpy.linspace(0, height, ry) + xv, yv = numpy.meshgrid(lx, ly, indexing='ij') + + for i in range(rx): + + for j in range(ry): + + start = (xv[i][j], yv[i][j]) + end = self.apply(GazeFeatures.GazePosition(start)).value + + cv2.line(image, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), line_color, thickness) + + @property + def calibrating(self) -> bool: + """Is the calibration running?""" + + return self.__linear_regression is None
\ No newline at end of file |