From 322fa8af22f8880d58506fc18f4205ac4d3f937a Mon Sep 17 00:00:00 2001 From: Théo de la Hogue Date: Tue, 17 Oct 2023 15:58:55 +0200 Subject: adding gaze_position_calibrator to ArFrame. --- src/argaze/ArFeatures.py | 40 +++++++++++++++++++++++++---- src/argaze/GazeAnalysis/LinearRegression.py | 27 ++++++++++++++----- src/argaze/GazeFeatures.py | 19 +++++++++----- 3 files changed, 69 insertions(+), 17 deletions(-) (limited to 'src') diff --git a/src/argaze/ArFeatures.py b/src/argaze/ArFeatures.py index a1c7349..cb1b2f6 100644 --- a/src/argaze/ArFeatures.py +++ b/src/argaze/ArFeatures.py @@ -523,7 +523,8 @@ class ArFrame(): Parameters: name: name of the frame - size: defines the dimension of the rectangular area where gaze positions are projected. + size: defines the dimension of the rectangular area where gaze positions are projected + gaze_position_calibrator: gaze position calibration algoritm gaze_movement_identifier: gaze movement identification algorithm filter_in_progress_identification: ignore in progress gaze movement identification scan_path: scan path object @@ -537,6 +538,7 @@ class ArFrame(): name: str size: tuple[int] = field(default=(1, 1)) + gaze_position_calibrator: GazeFeatures.GazePositionCalibrator = field(default_factory=GazeFeatures.GazePositionCalibrator) gaze_movement_identifier: GazeFeatures.GazeMovementIdentifier = field(default_factory=GazeFeatures.GazeMovementIdentifier) filter_in_progress_identification: bool = field(default=True) scan_path: GazeFeatures.ScanPath = field(default_factory=GazeFeatures.ScanPath) @@ -600,6 +602,24 @@ class ArFrame(): new_frame_size = (0, 0) + # Load gaze position calibrator + try: + + gaze_position_calibrator_value = frame_data.pop('gaze_position_calibrator') + + gaze_position_calibrator_module_path, gaze_position_calibrator_parameters = gaze_position_calibrator_value.popitem() + + # Prepend argaze.GazeAnalysis path when a single name is provided + if len(gaze_position_calibrator_module_path.split('.')) == 1: + gaze_position_calibrator_module_path = f'argaze.GazeAnalysis.{gaze_position_calibrator_module_path}' + + gaze_position_calibrator_module = importlib.import_module(gaze_position_calibrator_module_path) + new_gaze_position_calibrator = gaze_position_calibrator_module.GazePositionCalibrator(**gaze_position_calibrator_parameters) + + except KeyError: + + new_gaze_position_calibrator = None + # Load gaze movement identifier try: @@ -756,6 +776,7 @@ class ArFrame(): # Create frame return ArFrame(new_frame_name, \ new_frame_size, \ + new_gaze_position_calibrator, \ new_gaze_movement_identifier, \ filter_in_progress_identification, \ new_scan_path, \ @@ -815,6 +836,7 @@ class ArFrame(): gaze_position: gaze position to project Returns: + current_gaze_position: calibrated gaze position if gaze_position_calibrator is instanciated else, given gaze position. identified_gaze_movement: identified gaze movement from incoming consecutive timestamped gaze positions if gaze_movement_identifier is instanciated. Current gaze movement if filter_in_progress_identification is False. scan_path_analysis: scan path analysis at each new scan step if scan_path is instanciated. layers_analysis: aoi scan path analysis at each new aoi scan step for each instanciated layers aoi scan path. @@ -828,9 +850,6 @@ class ArFrame(): # Store look execution start date look_start = time.perf_counter() - # Update current gaze position - self.__gaze_position = gaze_position - # No gaze movement identified by default identified_gaze_movement = GazeFeatures.UnvalidGazeMovement() @@ -853,6 +872,16 @@ class ArFrame(): try: + # Apply gaze position calibration + if self.gaze_position_calibrator is not None: + + self.__gaze_position = self.gaze_position_calibrator.apply(gaze_position) + + # Or update gaze position at least + else: + + self.__gaze_position = gaze_position + # Identify gaze movement if self.gaze_movement_identifier is not None: @@ -942,6 +971,7 @@ class ArFrame(): print('Warning: the following error occurs in ArFrame.look method:', e) + self.__gaze_position = GazeFeatures.UnvalidGazePosition() identified_gaze_movement = GazeFeatures.UnvalidGazeMovement() scan_step_analysis = {} layer_analysis = {} @@ -954,7 +984,7 @@ class ArFrame(): self.__look_lock.release() # Return look data - return identified_gaze_movement, scan_step_analysis, layer_analysis, execution_times, exception + return self.__gaze_position, identified_gaze_movement, scan_step_analysis, layer_analysis, execution_times, exception def __image(self, background_weight: float = None, heatmap_weight: float = None, draw_scan_path: dict = None, draw_layers: dict = None, draw_gaze_positions: dict = None, draw_fixations: dict = None, draw_saccades: dict = None) -> numpy.array: """ diff --git a/src/argaze/GazeAnalysis/LinearRegression.py b/src/argaze/GazeAnalysis/LinearRegression.py index 5a92048..de7725d 100644 --- a/src/argaze/GazeAnalysis/LinearRegression.py +++ b/src/argaze/GazeAnalysis/LinearRegression.py @@ -31,9 +31,11 @@ class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator): """Linear regression intercept value""" def __post_init__(self): - """Init calibration data.""" + """Init calibration.""" - self.reset() + self.__linear_regression = LinearRegression() + self.__linear_regression.coef_ = numpy.array(self.coefficients) + self.__linear_regression.intercept_ = numpy.array(self.intercept) def store(self, timestamp: int|float, observed_gaze_position: GazeFeatures.GazePosition, expected_gaze_position: GazeFeatures.GazePosition): """Store observed and expected gaze positions.""" @@ -57,12 +59,25 @@ class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator): self.__linear_regression = LinearRegression().fit(self.__observed_positions, self.__expected_positions) + # Update frozen coefficients attribute + object.__setattr__(self, 'coefficients', self.__linear_regression.coef_) + + # Update frozen intercept attribute + object.__setattr__(self, 'intercept', self.__linear_regression.intercept_) + + # Return calibrated gaze position return self.__linear_regression.score(self.__observed_positions, self.__expected_positions) def apply(self, gaze_position: GazeFeatures.GazePosition) -> GazePositionType: """Apply calibration onto observed gaze position.""" - return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position.value]))[0], precision=gaze_position.precision) + if not self.calibrating: + + return GazeFeatures.GazePosition(self.__linear_regression.predict(numpy.array([gaze_position.value]))[0], precision=gaze_position.precision) + + else: + + return gaze_position def draw(self, image: numpy.array): """Draw calibration into image. @@ -74,7 +89,7 @@ class GazePositionCalibrator(GazeFeatures.GazePositionCalibrator): raise NotImplementedError('draw() method not implemented') @property - def ready(self) -> bool: - """Is the calibrator ready?""" + def calibrating(self) -> bool: + """Is the calibration running?""" - return self.__linear_regression is not None \ No newline at end of file + return self.__linear_regression is None \ No newline at end of file diff --git a/src/argaze/GazeFeatures.py b/src/argaze/GazeFeatures.py index b918256..eddd01d 100644 --- a/src/argaze/GazeFeatures.py +++ b/src/argaze/GazeFeatures.py @@ -201,6 +201,13 @@ class TimeStampedGazePositions(DataStructures.TimeStampedBuffer): return TimeStampedGazePositions(df.to_dict('index')) +class GazePositionCalibrationFailed(Exception): + """Exception raised by GazePositionCalibrator.""" + + def __init__(self, message): + + super().__init__(message) + @dataclass class GazePositionCalibrator(): """Abstract class to define what should provide a gaze position calibrator algorithm.""" @@ -237,10 +244,10 @@ class GazePositionCalibrator(): observed_gaze_position: where gaze position actually is Returns: - expected_gaze_position: where gaze position should be + expected_gaze_position: where gaze position should be if the calibrator is ready else, observed gaze position """ - raise NotImplementedError('process() method not implemented') + raise NotImplementedError('apply() method not implemented') def draw(self, image: numpy.array): """Draw calibration into image. @@ -252,8 +259,8 @@ class GazePositionCalibrator(): raise NotImplementedError('draw() method not implemented') @property - def ready(self) -> bool: - """Is the calibrator ready?""" + def calibrating(self) -> bool: + """Is the calibration running?""" raise NotImplementedError('ready getter not implemented') @@ -601,7 +608,7 @@ ScanStepType = TypeVar('ScanStep', bound="ScanStep") # Type definition for type annotation convenience class ScanStepError(Exception): - """Exception raised at ScanStepError creation if a aoi scan step doesn't start by a fixation or doesn't end by a saccade.""" + """Exception raised at ScanStep creation if a aoi scan step doesn't start by a fixation or doesn't end by a saccade.""" def __init__(self, message): @@ -811,7 +818,7 @@ AOIScanStepType = TypeVar('AOIScanStep', bound="AOIScanStep") # Type definition for type annotation convenience class AOIScanStepError(Exception): - """Exception raised at AOIScanStepError creation if a aoi scan step doesn't start by a fixation or doesn't end by a saccade.""" + """Exception raised at AOIScanStep creation if a aoi scan step doesn't start by a fixation or doesn't end by a saccade.""" def __init__(self, message, aoi=''): -- cgit v1.1