From addc275da0fe6edcf3ae8d73ffb4a3d5e3d375f9 Mon Sep 17 00:00:00 2001 From: Théo de la Hogue Date: Thu, 25 Apr 2024 12:18:12 +0200 Subject: Making IDT faster. --- .../DispersionThresholdIdentification.py | 204 +++++++++++---------- 1 file changed, 109 insertions(+), 95 deletions(-) diff --git a/src/argaze/GazeAnalysis/DispersionThresholdIdentification.py b/src/argaze/GazeAnalysis/DispersionThresholdIdentification.py index cd31b04..111d1c3 100644 --- a/src/argaze/GazeAnalysis/DispersionThresholdIdentification.py +++ b/src/argaze/GazeAnalysis/DispersionThresholdIdentification.py @@ -17,34 +17,26 @@ __credits__ = [] __copyright__ = "Copyright 2023, Ecole Nationale de l'Aviation Civile (ENAC)" __license__ = "GPLv3" -import cv2 -import numpy +import math from argaze import GazeFeatures, DataFeatures +import cv2 +import numpy class Fixation(GazeFeatures.Fixation): """Define dispersion based fixation.""" - def __init__(self, positions: GazeFeatures.TimeStampedGazePositions = (), finished: bool = False, message: str = None, **kwargs): - - super().__init__(positions, finished, message, **kwargs) + def __init__(self, focus: tuple = (), deviation_max: float = math.nan, **kwargs): - if positions: + super().__init__(**kwargs) - positions_array = numpy.asarray(self.values()) - centroid = numpy.mean(positions_array, axis=0) - deviations_array = numpy.sqrt(numpy.sum((positions_array - centroid)**2, axis=1)) - - # Set focus as positions centroid - self.focus = (centroid[0], centroid[1]) - - # Set deviation_max attribute - self.__deviation_max = deviations_array.max() + self._focus = focus + self.__deviation_max = deviation_max @property - def deviation_max(self): - """Get fixation's maximal deviation.""" + def deviation_max(self) -> float: + """Fixation's maximal deviation.""" return self.__deviation_max def is_overlapping(self, fixation: GazeFeatures.Fixation) -> bool: @@ -86,11 +78,11 @@ class Fixation(GazeFeatures.Fixation): class Saccade(GazeFeatures.Saccade): """Define dispersion based saccade.""" - def __init__(self, positions: GazeFeatures.TimeStampedGazePositions = (), finished: bool = False, message: str = None, **kwargs): + def __init__(self, positions: GazeFeatures.TimeStampedGazePositions = None, **kwargs): - super().__init__(positions, finished, message, **kwargs) + super().__init__(positions, **kwargs) - def draw(self, image: numpy.array, line_color: tuple = None): + def draw(self, image: numpy.array, line_color: tuple = None, draw_positions: dict = None): """Draw saccade into image. Parameters: @@ -106,6 +98,11 @@ class Saccade(GazeFeatures.Saccade): cv2.line(image, (int(start_position[0]), int(start_position[1])), (int(last_position[0]), int(last_position[1])), line_color, 2) + # Draw positions if required + if draw_positions is not None: + + self.draw_positions(image, **draw_positions) + class GazeMovementIdentifier(GazeFeatures.GazeMovementIdentifier): """Implementation of the I-DT algorithm as described in: @@ -125,8 +122,12 @@ class GazeMovementIdentifier(GazeFeatures.GazeMovementIdentifier): self.__duration_min_threshold = 0 self.__valid_positions = GazeFeatures.TimeStampedGazePositions() - self.__fixation_positions = GazeFeatures.TimeStampedGazePositions() - self.__saccade_positions = GazeFeatures.TimeStampedGazePositions() + + self.__centroid = () + self.__deviations = [] + + self.__fixation = Fixation() + self.__saccade = Saccade() @property def deviation_max_threshold(self) -> int|float: @@ -151,127 +152,140 @@ class GazeMovementIdentifier(GazeFeatures.GazeMovementIdentifier): @DataFeatures.PipelineStepMethod def identify(self, gaze_position, terminate=False) -> GazeFeatures.GazeMovement: - - # Ignore empty gaze position - if not gaze_position: - return GazeFeatures.GazeMovement() if not terminate else self.current_fixation().finish() - - # Check if too much time elapsed since last valid gaze position + # When too much time elapsed since the last valid gaze position if self.__valid_positions: - ts_last = self.__valid_positions[-1].timestamp + elapsed_time = gaze_position.timestamp - self.__valid_positions[-1].timestamp + + if elapsed_time > self.__duration_min_threshold: - if (gaze_position.timestamp - ts_last) > self.__duration_min_threshold: + try: - # Get last movement - last_movement = self.current_gaze_movement().finish() + # Finish and return current gaze movement + return self.current_gaze_movement().finish() - # Clear all former gaze positions - self.__valid_positions = GazeFeatures.TimeStampedGazePositions() - self.__fixation_positions = GazeFeatures.TimeStampedGazePositions() - self.__saccade_positions = GazeFeatures.TimeStampedGazePositions() + finally: - # Store valid gaze position - self.__valid_positions.append(gaze_position) + # Reset valid gaze positions + self.__valid_positions = GazeFeatures.TimeStampedGazePositions() - # Return last valid movement if exist - return last_movement + # Reset centroid and deviations + self.__centroid = () + self.__deviations = [] - # Store gaze positions until a minimal duration - self.__valid_positions.append(gaze_position) + # Clear gaze movements + self.__fixation = Fixation() + self.__saccade = Saccade() - # Once the minimal duration is reached - if self.__valid_positions.duration >= self.__duration_min_threshold: + # Consider only valid gaze position + if gaze_position: - # Calculate the deviation of valid gaze positions - deviation = Fixation(self.__valid_positions).deviation_max + # Update centroid and deviations + if not self.__valid_positions: - # Valid gaze positions deviation small enough - if deviation <= self.__deviation_max_threshold: + self.__centroid = gaze_position + self.__deviations = [] - last_saccade = GazeFeatures.GazeMovement() + else: - # Is there saccade positions? - if self.__saccade_positions: + self.__centroid = self.__centroid + (gaze_position - self.__centroid) / (len(self.__valid_positions) + 1) + self.__deviations.append(gaze_position.distance(self.__centroid)) - # Copy oldest valid position into saccade positions - self.__saccade_positions.append(self.__valid_positions[0]) + # Store valid gaze position + self.__valid_positions.append(gaze_position) - # Finish last saccade - last_saccade = self.current_saccade().finish() + # Once the minimal duration is reached + if self.__valid_positions.duration >= self.__duration_min_threshold: - # Clear saccade positions - self.__saccade_positions = GazeFeatures.TimeStampedGazePositions() + deviation_max = max(self.__deviations) - # Copy valid gaze positions into fixation positions - self.__fixation_positions = self.__valid_positions.copy() + # Maximal deviation small enough + if deviation_max <= self.__deviation_max_threshold: - # Output last saccade - return last_saccade if not terminate else self.current_fixation().finish() - - # Valid gaze positions deviation too wide - else: + # Make valid gaze positions as current fixation + self.__fixation = Fixation(positions=self.__valid_positions, focus=self.__centroid, deviation_max=deviation_max) - last_fixation = GazeFeatures.GazeMovement() + # Is there a current saccade? + if self.__saccade: - # Is there fixation positions? - if self.__fixation_positions: + try: - # Copy most recent fixation position into saccade positions - self.__saccade_positions.append(self.__fixation_positions[-1]) + # Share first fixation position with current saccade + self.__saccade.append(self.__fixation[0]) - # Finish last fixation - last_fixation = self.current_fixation().finish() + # Finish and return the current saccade + return self.__saccade.finish() - # Clear fixation positions - self.__fixation_positions = GazeFeatures.TimeStampedGazePositions() + # Clear saccade after the return + finally: - # Clear valid positions - self.__valid_positions = GazeFeatures.TimeStampedGazePositions() + self.__saccade = Saccade() + + # Maximal deviation too wide + else: - # Store current gaze position - self.__valid_positions.append(gaze_position) + # Is there a current fixation? + if self.__fixation: - # Output last fixation - return last_fixation if not terminate else self.current_saccade().finish() + try: - # Move oldest valid position into saccade positions - self.__saccade_positions.append(self.__valid_positions.pop(0)) + # Share last fixation position with current saccade + self.__saccade.append(self.__fixation[-1]) - # Always return empty gaze movement at least - return GazeFeatures.GazeMovement() + # Clear valid positions + self.__valid_positions = GazeFeatures.TimeStampedGazePositions() + + # Finish and return the current fixation + return self.__fixation.finish() + + # Clear fixation after the return + finally: + + self.__fixation = Fixation() + + # No fixation case: + + # Remove oldest valid position + old_gaze_position = self.__valid_positions.pop(0) + + # Move oldest valid position into current saccade + self.__saccade.append(old_gaze_position) + + # Update centroid and deviations + self.__centroid = self.__centroid - (old_gaze_position - self.__centroid) / (len(self.__valid_positions) + 1) + self.__deviations.pop(0) + + # Return current gaze movement + return self.current_gaze_movement() if not terminate else self.current_gaze_movement().finish() def current_gaze_movement(self) -> GazeFeatures.GazeMovement: - # It shouldn't have a current fixation and a current saccade at the same time - assert(not (self.__fixation_positions and len(self.__saccade_positions) > 1)) + if self.__fixation: - if self.__fixation_positions: + return self.__fixation - return Fixation(self.__fixation_positions) + if len(self.__saccade) > 1: - if len(self.__saccade_positions) > 1: - - return Saccade(self.__saccade_positions) + return self.__saccade # Always return empty gaze movement at least return GazeFeatures.GazeMovement() def current_fixation(self) -> GazeFeatures.GazeMovement: - if self.__fixation_positions: + if self.__fixation: - return Fixation(self.__fixation_positions) + return self.__fixation # Always return empty gaze movement at least return GazeFeatures.GazeMovement() def current_saccade(self) -> GazeFeatures.GazeMovement: - if len(self.__saccade_positions) > 1: - - return Saccade(self.__saccade_positions) + if len(self.__saccade) > 1: + + return self.__saccade # Always return empty gaze movement at least return GazeFeatures.GazeMovement() -- cgit v1.1