Skip to content

Policies

Greedy Policy

GraphGreedyPolicy

Bases: BasePolicy

Greedy policy for graph environment using shortest path.

Compatible only with graph_collector environment.

Note

This is very slow for non-static graphs as shortest paths are computed in O(V^3) time for every action search. Static graphs use cached shortest paths.

Attributes:

Name Type Description
env pettingzoo.utils.env.AECEnv

Environment used by policy.

shortest_len_paths dict

Cached shortest paths including path lengths for all node pairs.

cur_goals dict

Cached goals for each agent consisting of (path, collected, point_idx) tuples keyed by agent name.

Source code in datadynamics/policies/greedy_policy/greedy_policy.py
class GraphGreedyPolicy(BasePolicy):
    """Greedy policy for graph environment using shortest path.

    Compatible only with graph_collector environment.

    Note:
        This is very slow for non-static graphs as shortest paths are
        computed in O(V^3) time for every action search.
        Static graphs use cached shortest paths.

    Attributes:
        env (pettingzoo.utils.env.AECEnv): Environment used by policy.
        shortest_len_paths (dict): Cached shortest paths including path
            lengths for all node pairs.
        cur_goals (dict): Cached goals for each agent consisting of
            (path, collected, point_idx) tuples keyed by agent name.
    """

    def __init__(self, env):
        assert env.metadata["name"] == "graph_collector", (
            f"{self.__class__.__name__} is only compatible with "
            "graph_collector."
        )

        gymnasium.logger.info("Initializing GraphGreedyPolicy...")
        self.env = env
        if self.env.static_graph:
            gymnasium.logger.info(
                " - Computing and caching shortest paths. This runs in O(V^3) "
                "and may take a while..."
            )
            self.shortest_len_paths = dict(
                nx.all_pairs_dijkstra(self.env.graph)
            )
            # Shortest len paths is a dict with node labels as keys and values
            # consisting of a (length dict, path dict) tuple containing
            # shortest paths between all pairs of nodes.
            self.point_labels = set()
            # cur_goals consist of (path, collected, point_idx) keyed by agent.
            self.cur_goals = {}
        gymnasium.logger.info("Completed initialization.")

    def action(self, observation, agent):
        if self.env.terminations[agent] or self.env.truncations[agent]:
            # Agent is dead, the only valid action is None.
            return None

        if not self.env.static_graph:
            gymnasium.logger.info(
                "Recomputing shortest paths in O(V^3) for non-static graph..."
            )
            self.shortest_len_paths = dict(
                nx.all_pairs_dijkstra(self.env.graph)
            )
            self.cur_goals = {}
            self.point_labels = set()

        if not self.point_labels:
            # For static graphs, the points should not change or change
            # position as such we only need to compute labels once.
            self.point_labels = set(observation["point_labels"])

        agent_idx = int(agent[-1])
        cur_node = observation["collector_labels"][agent_idx]
        goal_path, goal_collected, goal_point_idx = self.cur_goals.get(
            agent, ([], None, None)
        )

        # Update goal if we completed the goal (goal_path is empty) or if
        # the goal was collected by another agent meanwhile.
        if (
            not goal_path
            or goal_collected != observation["collected"][goal_point_idx]
        ):
            best_reward = -np.inf

            for i, point_label in enumerate(observation["point_labels"]):
                path = self.shortest_len_paths.get(cur_node, ({}, {}))[1].get(
                    point_label, []
                )

                if not path:
                    continue

                collected = observation["collected"][i]
                reward = -self.shortest_len_paths.get(cur_node, ({}, {}))[
                    0
                ].get(point_label, np.inf)
                if "collection_reward" in observation:
                    reward += observation["collection_reward"][i]
                if "cheating_cost" in observation and collected > 0:
                    reward -= observation["cheating_cost"][i]

                if reward > best_reward:
                    best_reward = reward
                    # Trim current node and add a `collect` action.
                    goal_path = path[1:] + [-1]
                    goal_collected = collected
                    goal_point_idx = i

            self.cur_goals[agent] = (goal_path, goal_collected, goal_point_idx)

        action = goal_path.pop(0) if goal_path else None

        if action is None:
            gymnasium.logger.warn(
                f"{agent} cannot reach any points and will issue None "
                "actions."
            )

        return action

GreedyPolicy

Bases: BasePolicy

Greedy policy for collector environment.

This policy computes the expected reward for every action in every step and chooses the one with the highest expected reward.

Compatible only with collector environment.

Note

This is only locally optimal and not globally. Best routes to collect all points are disregarded and we only search for the next best point for every step. The policy may degenerate and always sample the same point if the cost of cheating is lower than the reward for collecting a point.

Attributes:

Name Type Description
env pettingzoo.utils.env.AECEnv

Environment used by policy.

Source code in datadynamics/policies/greedy_policy/greedy_policy.py
class GreedyPolicy(BasePolicy):
    """Greedy policy for collector environment.

    This policy computes the expected reward for every action in every step
    and chooses the one with the highest expected reward.

    Compatible only with collector environment.

    Note:
        This is only locally optimal and not globally. Best routes to collect
        all points are disregarded and we only search for the next best point
        for every step.
        The policy may degenerate and always sample the same point if the cost
        of cheating is lower than the reward for collecting a point.

    Attributes:
        env (pettingzoo.utils.env.AECEnv): Environment used by policy.
    """

    def __init__(self, env):
        assert env.metadata["name"] == "collector", (
            f"{self.__class__.__name__} is only compatible with " "collector."
        )

        self.env = env

    def action(self, observation, agent):
        if self.env.terminations[agent] or self.env.truncations[agent]:
            # Agent is dead, the only valid action is None.
            return None

        agent_idx = int(agent[-1])
        cur_position = observation["collector_positions"][agent_idx]

        best_reward = -np.inf
        best_action = None

        for i, position in enumerate(observation["point_positions"]):
            cheating = observation["collected"][i] > 0
            reward = -np.linalg.norm(cur_position - position)

            if "collection_reward" in observation:
                reward += observation["collection_reward"][i]
            if "cheating_cost" in observation and cheating:
                reward -= observation["cheating_cost"][i]

            if reward > best_reward:
                best_reward = reward
                best_action = i

        if best_action is None:
            gymnasium.logger.warn(
                f"{agent} cannot reach any points and will issue None "
                "actions."
            )

        return best_action

policy(**kwargs)

Creates a suitable greedy policy for a given environment.

Returns:

Name Type Description
BasePolicy

Greedy policy.

Source code in datadynamics/policies/greedy_policy/greedy_policy.py
def policy(**kwargs):
    """Creates a suitable greedy policy for a given environment.

    Returns:
        BasePolicy: Greedy policy.
    """
    if kwargs["env"].metadata["name"] == "graph_collector":
        policy = GraphGreedyPolicy(**kwargs)
    else:
        policy = GreedyPolicy(**kwargs)
    return policy

BFS-Greedy Policy

BFSGraphGreedyPolicy

Bases: BasePolicy

Greedy policy using a breadth-first search for every action retrieval.

This policy runs in O(V + E) time when finding a new goal for an agent and as such may slow down stepping through the environment.

Compatible only with graph_collector environment for graphs with equal edge weights between nodes.

Attributes:

Name Type Description
env pettingzoo.utils.env.AECEnv

Environment used by policy.

graph nx.Graph

Graph used by environment.

cur_goals dict

Cached goals for each agent consisting of (path, collected, point_idx) tuples keyed by agent name.

Source code in datadynamics/policies/bfs_greedy_policy/bfs_greedy_policy.py
class BFSGraphGreedyPolicy(BasePolicy):
    """Greedy policy using a breadth-first search for every action retrieval.

    This policy runs in O(V + E) time when finding a new goal for an agent and
    as such may slow down stepping through the environment.

    Compatible only with graph_collector environment for graphs with equal
    edge weights between nodes.

    Attributes:
        env (pettingzoo.utils.env.AECEnv): Environment used by policy.
        graph (nx.Graph): Graph used by environment.
        cur_goals (dict): Cached goals for each agent consisting of
            (path, collected, point_idx) tuples keyed by agent name.
    """

    def __init__(self, env, graph):
        """Initialize policy from environment.

        Args:
            env (pettingzoo.utils.env.AECEnv): Environment on which to base
                policy.
            graph (nx.Graph): Graph used by environment.
        """
        assert env.metadata["name"] == "graph_collector", (
            f"{self.__class__.__name__} is only compatible with "
            "graph_collector."
        )

        self.env = env
        self.graph = graph
        self.cur_goals = {}

    def _bfs_shortest_paths(self, source_node, graph):
        """Runs breadth-first search to find the shortest paths from a source.

        Note:
            This runs in O(V + E) time.

        Args:
            source_node (int): Label of source node.
            graph (nx.Graph): Graph to search.

        Returns:
            dict: Dictionary of predecessors and depth keyed by node label.
        """
        discovered = {source_node}
        predecessors_and_depth = {source_node: (None, 0)}
        queue = deque([(source_node, 0)])

        while queue:
            node, depth = queue.popleft()

            for neighbor in graph.neighbors(node):
                if neighbor not in discovered:
                    discovered.add(neighbor)
                    queue.append((neighbor, depth + 1))
                    predecessors_and_depth[neighbor] = (node, depth)

        return predecessors_and_depth

    def _find_goal_full(self, observation, agent, graph):
        """Finds the point with the highest reward for an agent.

        Args:
            observation (dict): Observation for agent.
            agent (str): Name of agent.
            graph (nx.Graph): Graph used by environment.

        Returns:
            tuple: Tuple of (path, collected, point_idx) where path is the
                shortest path to the point, collected is the number of times
                the point has been collected, and point_idx is the index of
                the point in the observation.
        """
        agent_idx = int(agent[-1])
        agent_node = observation["collector_labels"][agent_idx]
        predecessors_and_depth = self._bfs_shortest_paths(agent_node, graph)
        best_reward = -np.inf
        best_point_label = None
        best_point_idx = None
        best_point_collected = None

        for i, point_label in enumerate(observation["point_labels"]):
            if point_label not in predecessors_and_depth:
                # Skip unreachable points.
                continue
            collected = observation["collected"][i]
            reward = -predecessors_and_depth[point_label][1]
            if "collection_reward" in observation:
                reward += observation["collection_reward"][i]
            if "cheating_cost" in observation and collected > 0:
                reward -= observation["cheating_cost"][i]

            if reward > best_reward:
                best_reward = reward
                best_point_label = point_label
                best_point_collected = collected
                best_point_idx = i

        if best_point_label is None:
            return None

        # Backtrack to find the path for the best point.
        path = [best_point_label]
        while path[-1] != agent_node and path[-1] is not None:
            path.append(predecessors_and_depth[path[-1]][0])
        path.reverse()
        path = path[1:] + [-1]

        return path, best_point_collected, best_point_idx

    def action(self, observation, agent):
        if self.env.terminations[agent] or self.env.truncations[agent]:
            # Agent is dead, the only valid action is None.
            return None

        goal_path, goal_collected, goal_point_idx = self.cur_goals.get(
            agent, ([], None, None)
        )

        if (
            not goal_path
            or goal_collected != observation["collected"][goal_point_idx]
        ):
            goal_path, goal_collected, goal_point_idx = self._find_goal_full(
                observation, agent, self.graph
            )
            self.cur_goals[agent] = (goal_path, goal_collected, goal_point_idx)

        action = goal_path.pop(0) if goal_path else None

        if action is None:
            gymnasium.logger.warn(
                f"{agent} cannot reach any points and will issue None "
                "actions."
            )

        return action

__init__(env, graph)

Initialize policy from environment.

Parameters:

Name Type Description Default
env pettingzoo.utils.env.AECEnv

Environment on which to base policy.

required
graph nx.Graph

Graph used by environment.

required
Source code in datadynamics/policies/bfs_greedy_policy/bfs_greedy_policy.py
def __init__(self, env, graph):
    """Initialize policy from environment.

    Args:
        env (pettingzoo.utils.env.AECEnv): Environment on which to base
            policy.
        graph (nx.Graph): Graph used by environment.
    """
    assert env.metadata["name"] == "graph_collector", (
        f"{self.__class__.__name__} is only compatible with "
        "graph_collector."
    )

    self.env = env
    self.graph = graph
    self.cur_goals = {}

policy(**kwargs)

Creates a suitable BFS-based greedy policy for a given environment.

Returns:

Name Type Description
BasePolicy

BFS-based greedy policy.

Source code in datadynamics/policies/bfs_greedy_policy/bfs_greedy_policy.py
def policy(**kwargs):
    """Creates a suitable BFS-based greedy policy for a given environment.

    Returns:
        BasePolicy: BFS-based greedy policy.
    """
    policy = BFSGraphGreedyPolicy(**kwargs)
    return policy

Premade Policy

PremadePolicy

Bases: BasePolicy

Policy using a premade list of goals for each agent.

This policy runs in O(V + E) time when reaching a goal due to having to search for the shortest path to the given goal.

Compatible only with graph_collector environment for graphs with equal edge weights between nodes.

Attributes:

Name Type Description
env pettingzoo.utils.env.AECEnv

Environment used by policy.

graph nx.Graph

Graph used by environment.

cur_goals dict

Cached goals for each agent consisting of (path, collected, point_idx) tuples keyed by agent name.

Source code in datadynamics/policies/premade_policy/premade_policy.py
class PremadePolicy(BasePolicy):
    """Policy using a premade list of goals for each agent.

    This policy runs in O(V + E) time when reaching a goal due to having to
    search for the shortest path to the given goal.

    Compatible only with graph_collector environment for graphs with equal
    edge weights between nodes.

    Attributes:
        env (pettingzoo.utils.env.AECEnv): Environment used by policy.
        graph (nx.Graph): Graph used by environment.
        cur_goals (dict): Cached goals for each agent consisting of
            (path, collected, point_idx) tuples keyed by agent name.
    """

    def __init__(self, env, graph, goal_dict):
        """Initialize policy from environment.

        Args:
            env (pettingzoo.utils.env.AECEnv): Environment on which to base
                policy.
            graph (nx.Graph): Graph used by environment.
            goal_dict (dict): Dictionary of goals for each agent (keys can be
                arbitrary).
        """
        assert env.metadata["name"] == "graph_collector", (
            f"{self.__class__.__name__} is only compatible with "
            "graph_collector."
        )

        self.env = env
        self.graph = graph
        self.cur_goals = {}
        self.goal_dict = {}

        assert len(goal_dict) == len(env.possible_agents), (
            f"You must provide only one list of goals for each agent. "
            f"Provided {len(goal_dict)} goals for {len(env.possible_agents)} "
            "agents."
        )
        for agent, key in zip(env.possible_agents, goal_dict):
            # Copy goals since we pop from them.
            self.goal_dict[agent] = goal_dict[key][:]

    def _bfs_shortest_path(self, source_node, target, graph):
        """Runs BFS to find the shortest path from source to target.

        Note:
            This runs in O(V + E) time in the worst case.

        Args:
            source_node (int): Label of source node.
            target (int): Label of target node.
            graph (nx.Graph): Graph to search.

        Returns:
            list: Shortest path from source to target

        Raises:
            ValueError: If no path exists from source to target.
        """
        discovered = {source_node}
        predecessors_and_depth = {source_node: (None, 0)}
        queue = deque([(source_node, 0)])

        while queue:
            node, depth = queue.popleft()

            if node == target:
                # Reconstruct path
                path = [target]
                while path[-1] != source_node and path[-1] is not None:
                    path.append(predecessors_and_depth[path[-1]][0])
                path.reverse()
                path = path[1:] + [-1]
                return path

            for neighbor in graph.neighbors(node):
                if neighbor not in discovered:
                    discovered.add(neighbor)
                    queue.append((neighbor, depth + 1))
                    predecessors_and_depth[neighbor] = (node, depth)

        raise ValueError(f"No path exists from {source_node} to {target}.")

    def action(self, observation, agent):
        if self.env.terminations[agent] or self.env.truncations[agent]:
            # Agent is dead, the only valid action is None.
            if self.goal_dict[agent]:
                gymnasium.logger.warn(
                    f"Agent {agent} is dead but has goals remaining."
                )
            return None

        goal_path = self.cur_goals.get(agent, ([]))

        if not goal_path:
            agent_idx = int(agent[-1])
            agent_node = observation["collector_labels"][agent_idx]
            goal_path = self._bfs_shortest_path(
                agent_node, self.goal_dict[agent].pop(0), self.graph
            )
            self.cur_goals[agent] = goal_path

        action = goal_path.pop(0)

        return action

__init__(env, graph, goal_dict)

Initialize policy from environment.

Parameters:

Name Type Description Default
env pettingzoo.utils.env.AECEnv

Environment on which to base policy.

required
graph nx.Graph

Graph used by environment.

required
goal_dict dict

Dictionary of goals for each agent (keys can be arbitrary).

required
Source code in datadynamics/policies/premade_policy/premade_policy.py
def __init__(self, env, graph, goal_dict):
    """Initialize policy from environment.

    Args:
        env (pettingzoo.utils.env.AECEnv): Environment on which to base
            policy.
        graph (nx.Graph): Graph used by environment.
        goal_dict (dict): Dictionary of goals for each agent (keys can be
            arbitrary).
    """
    assert env.metadata["name"] == "graph_collector", (
        f"{self.__class__.__name__} is only compatible with "
        "graph_collector."
    )

    self.env = env
    self.graph = graph
    self.cur_goals = {}
    self.goal_dict = {}

    assert len(goal_dict) == len(env.possible_agents), (
        f"You must provide only one list of goals for each agent. "
        f"Provided {len(goal_dict)} goals for {len(env.possible_agents)} "
        "agents."
    )
    for agent, key in zip(env.possible_agents, goal_dict):
        # Copy goals since we pop from them.
        self.goal_dict[agent] = goal_dict[key][:]

policy(**kwargs)

Creates a premade policy for a given environment

Returns:

Name Type Description
BasePolicy

Premade policy.

Source code in datadynamics/policies/premade_policy/premade_policy.py
def policy(**kwargs):
    """Creates a premade policy for a given environment

    Returns:
        BasePolicy: Premade policy.
    """
    policy = PremadePolicy(**kwargs)
    return policy

Random Policy

RandomPolicy

Bases: BasePolicy

Policy that returns a random action.

Compatible with all environments.

Attributes:

Name Type Description
env pettingzoo.utils.env.AECEnv

Environment used by policy.

Source code in datadynamics/policies/random_policy/random_policy.py
class RandomPolicy(BasePolicy):
    """Policy that returns a random action.

    Compatible with all environments.

    Attributes:
        env (pettingzoo.utils.env.AECEnv): Environment used by policy.
    """

    def __init__(self, env):
        self.env = env

    def action(self, observation, agent):
        if self.env.terminations[agent] or self.env.truncations[agent]:
            # Agent is dead, the only valid action is None.
            return None
        action = self.env.action_space(agent).sample()
        return action

policy(**kwargs)

Creates a RandomPolicy for a given environment.

Returns:

Name Type Description
BasePolicy

Random policy.

Source code in datadynamics/policies/random_policy/random_policy.py
def policy(**kwargs):
    """Creates a RandomPolicy for a given environment.

    Returns:
        BasePolicy: Random policy.
    """
    policy = RandomPolicy(**kwargs)
    return policy

Dummy Policy

DummyPolicy

Bases: BasePolicy

Dummy policy that cycles through all actions.

Compatible with all environments.

Attributes:

Name Type Description
env pettingzoo.utils.env.AECEnv

Environment used by policy.

Source code in datadynamics/policies/dummy_policy/dummy_policy.py
class DummyPolicy(BasePolicy):
    """Dummy policy that cycles through all actions.

    Compatible with all environments.

    Attributes:
        env (pettingzoo.utils.env.AECEnv): Environment used by policy.
    """

    def __init__(self, env):
        self.env = env
        self._actions = self._get_possible_actions(env)

    def _get_possible_actions(self, env):
        """Retrieve all possible actions for given environment

        Args:
            env (pettingzoo.utils.env.AECEnv): Environment for which to
                retrieve actions.

        Returns:
            dict: Dictionary of action iterators keyed by agent.
        """
        actions = {}
        for agent in env.possible_agents:
            action_space = env.action_spaces[agent]
            actions[agent] = cycle(
                range(action_space.start, action_space.n + action_space.start)
            )
        return actions

    def action(self, observation, agent):
        if self.env.terminations[agent] or self.env.truncations[agent]:
            # Agent is dead, the only valid action is None.
            return None
        action = next(self._actions[agent])
        return action

policy(**kwargs)

Creates a dummy policy for a given environment.

Returns:

Name Type Description
BasePolicy

Dummy policy.

Source code in datadynamics/policies/dummy_policy/dummy_policy.py
def policy(**kwargs):
    """Creates a dummy policy for a given environment.

    Returns:
        BasePolicy: Dummy policy.
    """
    policy = DummyPolicy(**kwargs)
    return policy