Customization
Because algebra problems represent only a tiny sliver of the uses for math expression trees, Mathy Envs has customization points to alter or create entirely new environments with little effort.
Let's consider a few examples:
New Problems¶
Generating a new problem type while subclassing a base environment is likely the simplest way to create a custom challenge for the agent.
You can inherit from a base environment like Poly Simplify, which has win-conditions that require all the like-terms to be gone from an expression and all complex terms to be simplified. From there, you can provide any valid input expression:
from mathy_envs import MathyEnv, MathyEnvProblem, MathyEnvProblemArgs
class CustomSimplifyEnv(MathyEnv):
def get_env_namespace(self) -> str:
return "custom.polynomial.simplify"
def problem_fn(self, params: MathyEnvProblemArgs) -> MathyEnvProblem:
return MathyEnvProblem("4x + y + 13x", 3, self.get_env_namespace())
env: MathyEnv = CustomSimplifyEnv()
state, problem = env.get_initial_state()
assert problem.text == "4x + y + 13x"
assert problem.complexity == 3
New Actions¶
Build your tree transformation actions and use them with the built-in agents:
"""Environment with user-defined actions"""
from mathy_core import AddExpression, BaseRule, NegateExpression, SubtractExpression
from mathy_envs import MathyEnv, MathyEnvState, envs
class PlusNegationRule(BaseRule):
"""Convert subtract operators to plus negative to allow commuting"""
@property
def name(self) -> str:
return "Plus Negation"
@property
def code(self) -> str:
return "PN"
def can_apply_to(self, node) -> bool:
is_sub = isinstance(node, SubtractExpression)
is_parent_add = isinstance(node.parent, AddExpression)
return is_sub and (node.parent is None or is_parent_add)
def apply_to(self, node):
change = super().apply_to(node)
change.save_parent() # connect result to node.parent
result = AddExpression(node.left, NegateExpression(node.right))
result.set_changed() # mark this node as changed for visualization
return change.done(result)
class CustomActionEnv(envs.PolySimplify):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.rules = MathyEnv.core_rules() + [PlusNegationRule()]
env = CustomActionEnv()
state = MathyEnvState(problem="4x - 2x")
expression = env.parser.parse(state.agent.problem)
action = env.random_action(expression, PlusNegationRule)
out_state, transition, _ = env.get_next_state(state, action)
assert out_state.agent.problem == "4x + -2x"
Custom Win Conditions¶
Environments can implement custom logic for win conditions or inherit them from a base class:
"""Custom environment with win conditions that are met whenever
two nodes are adjacent to each other that can have the distributive
property applied to factor out a common term """
from typing import Optional
from mathy_core import MathExpression, rules
from mathy_envs import (
MathyEnv,
MathyEnvState,
MathyObservation,
is_terminal_transition,
time_step,
)
class CustomWinConditions(MathyEnv):
rule = rules.DistributiveFactorOutRule()
def transition_fn(
self,
env_state: MathyEnvState,
expression: MathExpression,
features: MathyObservation,
) -> Optional[time_step.TimeStep]:
# If the rule can find any applicable nodes
if self.rule.find_node(expression) is not None:
# Return a terminal transition with reward
return time_step.termination(features, self.get_win_signal(env_state))
# None does nothing
return None
env = CustomWinConditions()
# This state is not terminal because none of the nodes can have the distributive
# factoring rule applied to them.
state_one = MathyEnvState(problem="4x + y + 2x")
transition = env.get_state_transition(state_one)
assert is_terminal_transition(transition) is False
# This is a terminal state because the nodes representing "4x + 2x" can
# have the distributive factoring rule applied to them.
state_two = MathyEnvState(problem="4x + 2x + y")
transition = env.get_state_transition(state_two)
assert is_terminal_transition(transition) is True
Custom Timestep Rewards¶
Specify which actions to give the agent positive and negative rewards:
"""Environment with user-defined rewards per-timestep based on the
rule that was applied by the agent."""
from typing import List, Type
from mathy_core import BaseRule, rules
from mathy_envs import MathyEnv, MathyEnvState
class CustomTimestepRewards(MathyEnv):
def get_rewarding_actions(self, state: MathyEnvState) -> List[Type[BaseRule]]:
return [rules.AssociativeSwapRule]
def get_penalizing_actions(self, state: MathyEnvState) -> List[Type[BaseRule]]:
return [rules.CommutativeSwapRule]
env = CustomTimestepRewards()
problem = "4x + y + 2x"
expression = env.parser.parse(problem)
state = MathyEnvState(problem=problem)
action = env.random_action(expression, rules.AssociativeSwapRule)
_, transition, _ = env.get_next_state(state, action,)
# Expect positive reward
assert transition.reward > 0.0
_, transition, _ = env.get_next_state(
state, env.random_action(expression, rules.CommutativeSwapRule),
)
# Expect neagative reward
assert transition.reward < 0.0
Custom Episode Rewards¶
Specify (or calculate) custom floating-point episode rewards:
"""Environment with user-defined terminal rewards"""
from mathy_core.rules import ConstantsSimplifyRule
from mathy_envs import MathyEnvState, envs, is_terminal_transition
class CustomEpisodeRewards(envs.PolySimplify):
def get_win_signal(self, env_state: MathyEnvState) -> float:
return 20.0
def get_lose_signal(self, env_state: MathyEnvState) -> float:
return -20.0
env = CustomEpisodeRewards()
# Win by simplifying constants and yielding a single simple term form
state = MathyEnvState(problem="(4 + 2) * x")
expression = env.parser.parse(state.agent.problem)
action = env.random_action(expression, ConstantsSimplifyRule)
out_state, transition, _ = env.get_next_state(state, action)
assert is_terminal_transition(transition) is True
assert transition.reward == 20.0
assert out_state.agent.problem == "6x"
# Lose by applying a rule with only 1 move remaining
state = MathyEnvState(problem="2x + (4 + 2) + 4x", max_moves=1)
expression = env.parser.parse(state.agent.problem)
action = env.random_action(expression, ConstantsSimplifyRule)
out_state, transition, _ = env.get_next_state(state, action)
assert is_terminal_transition(transition) is True
assert transition.reward == -20.0
assert out_state.agent.problem == "2x + 6 + 4x"