From 1d46c5816ba603105dfaa1b5a79f3a167fdc99d8 Mon Sep 17 00:00:00 2001 From: Théo de la Hogue Date: Tue, 17 Oct 2023 12:56:41 +0200 Subject: Adding GazePositionCalibrator class. Adding LinearRegression module. --- setup.py | 2 +- src/argaze/GazeAnalysis/LinearRegression.py | 80 +++++++++++++++++++++++++++++ src/argaze/GazeAnalysis/__init__.py | 2 +- src/argaze/GazeFeatures.py | 56 ++++++++++++++++++++ 4 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 src/argaze/GazeAnalysis/LinearRegression.py diff --git a/setup.py b/setup.py index 358c19e..706f414 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ setup( packages=find_packages(where='src'), python_requires='>=3.11', - install_requires=['opencv-python>=4.7.0', 'opencv-contrib-python>=4.7.0', 'numpy', 'pandas', 'matplotlib', 'shapely', 'lempel_ziv_complexity', 'scipy'], + install_requires=['opencv-python>=4.7.0', 'opencv-contrib-python>=4.7.0', 'numpy', 'pandas', 'matplotlib', 'shapely', 'lempel_ziv_complexity', 'scipy', 'scikit-learn'], project_urls={ 'Bug Reports': 'https://git.recherche.enac.fr/projects/argaze/issues', diff --git a/src/argaze/GazeAnalysis/LinearRegression.py b/src/argaze/GazeAnalysis/LinearRegression.py new file mode 100644 index 0000000..5a92048 --- /dev/null +++ b/src/argaze/GazeAnalysis/LinearRegression.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python + +"""Module for gaze position calibration based on linear regression. +""" + +__author__ = "Théo de la Hogue" +__credits__ = [] +__copyright__ = "Copyright 2023, Ecole Nationale de l'Aviation Civile (ENAC)" +__license__ = "BSD" + +from typing import TypeVar, Tuple +from dataclasses import dataclass, field + +from argaze import GazeFeatures + +from sklearn.linear_model import LinearRegression +import numpy +import cv2 + +GazePositionType = TypeVar('GazePositionType', bound="GazePositionType") +# Type definition for type annotation convenience + +@dataclass +class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator): + """Calibration algorithm based on linear regression.""" + + coefficients: numpy.array = field(default_factory=lambda : numpy.array([[1., 0.], [0., 1.]])) + """Linear regression coefficients""" + + intercept: numpy.array = field(default_factory=lambda : numpy.array([0., 0.])) + """Linear regression intercept value""" + + def __post_init__(self): + """Init calibration data.""" + + self.reset() + + def store(self, timestamp: int|float, observed_gaze_position: GazeFeatures.GazePosition, expected_gaze_position: GazeFeatures.GazePosition): + """Store observed and expected gaze positions.""" + + self.__observed_positions.append(observed_gaze_position.value) + self.__expected_positions.append(expected_gaze_position.value) + + def reset(self): + """Reset observed and expected gaze positions.""" + + self.__observed_positions = [] + self.__expected_positions = [] + self.__linear_regression = None + + def calibrate(self) -> float: + """Process calibration from observed and expected gaze positions. + + Returns: + score: the score of linear regression + """ + + self.__linear_regression = LinearRegression().fit(self.__observed_positions, self.__expected_positions) + + return self.__linear_regression.score(self.__observed_positions, self.__expected_positions) + + def apply(self, gaze_position: GazeFeatures.GazePosition) -> GazePositionType: + """Apply calibration onto observed gaze position.""" + + return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position.value]))[0], precision=gaze_position.precision) + + def draw(self, image: numpy.array): + """Draw calibration into image. + + Parameters: + image: where to draw + """ + + raise NotImplementedError('draw() method not implemented') + + @property + def ready(self) -> bool: + """Is the calibrator ready?""" + + return self.__linear_regression is not None \ No newline at end of file diff --git a/src/argaze/GazeAnalysis/__init__.py b/src/argaze/GazeAnalysis/__init__.py index 62e0823..c110eb1 100644 --- a/src/argaze/GazeAnalysis/__init__.py +++ b/src/argaze/GazeAnalysis/__init__.py @@ -1,4 +1,4 @@ """ Various gaze movement identification, AOI matching and scan path analysis algorithms. """ -__all__ = ['Basic', 'DispersionThresholdIdentification', 'VelocityThresholdIdentification', 'TransitionMatrix', 'KCoefficient', 'LempelZivComplexity', 'NGram', 'Entropy', 'NearestNeighborIndex', 'ExploreExploitRatio'] \ No newline at end of file +__all__ = ['Basic', 'DispersionThresholdIdentification', 'VelocityThresholdIdentification', 'TransitionMatrix', 'KCoefficient', 'LempelZivComplexity', 'NGram', 'Entropy', 'NearestNeighborIndex', 'ExploreExploitRatio', 'LinearRegression'] \ No newline at end of file diff --git a/src/argaze/GazeFeatures.py b/src/argaze/GazeFeatures.py index bd1a3da..b918256 100644 --- a/src/argaze/GazeFeatures.py +++ b/src/argaze/GazeFeatures.py @@ -201,6 +201,62 @@ class TimeStampedGazePositions(DataStructures.TimeStampedBuffer): return TimeStampedGazePositions(df.to_dict('index')) +@dataclass +class GazePositionCalibrator(): + """Abstract class to define what should provide a gaze position calibrator algorithm.""" + + def store(self, timestamp: int|float, observed_gaze_position: GazePosition, expected_gaze_position: GazePosition): + """Store observed and expected gaze positions. + + Parameters: + timestamp: time of observed gaze position + observed_gaze_position: where gaze position actually is + expected_gaze_position: where gaze position should be + """ + + raise NotImplementedError('calibrate() method not implemented') + + def reset(self): + """Reset observed and expected gaze positions.""" + + raise NotImplementedError('reset() method not implemented') + + def calibrate(self) -> Any: + """Process calibration from observed and expected gaze positions. + + Returns: + calibration outputs: any data returned to assess calibration + """ + + raise NotImplementedError('terminate() method not implemented') + + def apply(self, observed_gaze_position: GazePosition) -> GazePositionType: + """Apply calibration onto observed gaze position. + + Parameters: + observed_gaze_position: where gaze position actually is + + Returns: + expected_gaze_position: where gaze position should be + """ + + raise NotImplementedError('process() method not implemented') + + def draw(self, image: numpy.array): + """Draw calibration into image. + + Parameters: + image: where to draw + """ + + raise NotImplementedError('draw() method not implemented') + + @property + def ready(self) -> bool: + """Is the calibrator ready?""" + + raise NotImplementedError('ready getter not implemented') + GazeMovementType = TypeVar('GazeMovement', bound="GazeMovement") # Type definition for type annotation convenience -- cgit v1.1