diff options
-rw-r--r-- | src/argaze/GazeFeatures.py | 37 |
1 files changed, 23 insertions, 14 deletions
diff --git a/src/argaze/GazeFeatures.py b/src/argaze/GazeFeatures.py index bc1b91e..a7f0f11 100644 --- a/src/argaze/GazeFeatures.py +++ b/src/argaze/GazeFeatures.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field import math +import json from argaze import DataStructures from argaze.AreaOfInterest import AOIFeatures @@ -10,26 +11,21 @@ import numpy import pandas import cv2 as cv -@dataclass +@dataclass(frozen=True) class GazePosition(): """Define gaze position as a tuple of coordinates with accuracy.""" - value: tuple[int | float] = (0, 0) - """Positon value.""" + value: tuple[int | float] = field(default=(0, 0)) + """Position's value.""" + + accuracy: float = field(default=0., kw_only=True) + """Position's accuracy.""" - accuracy: float = 0. - """Positon accuracy.""" - def __getitem__(self, axis: int) -> int | float: """Get position value along a particular axis.""" return self.value[axis] - def __setitem__(self, axis, axis_value: int | float): - """Set position value along a particular axis.""" - - self.value[axis] = axis_value - def __iter__(self) -> iter: """Iterate over each position value axis.""" @@ -40,15 +36,20 @@ class GazePosition(): return len(self.value) + def __repr__(self): + """String representation""" + + return json.dumps(self, ensure_ascii = False, default=vars) + def __array__(self): """Cast as numpy array.""" return numpy.array(self.value) @property def valid(self) -> bool: - """Is the accuracy greater than 0 ?""" + """Is the accuracy not None?""" - return self.accuracy >= 0 + return self.accuracy is not None def draw(self, frame, color=(0, 255, 255)): """Draw gaze position point and accuracy circle.""" @@ -67,12 +68,16 @@ class UnvalidGazePosition(GazePosition): def __init__(self): - super().__init__((math.nan, math.nan), accuracy=-1) + super().__init__((None, None), accuracy=None) class TimeStampedGazePositions(DataStructures.TimeStampedBuffer): """Define timestamped buffer to store gaze positions.""" def __setitem__(self, key, value: GazePosition): + """Force value to be GazePosition.""" + + assert(type(value) == GazePosition or type(value) == UnvalidGazePosition) + super().__setitem__(key, value) @dataclass @@ -102,6 +107,10 @@ class TimeStampedMovements(DataStructures.TimeStampedBuffer): """Define timestamped buffer to store movements.""" def __setitem__(self, key, value: Movement): + """Force value to inherit from Movement.""" + + assert(type(value).__bases__[0] == Movement) + super().__setitem__(key, value) @dataclass |