aboutsummaryrefslogtreecommitdiff
path: root/src/argaze/GazeAnalysis/LinearRegression.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/argaze/GazeAnalysis/LinearRegression.py')
-rw-r--r--src/argaze/GazeAnalysis/LinearRegression.py107
1 files changed, 107 insertions, 0 deletions
diff --git a/src/argaze/GazeAnalysis/LinearRegression.py b/src/argaze/GazeAnalysis/LinearRegression.py
new file mode 100644
index 0000000..0e10b87
--- /dev/null
+++ b/src/argaze/GazeAnalysis/LinearRegression.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+
+"""Module for gaze position calibration based on linear regression.
+"""
+
+__author__ = "Théo de la Hogue"
+__credits__ = []
+__copyright__ = "Copyright 2023, Ecole Nationale de l'Aviation Civile (ENAC)"
+__license__ = "BSD"
+
+from typing import TypeVar, Tuple
+from dataclasses import dataclass, field
+
+from argaze import GazeFeatures
+
+from sklearn.linear_model import LinearRegression
+import numpy
+import cv2
+
+GazePositionType = TypeVar('GazePositionType', bound="GazePositionType")
+# Type definition for type annotation convenience
+
+@dataclass
+class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator):
+ """Calibration algorithm based on linear regression."""
+
+ coefficients: numpy.array = field(default_factory=lambda : numpy.array([[1., 0.], [0., 1.]]))
+ """Linear regression coefficients"""
+
+ intercept: numpy.array = field(default_factory=lambda : numpy.array([0., 0.]))
+ """Linear regression intercept value"""
+
+ def __post_init__(self):
+ """Init calibration."""
+
+ self.__linear_regression = LinearRegression()
+ self.__linear_regression.coef_ = numpy.array(self.coefficients)
+ self.__linear_regression.intercept_ = numpy.array(self.intercept)
+
+ def store(self, timestamp: int|float, observed_gaze_position: GazeFeatures.GazePosition, expected_gaze_position: GazeFeatures.GazePosition):
+ """Store observed and expected gaze positions."""
+
+ self.__observed_positions.append(observed_gaze_position.value)
+ self.__expected_positions.append(expected_gaze_position.value)
+
+ def reset(self):
+ """Reset observed and expected gaze positions."""
+
+ self.__observed_positions = []
+ self.__expected_positions = []
+ self.__linear_regression = None
+
+ def calibrate(self) -> float:
+ """Process calibration from observed and expected gaze positions.
+
+ Returns:
+ score: the score of linear regression
+ """
+
+ self.__linear_regression = LinearRegression().fit(self.__observed_positions, self.__expected_positions)
+
+ # Update frozen coefficients attribute
+ object.__setattr__(self, 'coefficients', self.__linear_regression.coef_)
+
+ # Update frozen intercept attribute
+ object.__setattr__(self, 'intercept', self.__linear_regression.intercept_)
+
+ # Return calibrated gaze position
+ return self.__linear_regression.score(self.__observed_positions, self.__expected_positions)
+
+ def apply(self, gaze_position: GazeFeatures.GazePosition) -> GazePositionType:
+ """Apply calibration onto observed gaze position."""
+
+ if not self.calibrating:
+
+ return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position.value]))[0], precision=gaze_position.precision)
+
+ else:
+
+ return gaze_position
+
+ def draw(self, image: numpy.array, size: tuple, resolution: tuple, line_color: tuple = (0, 0, 0), thickness: int = 1):
+ """Draw calibration field."""
+
+ width, height = size
+
+ if width * height > 0:
+
+ rx, ry = resolution
+ lx = numpy.linspace(0, width, rx)
+ ly = numpy.linspace(0, height, ry)
+ xv, yv = numpy.meshgrid(lx, ly, indexing='ij')
+
+ for i in range(rx):
+
+ for j in range(ry):
+
+ start = (xv[i][j], yv[i][j])
+ end = self.apply(GazeFeatures.GazePosition(start)).value
+
+ cv2.line(image, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), line_color, thickness)
+
+ @property
+ def calibrating(self) -> bool:
+ """Is the calibration running?"""
+
+ return self.__linear_regression is None \ No newline at end of file