Skip to content

Observations

Mathy Env observations contain a rich set of features, and can optionally be normalized to the range 0-1.

Let's quickly review how we go from AsciiMath text inputs to a set of feature for neural networks to consume:

Text -> Observation

Mathy Core processes an input problem by parsing its text into a tree, then mathy envs convert that into a sequence of nodes/values, and finally those features are concatenated with the current environment time, type, and valid action mask.

Text to Tree

A problem text is encoded into tokens, then parsed into a tree that preserves the order of operations while removing parentheses and whitespace. Consider the tokens and tree that result from the input: -3 * (4 + 7)

Tokens

- 8 3 1 * 16 ( 256 4 1 + 4 7 1 ) 512 8192

Tree

-3 4 7 + *

Please observe that the tree representation is more concise than the tokens array because it doesn't have nodes for hierarchical features like parentheses.

Converting text to trees is accomplished with the expression parser:

Open Example In Colab

from typing import List

from mathy_core import ExpressionParser, MathExpression, Token, VariableExpression

problem = "4 + 2x"
parser = ExpressionParser()
tokens: List[Token] = parser.tokenize(problem)
expression: MathExpression = parser.parse(problem)
assert len(expression.find_type(VariableExpression)) == 1

Tree to List

Rather than expose tree structures to environments, we traverse them to produce node/value lists.

tree list ordering

You might have noticed that the previous tree features are not expressed in the natural order we might read. As observed by Lample and Charton trees must be visited in an order that preserves the order-of-operations so that the model can pick up on the hierarchical features of the input.

For this reason, we visit trees in pre order for serialization.

Converting math expression trees to lists is done with a helper:

Open Example In Colab

from typing import List

from mathy_core import ExpressionParser, MathExpression

parser = ExpressionParser()
expression: MathExpression = parser.parse("4 + 2x")
nodes: List[MathExpression] = expression.to_list()
# len([4,+,2,*,x])
assert len(nodes) == 5

Lists to Observations

Mathy turns a list of math expression nodes into a feature list that captures the input characteristics. Specifically, mathy converts a node list into two lists, one with node types and another with node values:

* 0.3 0.2857142857142857 -3 0.0 1.0 + 0.3 0.0 4 0.7 1.0 7 1.0 1.0

  • The first row contains input token characters stripped of whitespace and parentheses.
  • The second row is the sequence of floating-point node values for the tree, with each non-constant node represented by a mask value.
  • The third row is the node type integer representing the node's class in the tree.

While feature lists may be directly passable to an ML model, they don't include any information about the problem's state over time. To work with information over time, mathy agents draw extra information from the environment when building observations. This additional information includes:

  • Environment Problem Type: environments all specify an environment namespace that is converted into a pair of hashed string values using different random seeds.
  • Episode Relative Time: each observation can see a 0-1 floating-point value that indicates how close the agent is to running out of moves.
  • Valid Action Mask: mathy gives weighted estimates for each action at every node. If there are five possible actions and ten nodes in the tree, there are up to 50 possible actions. A same-sized (e.g., 50) mask of 0/1 values is provided so the model can mask out nodes with no valid actions when returning probability distributions.

Mathy has utilities for making the conversion:

Open Example In Colab

from mathy_envs import MathyEnv, MathyEnvState, MathyObservation, envs

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
observation: MathyObservation = env.state_to_observation(state)

# As many nodes as values
assert len(observation.nodes) == len(observation.values)
# Mask is a binary validity mask of size (num_rules, num_nodes)
assert len(observation.mask) == len(env.rules)
assert len(observation.mask[0]) == len(observation.nodes)