diff options
author | Théo de la Hogue | 2023-10-17 12:56:41 +0200 |
---|---|---|
committer | Théo de la Hogue | 2023-10-17 12:56:41 +0200 |
commit | 1d46c5816ba603105dfaa1b5a79f3a167fdc99d8 (patch) | |
tree | cc667f9e526b84f7051b335e11d82a75c793e4b5 /src | |
parent | 77914e2aa25623a237a58b7c80f712129cbb2b55 (diff) | |
download | argaze-1d46c5816ba603105dfaa1b5a79f3a167fdc99d8.zip argaze-1d46c5816ba603105dfaa1b5a79f3a167fdc99d8.tar.gz argaze-1d46c5816ba603105dfaa1b5a79f3a167fdc99d8.tar.bz2 argaze-1d46c5816ba603105dfaa1b5a79f3a167fdc99d8.tar.xz |
Adding GazePositionCalibrator class. Adding LinearRegression module.
Diffstat (limited to 'src')
-rw-r--r-- | src/argaze/GazeAnalysis/LinearRegression.py | 80 | ||||
-rw-r--r-- | src/argaze/GazeAnalysis/__init__.py | 2 | ||||
-rw-r--r-- | src/argaze/GazeFeatures.py | 56 |
3 files changed, 137 insertions, 1 deletions
diff --git a/src/argaze/GazeAnalysis/LinearRegression.py b/src/argaze/GazeAnalysis/LinearRegression.py new file mode 100644 index 0000000..5a92048 --- /dev/null +++ b/src/argaze/GazeAnalysis/LinearRegression.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python + +"""Module for gaze position calibration based on linear regression. +""" + +__author__ = "Théo de la Hogue" +__credits__ = [] +__copyright__ = "Copyright 2023, Ecole Nationale de l'Aviation Civile (ENAC)" +__license__ = "BSD" + +from typing import TypeVar, Tuple +from dataclasses import dataclass, field + +from argaze import GazeFeatures + +from sklearn.linear_model import LinearRegression +import numpy +import cv2 + +GazePositionType = TypeVar('GazePositionType', bound="GazePositionType") +# Type definition for type annotation convenience + +@dataclass +class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator): + """Calibration algorithm based on linear regression.""" + + coefficients: numpy.array = field(default_factory=lambda : numpy.array([[1., 0.], [0., 1.]])) + """Linear regression coefficients""" + + intercept: numpy.array = field(default_factory=lambda : numpy.array([0., 0.])) + """Linear regression intercept value""" + + def __post_init__(self): + """Init calibration data.""" + + self.reset() + + def store(self, timestamp: int|float, observed_gaze_position: GazeFeatures.GazePosition, expected_gaze_position: GazeFeatures.GazePosition): + """Store observed and expected gaze positions.""" + + self.__observed_positions.append(observed_gaze_position.value) + self.__expected_positions.append(expected_gaze_position.value) + + def reset(self): + """Reset observed and expected gaze positions.""" + + self.__observed_positions = [] + self.__expected_positions = [] + self.__linear_regression = None + + def calibrate(self) -> float: + """Process calibration from observed and expected gaze positions. + + Returns: + score: the score of linear regression + """ + + self.__linear_regression = LinearRegression().fit(self.__observed_positions, self.__expected_positions) + + return self.__linear_regression.score(self.__observed_positions, self.__expected_positions) + + def apply(self, gaze_position: GazeFeatures.GazePosition) -> GazePositionType: + """Apply calibration onto observed gaze position.""" + + return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position.value]))[0], precision=gaze_position.precision) + + def draw(self, image: numpy.array): + """Draw calibration into image. + + Parameters: + image: where to draw + """ + + raise NotImplementedError('draw() method not implemented') + + @property + def ready(self) -> bool: + """Is the calibrator ready?""" + + return self.__linear_regression is not None
\ No newline at end of file diff --git a/src/argaze/GazeAnalysis/__init__.py b/src/argaze/GazeAnalysis/__init__.py index 62e0823..c110eb1 100644 --- a/src/argaze/GazeAnalysis/__init__.py +++ b/src/argaze/GazeAnalysis/__init__.py @@ -1,4 +1,4 @@ """ Various gaze movement identification, AOI matching and scan path analysis algorithms. """ -__all__ = ['Basic', 'DispersionThresholdIdentification', 'VelocityThresholdIdentification', 'TransitionMatrix', 'KCoefficient', 'LempelZivComplexity', 'NGram', 'Entropy', 'NearestNeighborIndex', 'ExploreExploitRatio']
\ No newline at end of file +__all__ = ['Basic', 'DispersionThresholdIdentification', 'VelocityThresholdIdentification', 'TransitionMatrix', 'KCoefficient', 'LempelZivComplexity', 'NGram', 'Entropy', 'NearestNeighborIndex', 'ExploreExploitRatio', 'LinearRegression']
\ No newline at end of file diff --git a/src/argaze/GazeFeatures.py b/src/argaze/GazeFeatures.py index bd1a3da..b918256 100644 --- a/src/argaze/GazeFeatures.py +++ b/src/argaze/GazeFeatures.py @@ -201,6 +201,62 @@ class TimeStampedGazePositions(DataStructures.TimeStampedBuffer): return TimeStampedGazePositions(df.to_dict('index')) +@dataclass +class GazePositionCalibrator(): + """Abstract class to define what should provide a gaze position calibrator algorithm.""" + + def store(self, timestamp: int|float, observed_gaze_position: GazePosition, expected_gaze_position: GazePosition): + """Store observed and expected gaze positions. + + Parameters: + timestamp: time of observed gaze position + observed_gaze_position: where gaze position actually is + expected_gaze_position: where gaze position should be + """ + + raise NotImplementedError('calibrate() method not implemented') + + def reset(self): + """Reset observed and expected gaze positions.""" + + raise NotImplementedError('reset() method not implemented') + + def calibrate(self) -> Any: + """Process calibration from observed and expected gaze positions. + + Returns: + calibration outputs: any data returned to assess calibration + """ + + raise NotImplementedError('terminate() method not implemented') + + def apply(self, observed_gaze_position: GazePosition) -> GazePositionType: + """Apply calibration onto observed gaze position. + + Parameters: + observed_gaze_position: where gaze position actually is + + Returns: + expected_gaze_position: where gaze position should be + """ + + raise NotImplementedError('process() method not implemented') + + def draw(self, image: numpy.array): + """Draw calibration into image. + + Parameters: + image: where to draw + """ + + raise NotImplementedError('draw() method not implemented') + + @property + def ready(self) -> bool: + """Is the calibrator ready?""" + + raise NotImplementedError('ready getter not implemented') + GazeMovementType = TypeVar('GazeMovement', bound="GazeMovement") # Type definition for type annotation convenience |