Observations
Mathy Env observations contain a rich set of features that can be represented in multiple formats to suit different neural network architectures. All observation types can optionally be normalized to the range 0-1.
Let's quickly review how we go from AsciiMath text inputs to different observation formats for neural networks to consume:
Text -> Tree -> Observations¶
Mathy Core processes an input problem by parsing its text into a tree, then mathy envs convert that tree into different observation formats depending on your model architecture needs.
Text to Tree¶
A problem text is encoded into tokens, then parsed into a tree that preserves the order of operations while removing parentheses and whitespace. Consider the tokens and tree that result from the input: -3 * (4 + 7)
Tokens
Tree
Please observe that the tree representation is more concise than the tokens array because it doesn't have nodes for hierarchical features like parentheses.
Converting text to trees is accomplished with the expression parser:
from typing import List
from mathy_core import ExpressionParser, MathExpression, Token, VariableExpression
problem = "4 + 2x"
parser = ExpressionParser()
tokens: List[Token] = parser.tokenize(problem)
expression: MathExpression = parser.parse(problem)
assert len(expression.find_type(VariableExpression)) == 1
Observation Types¶
Mathy supports four different observation formats to accommodate different neural network architectures:
1. Flat Observations (Default)¶
The original observation format that represents the expression as a flat sequence of features. This format first converts the tree to ordered lists, then concatenates all information into a single 1D array.
Use cases: Traditional MLPs, RNNs, transformers expecting sequential input
Tree to List Conversion¶
Rather than expose tree structures directly, flat observations traverse them to produce node/value lists.
tree list ordering
You might have noticed that the tree features are not expressed in the natural order we might read. As observed by Lample and Charton trees must be visited in an order that preserves the order-of-operations so that the model can pick up on the hierarchical features of the input.
For this reason, we visit trees in pre order for serialization.
Converting math expression trees to lists is done with a helper:
from typing import List
from mathy_core import ExpressionParser, MathExpression
parser = ExpressionParser()
expression: MathExpression = parser.parse("4 + 2x")
nodes: List[MathExpression] = expression.to_list()
# len([4,+,2,*,x])
assert len(nodes) == 5
Flat Observation Structure¶
Features included:
- Problem type hash (2 values)
- Relative episode time (1 value)
- Node types (padded sequence)
- Node values (padded sequence)
- Action mask (flattened)
from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import ObservationType
env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
# Flat observation (default)
flat_obs = state.to_observation(obs_type=ObservationType.FLAT, max_seq_len=100)
print(f"Flat observation shape: {flat_obs.shape}")
# Contains: [type_hash, time, nodes..., values..., action_mask...]
2. Graph Observations¶
Represents the mathematical expression as an adjacency matrix with node features, suitable for Graph Convolutional Networks (GCNs) and similar architectures. This format works directly with the tree structure.
Use cases: Graph Convolutional Networks, Graph Attention Networks, predictive coding models
Structure:
node_features:[type_id, value, time, is_leaf]for each nodeadjacency: Binary matrix encoding parent-child relationshipsaction_mask: Valid actions at each nodenum_nodes: Actual number of nodes (before padding)
Edge semantics: Parent nodes connect to their children, preserving mathematical precedence
# graph_observations.py
from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyGraphObservation, ObservationType
env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
graph_obs = state.to_observation(obs_type=ObservationType.GRAPH, max_seq_len=100)
assert isinstance(graph_obs, MathyGraphObservation)
print(f"Node features shape: {graph_obs.node_features.shape}") # (100, 4)
print(f"Adjacency shape: {graph_obs.adjacency.shape}") # (100, 100)
print(f"Action mask length: {len(graph_obs.action_mask)}") # num_rules * 100
print(f"Actual nodes: {graph_obs.num_nodes}")
# Check adjacency connections
actual_adj = graph_obs.adjacency[:graph_obs.num_nodes, :graph_obs.num_nodes]
print(f"Graph has {actual_adj.sum()} connections")
3. Hierarchical Observations¶
Groups nodes by their depth in the expression tree, enabling models to process expressions level by level. This format preserves the tree structure while organizing nodes by hierarchy.
Use cases: Hierarchical processing models, predictive coding architectures, models that benefit from depth-aware processing
Structure:
node_features:[type_id, value, time, is_leaf]for each nodelevel_indices: Tree depth for each nodeaction_mask: Valid actions at each nodemax_depth: Maximum tree depthnum_nodes: Actual number of nodes
Organization: Nodes are ordered by tree level, allowing models to process expressions hierarchically
from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyHierarchicalObservation, ObservationType
env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
hier_obs = state.to_observation(obs_type=ObservationType.HIERARCHICAL, max_seq_len=100)
assert isinstance(hier_obs, MathyHierarchicalObservation)
print(f"Node features shape: {hier_obs.node_features.shape}") # (100, 4)
print(f"Level indices shape: {hier_obs.level_indices.shape}") # (100,)
print(f"Max depth: {hier_obs.max_depth}")
print(f"Actual nodes: {hier_obs.num_nodes}")
# Show node distribution by level
actual_levels = hier_obs.level_indices[: hier_obs.num_nodes]
for level in range(hier_obs.max_depth + 1):
count = sum(1 for l in actual_levels if l == level)
print(f"Level {level}: {count} nodes")
4. Message Passing Observations¶
Formats expressions for PyTorch Geometric-style Graph Neural Networks with explicit edge lists and edge types. This format also works directly with the tree structure.
Use cases: PyTorch Geometric models, Graph Neural Networks with edge features, message passing architectures
Structure:
node_features:[type_id, value, time, is_leaf]for each nodeedge_index: PyG-format edge list(2, num_edges)edge_types: Edge type for each edge (0=left child, 1=right child)action_mask: Valid actions at each nodenum_nodes: Actual number of nodesnum_edges: Actual number of edges
Edge types:
0: Parent → left child relationship1: Parent → right child relationship
from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyMessagePassingObservation, ObservationType
env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
mp_obs = state.to_observation(obs_type=ObservationType.MESSAGE_PASSING, max_seq_len=100)
assert isinstance(mp_obs, MathyMessagePassingObservation)
print(f"Node features shape: {mp_obs.node_features.shape}") # (100, 4)
print(f"Edge index shape: {mp_obs.edge_index.shape}") # (2, 200)
print(f"Edge types shape: {mp_obs.edge_types.shape}") # (200,)
print(f"Actual nodes: {mp_obs.num_nodes}")
print(f"Actual edges: {mp_obs.num_edges}")
# Show edge type distribution
actual_edge_types = mp_obs.edge_types[: mp_obs.num_edges]
left_edges = sum(1 for t in actual_edge_types if t == 0)
right_edges = sum(1 for t in actual_edge_types if t == 1)
print(f"Left child edges: {left_edges}, Right child edges: {right_edges}")
Unified Observation Interface¶
All observation types are accessible through a single interface:
from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import ObservationType
env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
# All observation types through unified interface
obs_types = [
ObservationType.FLAT,
ObservationType.GRAPH,
ObservationType.HIERARCHICAL,
ObservationType.MESSAGE_PASSING,
]
for obs_type in obs_types:
obs = state.to_observation(obs_type=obs_type, max_seq_len=100, normalize=True)
print(f"{obs_type.value}: {type(obs).__name__}")
Common Features Across All Types¶
All observation formats share these characteristics:
Node Features¶
Every observation type uses consistent node features: [type_id, value, time, is_leaf]
- type_id: Integer representing the node's mathematical operation or value type
- value: Floating-point value for constants, 0.0 for operators
- time: Normalized episode progress (0.0 = start, 1.0 = end)
- is_leaf: Binary indicator (1.0 for leaf nodes, 0.0 for operators)
Action Mask¶
The action mask format is identical across all observation types:
- Flattened array of size
num_rules × max_seq_len - Binary values: 1.0 = valid action, 0.0 = invalid action
- Represents which transformation rules can be applied to which nodes
Normalization¶
When normalize=True (default), all features are scaled to the range [0, 1]:
- Node types and values are min-max normalized
- Time features are already normalized (episode progress)
- Action masks remain binary (0/1)
Padding¶
All observations are padded to max_seq_len to enable batching:
- Node features padded with zeros
- Adjacency matrices padded to square matrices
- Edge lists padded with dummy edges
- Action masks padded to consistent size
Choosing an Observation Type¶
| Architecture | Recommended Type | Reason |
|---|---|---|
| MLP, RNN, Transformer | Flat | Sequential processing of flattened features |
| Graph Convolutional Network | Graph | Native adjacency matrix representation |
| Graph Attention Network | Graph | Node features + adjacency for attention |
| PyTorch Geometric GNN | Message Passing | Optimized edge list format |
| Hierarchical/Tree LSTM | Hierarchical | Depth-aware processing |
| Predictive Coding Models | Hierarchical or Graph | Level-based or parent-child predictions |
Example Usage¶
Basic Observation Creation¶
from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyGraphObservation, ObservationType
env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
# Create graph observation with consistent features
observation = state.to_observation(
obs_type=ObservationType.GRAPH, max_seq_len=100, normalize=True
)
# Narrow down to the specific observation type
assert isinstance(observation, MathyGraphObservation)
# Check the consistent feature format: [type_id, value, time, is_leaf]
print(f"Node features shape: {observation.node_features.shape}") # (100, 4)
print(f"Feature dimensions: type_id, value, time, is_leaf")
# Check actual number of nodes
assert observation.num_nodes > 0
assert observation.num_nodes <= 100 # should be within max_seq_len
# Check the observation structure
assert observation.adjacency is not None, "Adjacency matrix should not be None"
assert observation.node_features is not None, "Node features should not be None"
assert observation.action_mask is not None, "Action mask should not be None"
# Show actual node features for debugging
actual_features = observation.node_features[: observation.num_nodes]
print(
f"First node features: {actual_features[0] if len(actual_features) > 0 else 'No nodes'}"
)
Environment Integration¶
from mathy_envs import MathyEnv, envs
from mathy_envs.state import ObservationType
# Initialize environment
env: MathyEnv = envs.PolySimplify()
# Get initial state and valid moves
state, problem = env.get_initial_state()
valid_moves = env.get_valid_moves(state)
# Create observation with action mask
obs = state.to_observation(
move_mask=valid_moves,
obs_type=ObservationType.GRAPH,
max_seq_len=env.max_seq_len,
normalize=True,
)
print(f"Problem: {problem}")
print(f"Observation type: {type(obs).__name__}")
print(f"Valid actions available: {obs.action_mask.sum()}")
The observation system provides flexibility to match your model architecture while maintaining consistent feature representations across all formats.