Skip to content

Observations

Mathy Env observations contain a rich set of features that can be represented in multiple formats to suit different neural network architectures. All observation types can optionally be normalized to the range 0-1.

Let's quickly review how we go from AsciiMath text inputs to different observation formats for neural networks to consume:

Text -> Tree -> Observations

Mathy Core processes an input problem by parsing its text into a tree, then mathy envs convert that tree into different observation formats depending on your model architecture needs.

Text to Tree

A problem text is encoded into tokens, then parsed into a tree that preserves the order of operations while removing parentheses and whitespace. Consider the tokens and tree that result from the input: -3 * (4 + 7)

Tokens

- 8 3 1 * 16 ( 256 4 1 + 4 7 1 ) 512 8192

Tree

-3 4 7 + *

Please observe that the tree representation is more concise than the tokens array because it doesn't have nodes for hierarchical features like parentheses.

Converting text to trees is accomplished with the expression parser:

Open Example In Colab

from typing import List

from mathy_core import ExpressionParser, MathExpression, Token, VariableExpression

problem = "4 + 2x"
parser = ExpressionParser()
tokens: List[Token] = parser.tokenize(problem)
expression: MathExpression = parser.parse(problem)
assert len(expression.find_type(VariableExpression)) == 1

Observation Types

Mathy supports four different observation formats to accommodate different neural network architectures:

1. Flat Observations (Default)

The original observation format that represents the expression as a flat sequence of features. This format first converts the tree to ordered lists, then concatenates all information into a single 1D array.

Use cases: Traditional MLPs, RNNs, transformers expecting sequential input

Tree to List Conversion

Rather than expose tree structures directly, flat observations traverse them to produce node/value lists.

tree list ordering

You might have noticed that the tree features are not expressed in the natural order we might read. As observed by Lample and Charton trees must be visited in an order that preserves the order-of-operations so that the model can pick up on the hierarchical features of the input.

For this reason, we visit trees in pre order for serialization.

Converting math expression trees to lists is done with a helper:

Open Example In Colab

from typing import List

from mathy_core import ExpressionParser, MathExpression

parser = ExpressionParser()
expression: MathExpression = parser.parse("4 + 2x")
nodes: List[MathExpression] = expression.to_list()
# len([4,+,2,*,x])
assert len(nodes) == 5

Flat Observation Structure

Features included:

  • Problem type hash (2 values)
  • Relative episode time (1 value)
  • Node types (padded sequence)
  • Node values (padded sequence)
  • Action mask (flattened)

Open Example In Colab

from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import ObservationType

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]

# Flat observation (default)
flat_obs = state.to_observation(obs_type=ObservationType.FLAT, max_seq_len=100)
print(f"Flat observation shape: {flat_obs.shape}")
# Contains: [type_hash, time, nodes..., values..., action_mask...]

2. Graph Observations

Represents the mathematical expression as an adjacency matrix with node features, suitable for Graph Convolutional Networks (GCNs) and similar architectures. This format works directly with the tree structure.

Use cases: Graph Convolutional Networks, Graph Attention Networks, predictive coding models

Structure:

  • node_features: [type_id, value, time, is_leaf] for each node
  • adjacency: Binary matrix encoding parent-child relationships
  • action_mask: Valid actions at each node
  • num_nodes: Actual number of nodes (before padding)

Edge semantics: Parent nodes connect to their children, preserving mathematical precedence

Open Example In Colab

# graph_observations.py
from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyGraphObservation, ObservationType

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]

graph_obs = state.to_observation(obs_type=ObservationType.GRAPH, max_seq_len=100)
assert isinstance(graph_obs, MathyGraphObservation)

print(f"Node features shape: {graph_obs.node_features.shape}")  # (100, 4)
print(f"Adjacency shape: {graph_obs.adjacency.shape}")         # (100, 100)
print(f"Action mask length: {len(graph_obs.action_mask)}")     # num_rules * 100
print(f"Actual nodes: {graph_obs.num_nodes}")

# Check adjacency connections
actual_adj = graph_obs.adjacency[:graph_obs.num_nodes, :graph_obs.num_nodes]
print(f"Graph has {actual_adj.sum()} connections")

3. Hierarchical Observations

Groups nodes by their depth in the expression tree, enabling models to process expressions level by level. This format preserves the tree structure while organizing nodes by hierarchy.

Use cases: Hierarchical processing models, predictive coding architectures, models that benefit from depth-aware processing

Structure:

  • node_features: [type_id, value, time, is_leaf] for each node
  • level_indices: Tree depth for each node
  • action_mask: Valid actions at each node
  • max_depth: Maximum tree depth
  • num_nodes: Actual number of nodes

Organization: Nodes are ordered by tree level, allowing models to process expressions hierarchically

Open Example In Colab

from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyHierarchicalObservation, ObservationType

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]

hier_obs = state.to_observation(obs_type=ObservationType.HIERARCHICAL, max_seq_len=100)
assert isinstance(hier_obs, MathyHierarchicalObservation)

print(f"Node features shape: {hier_obs.node_features.shape}")  # (100, 4)
print(f"Level indices shape: {hier_obs.level_indices.shape}")  # (100,)
print(f"Max depth: {hier_obs.max_depth}")
print(f"Actual nodes: {hier_obs.num_nodes}")

# Show node distribution by level
actual_levels = hier_obs.level_indices[: hier_obs.num_nodes]
for level in range(hier_obs.max_depth + 1):
    count = sum(1 for l in actual_levels if l == level)
    print(f"Level {level}: {count} nodes")

4. Message Passing Observations

Formats expressions for PyTorch Geometric-style Graph Neural Networks with explicit edge lists and edge types. This format also works directly with the tree structure.

Use cases: PyTorch Geometric models, Graph Neural Networks with edge features, message passing architectures

Structure:

  • node_features: [type_id, value, time, is_leaf] for each node
  • edge_index: PyG-format edge list (2, num_edges)
  • edge_types: Edge type for each edge (0=left child, 1=right child)
  • action_mask: Valid actions at each node
  • num_nodes: Actual number of nodes
  • num_edges: Actual number of edges

Edge types:

  • 0: Parent → left child relationship
  • 1: Parent → right child relationship

Open Example In Colab

from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyMessagePassingObservation, ObservationType

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]

mp_obs = state.to_observation(obs_type=ObservationType.MESSAGE_PASSING, max_seq_len=100)
assert isinstance(mp_obs, MathyMessagePassingObservation)

print(f"Node features shape: {mp_obs.node_features.shape}")  # (100, 4)
print(f"Edge index shape: {mp_obs.edge_index.shape}")  # (2, 200)
print(f"Edge types shape: {mp_obs.edge_types.shape}")  # (200,)
print(f"Actual nodes: {mp_obs.num_nodes}")
print(f"Actual edges: {mp_obs.num_edges}")

# Show edge type distribution
actual_edge_types = mp_obs.edge_types[: mp_obs.num_edges]
left_edges = sum(1 for t in actual_edge_types if t == 0)
right_edges = sum(1 for t in actual_edge_types if t == 1)
print(f"Left child edges: {left_edges}, Right child edges: {right_edges}")

Unified Observation Interface

All observation types are accessible through a single interface:

Open Example In Colab

from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import ObservationType

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]

# All observation types through unified interface
obs_types = [
    ObservationType.FLAT,
    ObservationType.GRAPH,
    ObservationType.HIERARCHICAL,
    ObservationType.MESSAGE_PASSING,
]

for obs_type in obs_types:
    obs = state.to_observation(obs_type=obs_type, max_seq_len=100, normalize=True)
    print(f"{obs_type.value}: {type(obs).__name__}")

Common Features Across All Types

All observation formats share these characteristics:

Node Features

Every observation type uses consistent node features: [type_id, value, time, is_leaf]

  • type_id: Integer representing the node's mathematical operation or value type
  • value: Floating-point value for constants, 0.0 for operators
  • time: Normalized episode progress (0.0 = start, 1.0 = end)
  • is_leaf: Binary indicator (1.0 for leaf nodes, 0.0 for operators)

Action Mask

The action mask format is identical across all observation types:

  • Flattened array of size num_rules × max_seq_len
  • Binary values: 1.0 = valid action, 0.0 = invalid action
  • Represents which transformation rules can be applied to which nodes

Normalization

When normalize=True (default), all features are scaled to the range [0, 1]:

  • Node types and values are min-max normalized
  • Time features are already normalized (episode progress)
  • Action masks remain binary (0/1)

Padding

All observations are padded to max_seq_len to enable batching:

  • Node features padded with zeros
  • Adjacency matrices padded to square matrices
  • Edge lists padded with dummy edges
  • Action masks padded to consistent size

Choosing an Observation Type

Architecture Recommended Type Reason
MLP, RNN, Transformer Flat Sequential processing of flattened features
Graph Convolutional Network Graph Native adjacency matrix representation
Graph Attention Network Graph Node features + adjacency for attention
PyTorch Geometric GNN Message Passing Optimized edge list format
Hierarchical/Tree LSTM Hierarchical Depth-aware processing
Predictive Coding Models Hierarchical or Graph Level-based or parent-child predictions

Example Usage

Basic Observation Creation

Open Example In Colab

from mathy_envs import MathyEnv, MathyEnvState, envs
from mathy_envs.state import MathyGraphObservation, ObservationType

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]

# Create graph observation with consistent features
observation = state.to_observation(
    obs_type=ObservationType.GRAPH, max_seq_len=100, normalize=True
)

# Narrow down to the specific observation type
assert isinstance(observation, MathyGraphObservation)

# Check the consistent feature format: [type_id, value, time, is_leaf]
print(f"Node features shape: {observation.node_features.shape}")  # (100, 4)
print(f"Feature dimensions: type_id, value, time, is_leaf")

# Check actual number of nodes
assert observation.num_nodes > 0
assert observation.num_nodes <= 100  # should be within max_seq_len

# Check the observation structure
assert observation.adjacency is not None, "Adjacency matrix should not be None"
assert observation.node_features is not None, "Node features should not be None"
assert observation.action_mask is not None, "Action mask should not be None"

# Show actual node features for debugging
actual_features = observation.node_features[: observation.num_nodes]
print(
    f"First node features: {actual_features[0] if len(actual_features) > 0 else 'No nodes'}"
)

Environment Integration

Open Example In Colab

from mathy_envs import MathyEnv, envs
from mathy_envs.state import ObservationType

# Initialize environment
env: MathyEnv = envs.PolySimplify()

# Get initial state and valid moves
state, problem = env.get_initial_state()
valid_moves = env.get_valid_moves(state)

# Create observation with action mask
obs = state.to_observation(
    move_mask=valid_moves,
    obs_type=ObservationType.GRAPH,
    max_seq_len=env.max_seq_len,
    normalize=True,
)

print(f"Problem: {problem}")
print(f"Observation type: {type(obs).__name__}")
print(f"Valid actions available: {obs.action_mask.sum()}")

The observation system provides flexibility to match your model architecture while maintaining consistent feature representations across all formats.