#!/usr/bin/env python from typing import TypeVar, Tuple from dataclasses import dataclass, field import math import ast import json from argaze import DataStructures import numpy import pandas import cv2 as cv @dataclass(frozen=True) class GazePosition(): """Define gaze position as a tuple of coordinates with precision.""" value: tuple[int | float] = field(default=(0, 0)) """Position's value.""" precision: float = field(default=0., kw_only=True) """Position's precision represents the radius of a circle around \ this gaze position value where other same gaze position measurements could be.""" def __getitem__(self, axis: int) -> int | float: """Get position value along a particular axis.""" return self.value[axis] def __iter__(self) -> iter: """Iterate over each position value axis.""" return iter(self.value) def __len__(self) -> int: """Number of axis in position value.""" return len(self.value) def __repr__(self): """String representation""" return json.dumps(self, ensure_ascii = False, default=vars) def __array__(self): """Cast as numpy array.""" return numpy.array(self.value) @property def valid(self) -> bool: """Is the precision not None?""" return self.precision is not None def overlap(self, gaze_position, both=False) -> float: """Does this gaze position overlap another gaze position considering its precision? Set both to True to test if the other gaze position overlaps this one too.""" dist = (self.value[0] - gaze_position.value[0])**2 + (self.value[1] - gaze_position.value[1])**2 dist = numpy.sqrt(dist) if both: return dist < min(self.precision, gaze_position.precision) else: return dist < self.precision def draw(self, frame, color=(0, 255, 255), draw_precision=True): """Draw gaze position point and precision circle.""" if self.valid: int_value = (int(self.value[0]), int(self.value[1])) # Draw point at position cv.circle(frame, int_value, 2, color, -1) # Draw precision circle if self.precision > 0 and draw_precision: cv.circle(frame, int_value, round(self.precision), color, 1) class UnvalidGazePosition(GazePosition): """Unvalid gaze position.""" def __init__(self, message=None): self.message = message super().__init__((None, None), precision=None) TimeStampedGazePositionsType = TypeVar('TimeStampedGazePositions', bound="TimeStampedGazePositions") # Type definition for type annotation convenience class TimeStampedGazePositions(DataStructures.TimeStampedBuffer): """Define timestamped buffer to store gaze positions.""" def __setitem__(self, key, value: GazePosition|dict): """Force GazePosition storage.""" # Convert dict into GazePosition if type(value) == dict: assert(set(['value', 'precision']).issubset(value.keys())) if 'message' in value.keys(): value = UnvalidGazePosition(value['message']) else: value = GazePosition(value['value'], precision=value['precision']) assert(type(value) == GazePosition or type(value) == UnvalidGazePosition) super().__setitem__(key, value) @classmethod def from_json(self, json_filepath: str) -> TimeStampedGazePositionsType: """Create a TimeStampedGazePositionsType from .json file.""" with open(json_filepath, encoding='utf-8') as ts_buffer_file: json_buffer = json.load(ts_buffer_file) return TimeStampedGazePositions({ast.literal_eval(ts_str): json_buffer[ts_str] for ts_str in json_buffer}) GazeMovementType = TypeVar('GazeMovement', bound="GazeMovement") # Type definition for type annotation convenience @dataclass(frozen=True) class GazeMovement(): """Define abstract gaze movement class as a buffer of timestamped positions.""" positions: TimeStampedGazePositions """All timestamp gaze positions.""" duration: float = field(init=False) """Inferred duration from first and last timestamps.""" distance: float = field(init=False) """Inferred distance from first and last positions.""" def __post_init__(self): start_position_ts, start_position = self.positions.first end_position_ts, end_position = self.positions.last # Update frozen duration attribute object.__setattr__(self, 'duration', end_position_ts - start_position_ts) _, start_position = self.positions.first _, end_position = self.positions.last distance = numpy.linalg.norm( numpy.array(start_position.value) - numpy.array(end_position.value)) # Update frozen distance attribute object.__setattr__(self, 'distance', distance) def __str__(self) -> str: """String display""" output = f'{type(self)}:\n\tduration={self.duration}\n\tsize={len(self.positions)}' for ts, position in self.positions.items(): output += f'\n\t{ts}:\n\t\tvalue={position.value},\n\t\taccurracy={position.precision}' return output class Fixation(GazeMovement): """Define abstract fixation as gaze movement.""" def __post_init__(self): super().__post_init__() class Saccade(GazeMovement): """Define abstract saccade as gaze movement.""" def __post_init__(self): super().__post_init__() TimeStampedGazeMovementsType = TypeVar('TimeStampedGazeMovements', bound="TimeStampedGazeMovements") # Type definition for type annotation convenience class TimeStampedGazeMovements(DataStructures.TimeStampedBuffer): """Define timestamped buffer to store gaze movements.""" def __setitem__(self, key, value: GazeMovement): """Force value to be or inherit from GazeMovement.""" assert(isinstance(value, GazeMovement) or type(value).__bases__[0] == Fixation or type(value).__bases__[0] == Saccade) super().__setitem__(key, value) def __str__(self): output = '' for ts, item in self.items(): output += f'\n{item}' return output GazeStatusType = TypeVar('GazeStatus', bound="GazeStatus") # Type definition for type annotation convenience @dataclass(frozen=True) class GazeStatus(GazePosition): """Define gaze status as a gaze position belonging to an identified and indexed gaze movement.""" movement_type: str = field(kw_only=True) """GazeMovement type to which gaze position belongs.""" movement_index: int = field(kw_only=True) """GazeMovement index to which gaze positon belongs.""" @classmethod def from_position(cls, gaze_position: GazePosition, movement_type: str, movement_index: int) -> GazeStatusType: """Initialize from a gaze position instance.""" return cls(gaze_position.value, precision=gaze_position.precision, movement_type=movement_type, movement_index=movement_index) TimeStampedGazeStatusType = TypeVar('TimeStampedGazeStatus', bound="TimeStampedGazeStatus") # Type definition for type annotation convenience class TimeStampedGazeStatus(DataStructures.TimeStampedBuffer): """Define timestamped buffer to store gaze status.""" def __setitem__(self, key, value: GazeStatus): super().__setitem__(key, value) class GazeMovementIdentifier(): """Abstract class to define what should provide a gaze movement identifier.""" def identify(self, ts, gaze_position, terminate=False) -> GazeMovementType: """Identify gaze movement from successive timestamped gaze positions. The optional *terminate* argument allows to notify identification algorithm that given gaze position will be the last one. """ raise NotImplementedError('identify() method not implemented') def browse(self, ts_gaze_positions: TimeStampedGazePositions) -> Tuple[TimeStampedGazeMovementsType, TimeStampedGazeMovementsType, TimeStampedGazeStatusType]: """Identify fixations and saccades browsing timestamped gaze positions.""" assert(type(ts_gaze_positions) == TimeStampedGazePositions) ts_fixations = TimeStampedGazeMovements() ts_saccades = TimeStampedGazeMovements() ts_status = TimeStampedGazeStatus() # Get last ts to terminate identification on last gaze position last_ts, _ = ts_gaze_positions.last # Iterate on gaze positions for ts, gaze_position in ts_gaze_positions.items(): gaze_movement = self.identify(ts, gaze_position, terminate=(ts == last_ts)) if isinstance(gaze_movement, Fixation): start_ts, start_position = gaze_movement.positions.first ts_fixations[start_ts] = gaze_movement for ts, position in gaze_movement.positions.items(): ts_status[ts] = GazeStatus.from_position(position, 'Fixation', len(ts_fixations)) elif isinstance(gaze_movement, Saccade): start_ts, start_position = gaze_movement.positions.first ts_saccades[start_ts] = gaze_movement for ts, position in gaze_movement.positions.items(): ts_status[ts] = GazeStatus.from_position(position, 'Saccade', len(ts_saccades)) else: continue return ts_fixations, ts_saccades, ts_status