aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--setup.py2
-rw-r--r--src/argaze/GazeAnalysis/LinearRegression.py80
-rw-r--r--src/argaze/GazeAnalysis/__init__.py2
-rw-r--r--src/argaze/GazeFeatures.py56
4 files changed, 138 insertions, 2 deletions
diff --git a/setup.py b/setup.py
index 358c19e..706f414 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,7 @@ setup(
packages=find_packages(where='src'),
python_requires='>=3.11',
- install_requires=['opencv-python>=4.7.0', 'opencv-contrib-python>=4.7.0', 'numpy', 'pandas', 'matplotlib', 'shapely', 'lempel_ziv_complexity', 'scipy'],
+ install_requires=['opencv-python>=4.7.0', 'opencv-contrib-python>=4.7.0', 'numpy', 'pandas', 'matplotlib', 'shapely', 'lempel_ziv_complexity', 'scipy', 'scikit-learn'],
project_urls={
'Bug Reports': 'https://git.recherche.enac.fr/projects/argaze/issues',
diff --git a/src/argaze/GazeAnalysis/LinearRegression.py b/src/argaze/GazeAnalysis/LinearRegression.py
new file mode 100644
index 0000000..5a92048
--- /dev/null
+++ b/src/argaze/GazeAnalysis/LinearRegression.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python
+
+"""Module for gaze position calibration based on linear regression.
+"""
+
+__author__ = "Théo de la Hogue"
+__credits__ = []
+__copyright__ = "Copyright 2023, Ecole Nationale de l'Aviation Civile (ENAC)"
+__license__ = "BSD"
+
+from typing import TypeVar, Tuple
+from dataclasses import dataclass, field
+
+from argaze import GazeFeatures
+
+from sklearn.linear_model import LinearRegression
+import numpy
+import cv2
+
+GazePositionType = TypeVar('GazePositionType', bound="GazePositionType")
+# Type definition for type annotation convenience
+
+@dataclass
+class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator):
+ """Calibration algorithm based on linear regression."""
+
+ coefficients: numpy.array = field(default_factory=lambda : numpy.array([[1., 0.], [0., 1.]]))
+ """Linear regression coefficients"""
+
+ intercept: numpy.array = field(default_factory=lambda : numpy.array([0., 0.]))
+ """Linear regression intercept value"""
+
+ def __post_init__(self):
+ """Init calibration data."""
+
+ self.reset()
+
+ def store(self, timestamp: int|float, observed_gaze_position: GazeFeatures.GazePosition, expected_gaze_position: GazeFeatures.GazePosition):
+ """Store observed and expected gaze positions."""
+
+ self.__observed_positions.append(observed_gaze_position.value)
+ self.__expected_positions.append(expected_gaze_position.value)
+
+ def reset(self):
+ """Reset observed and expected gaze positions."""
+
+ self.__observed_positions = []
+ self.__expected_positions = []
+ self.__linear_regression = None
+
+ def calibrate(self) -> float:
+ """Process calibration from observed and expected gaze positions.
+
+ Returns:
+ score: the score of linear regression
+ """
+
+ self.__linear_regression = LinearRegression().fit(self.__observed_positions, self.__expected_positions)
+
+ return self.__linear_regression.score(self.__observed_positions, self.__expected_positions)
+
+ def apply(self, gaze_position: GazeFeatures.GazePosition) -> GazePositionType:
+ """Apply calibration onto observed gaze position."""
+
+ return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position.value]))[0], precision=gaze_position.precision)
+
+ def draw(self, image: numpy.array):
+ """Draw calibration into image.
+
+ Parameters:
+ image: where to draw
+ """
+
+ raise NotImplementedError('draw() method not implemented')
+
+ @property
+ def ready(self) -> bool:
+ """Is the calibrator ready?"""
+
+ return self.__linear_regression is not None \ No newline at end of file
diff --git a/src/argaze/GazeAnalysis/__init__.py b/src/argaze/GazeAnalysis/__init__.py
index 62e0823..c110eb1 100644
--- a/src/argaze/GazeAnalysis/__init__.py
+++ b/src/argaze/GazeAnalysis/__init__.py
@@ -1,4 +1,4 @@
"""
Various gaze movement identification, AOI matching and scan path analysis algorithms.
"""
-__all__ = ['Basic', 'DispersionThresholdIdentification', 'VelocityThresholdIdentification', 'TransitionMatrix', 'KCoefficient', 'LempelZivComplexity', 'NGram', 'Entropy', 'NearestNeighborIndex', 'ExploreExploitRatio'] \ No newline at end of file
+__all__ = ['Basic', 'DispersionThresholdIdentification', 'VelocityThresholdIdentification', 'TransitionMatrix', 'KCoefficient', 'LempelZivComplexity', 'NGram', 'Entropy', 'NearestNeighborIndex', 'ExploreExploitRatio', 'LinearRegression'] \ No newline at end of file
diff --git a/src/argaze/GazeFeatures.py b/src/argaze/GazeFeatures.py
index bd1a3da..b918256 100644
--- a/src/argaze/GazeFeatures.py
+++ b/src/argaze/GazeFeatures.py
@@ -201,6 +201,62 @@ class TimeStampedGazePositions(DataStructures.TimeStampedBuffer):
return TimeStampedGazePositions(df.to_dict('index'))
+@dataclass
+class GazePositionCalibrator():
+ """Abstract class to define what should provide a gaze position calibrator algorithm."""
+
+ def store(self, timestamp: int|float, observed_gaze_position: GazePosition, expected_gaze_position: GazePosition):
+ """Store observed and expected gaze positions.
+
+ Parameters:
+ timestamp: time of observed gaze position
+ observed_gaze_position: where gaze position actually is
+ expected_gaze_position: where gaze position should be
+ """
+
+ raise NotImplementedError('calibrate() method not implemented')
+
+ def reset(self):
+ """Reset observed and expected gaze positions."""
+
+ raise NotImplementedError('reset() method not implemented')
+
+ def calibrate(self) -> Any:
+ """Process calibration from observed and expected gaze positions.
+
+ Returns:
+ calibration outputs: any data returned to assess calibration
+ """
+
+ raise NotImplementedError('terminate() method not implemented')
+
+ def apply(self, observed_gaze_position: GazePosition) -> GazePositionType:
+ """Apply calibration onto observed gaze position.
+
+ Parameters:
+ observed_gaze_position: where gaze position actually is
+
+ Returns:
+ expected_gaze_position: where gaze position should be
+ """
+
+ raise NotImplementedError('process() method not implemented')
+
+ def draw(self, image: numpy.array):
+ """Draw calibration into image.
+
+ Parameters:
+ image: where to draw
+ """
+
+ raise NotImplementedError('draw() method not implemented')
+
+ @property
+ def ready(self) -> bool:
+ """Is the calibrator ready?"""
+
+ raise NotImplementedError('ready getter not implemented')
+
GazeMovementType = TypeVar('GazeMovement', bound="GazeMovement")
# Type definition for type annotation convenience