mephisto.abstractions.blueprints.mock.mock_task_runner

View Source
#!/usr/bin/env python3

# Copyright (c) Meta Platforms and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from mephisto.abstractions.blueprint import TaskRunner, SharedTaskState
from mephisto.data_model.assignment import InitializationData

import os
import time

from typing import ClassVar, List, Type, Any, Dict, Union, TYPE_CHECKING

if TYPE_CHECKING:
    from mephisto.data_model.task_run import TaskRun
    from mephisto.data_model.unit import Unit
    from mephisto.data_model.assignment import Assignment
    from mephisto.data_model.agent import Agent, OnboardingAgent
    from argparse import _ArgumentGroup as ArgumentGroup
    from omegaconf import DictConfig


class MockTaskRunner(TaskRunner):
    """Mock of a task runner, for use in testing"""

    def __init__(
        self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"
    ):
        super().__init__(task_run, args, shared_state)
        self.timeout = args.blueprint.timeout_time
        self.tracked_tasks: Dict[str, Union["Assignment", "Unit"]] = {}
        self.is_concurrent = args.blueprint.get("is_concurrent", True)

    @staticmethod
    def get_mock_assignment_data() -> InitializationData:
        return InitializationData(shared={}, unit_data=[{}, {}])

    @staticmethod
    def get_data_for_assignment(assignment: "Assignment") -> InitializationData:
        """
        Mock tasks have no data unless given during testing
        """
        return MockTaskRunner.get_mock_assignment_data()

    def get_init_data_for_agent(self, agent: "Agent") -> Dict[str, Any]:
        """
        Return the data for an agent already assigned to a particular unit
        """
        # TODO(#97) implement
        pass

    def run_onboarding(self, onboarding_agent: "OnboardingAgent"):
        """
        Mock runners simply wait for an act to come in with whether
        or not onboarding is complete
        """
        onboarding_agent.await_submit(self.timeout)

    def run_unit(self, unit: "Unit", agent: "Agent"):
        """
        Mock runners will pass the agents for the given assignment
        all of the required messages to finish a task.
        """
        self.tracked_tasks[unit.db_id] = unit
        time.sleep(0.3)
        assigned_agent = unit.get_assigned_agent()
        assert assigned_agent is not None, "No agent was assigned"
        assert (
            assigned_agent.db_id == agent.db_id
        ), "Task was not given to assigned agent"
        packet = agent.get_live_update(timeout=self.timeout)
        if packet is not None:
            agent.observe(packet)
        agent.await_submit(self.timeout)
        del self.tracked_tasks[unit.db_id]

    def run_assignment(self, assignment: "Assignment", agents: List["Agent"]):
        """
        Mock runners will pass the agents for the given assignment
        all of the required messages to finish a task.
        """
        self.tracked_tasks[assignment.db_id] = assignment
        agent_dict = {a.db_id: a for a in agents}
        time.sleep(0.3)
        agents = []
        for unit in assignment.get_units():
            assigned_agent = unit.get_assigned_agent()
            assert assigned_agent is not None, "Task was not fully assigned"
            agent = agent_dict.get(assigned_agent.db_id)
            assert agent is not None, "Task was not launched with assigned agents"
            agents.append(agent)
        for agent in agents:
            packet = agent.get_live_update(timeout=self.timeout)
            if packet is not None:
                agent.observe(packet)
        for agent in agents:
            agent.await_submit(self.timeout)
        del self.tracked_tasks[assignment.db_id]

    def cleanup_assignment(self, assignment: "Assignment"):
        """No cleanup required yet for ending mock runs"""
        pass

    def cleanup_unit(self, unit: "Unit"):
        """No cleanup required yet for ending mock runs"""
        pass

    def cleanup_onboarding(self, onboarding_agent: "OnboardingAgent"):
        """No cleanup required yet for ending onboarding in mocks"""
        pass
#   class MockTaskRunner(mephisto.abstractions._subcomponents.task_runner.TaskRunner):
View Source
class MockTaskRunner(TaskRunner):
    """Mock of a task runner, for use in testing"""

    def __init__(
        self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"
    ):
        super().__init__(task_run, args, shared_state)
        self.timeout = args.blueprint.timeout_time
        self.tracked_tasks: Dict[str, Union["Assignment", "Unit"]] = {}
        self.is_concurrent = args.blueprint.get("is_concurrent", True)

    @staticmethod
    def get_mock_assignment_data() -> InitializationData:
        return InitializationData(shared={}, unit_data=[{}, {}])

    @staticmethod
    def get_data_for_assignment(assignment: "Assignment") -> InitializationData:
        """
        Mock tasks have no data unless given during testing
        """
        return MockTaskRunner.get_mock_assignment_data()

    def get_init_data_for_agent(self, agent: "Agent") -> Dict[str, Any]:
        """
        Return the data for an agent already assigned to a particular unit
        """
        # TODO(#97) implement
        pass

    def run_onboarding(self, onboarding_agent: "OnboardingAgent"):
        """
        Mock runners simply wait for an act to come in with whether
        or not onboarding is complete
        """
        onboarding_agent.await_submit(self.timeout)

    def run_unit(self, unit: "Unit", agent: "Agent"):
        """
        Mock runners will pass the agents for the given assignment
        all of the required messages to finish a task.
        """
        self.tracked_tasks[unit.db_id] = unit
        time.sleep(0.3)
        assigned_agent = unit.get_assigned_agent()
        assert assigned_agent is not None, "No agent was assigned"
        assert (
            assigned_agent.db_id == agent.db_id
        ), "Task was not given to assigned agent"
        packet = agent.get_live_update(timeout=self.timeout)
        if packet is not None:
            agent.observe(packet)
        agent.await_submit(self.timeout)
        del self.tracked_tasks[unit.db_id]

    def run_assignment(self, assignment: "Assignment", agents: List["Agent"]):
        """
        Mock runners will pass the agents for the given assignment
        all of the required messages to finish a task.
        """
        self.tracked_tasks[assignment.db_id] = assignment
        agent_dict = {a.db_id: a for a in agents}
        time.sleep(0.3)
        agents = []
        for unit in assignment.get_units():
            assigned_agent = unit.get_assigned_agent()
            assert assigned_agent is not None, "Task was not fully assigned"
            agent = agent_dict.get(assigned_agent.db_id)
            assert agent is not None, "Task was not launched with assigned agents"
            agents.append(agent)
        for agent in agents:
            packet = agent.get_live_update(timeout=self.timeout)
            if packet is not None:
                agent.observe(packet)
        for agent in agents:
            agent.await_submit(self.timeout)
        del self.tracked_tasks[assignment.db_id]

    def cleanup_assignment(self, assignment: "Assignment"):
        """No cleanup required yet for ending mock runs"""
        pass

    def cleanup_unit(self, unit: "Unit"):
        """No cleanup required yet for ending mock runs"""
        pass

    def cleanup_onboarding(self, onboarding_agent: "OnboardingAgent"):
        """No cleanup required yet for ending onboarding in mocks"""
        pass

Mock of a task runner, for use in testing

#   MockTaskRunner( task_run: mephisto.data_model.task_run.TaskRun, args: omegaconf.dictconfig.DictConfig, shared_state: mephisto.abstractions.blueprint.SharedTaskState )
View Source
    def __new__(
        cls, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"
    ) -> "TaskRunner":
        """Get the correct TaskRunner for this task run"""
        if cls == TaskRunner:
            from mephisto.operations.registry import get_blueprint_from_type

            # We are trying to construct an AgentState, find what type to use and
            # create that instead
            correct_class = get_blueprint_from_type(task_run.task_type).TaskRunnerClass
            return super().__new__(correct_class)
        else:
            # We are constructing another instance directly
            return super().__new__(cls)

Get the correct TaskRunner for this task run

#  
@staticmethod
def get_mock_assignment_data() -> mephisto.data_model.assignment.InitializationData:
View Source
    @staticmethod
    def get_mock_assignment_data() -> InitializationData:
        return InitializationData(shared={}, unit_data=[{}, {}])
#  
@staticmethod
def get_data_for_assignment( assignment: mephisto.data_model.assignment.Assignment ) -> mephisto.data_model.assignment.InitializationData:
View Source
    @staticmethod
    def get_data_for_assignment(assignment: "Assignment") -> InitializationData:
        """
        Mock tasks have no data unless given during testing
        """
        return MockTaskRunner.get_mock_assignment_data()

Mock tasks have no data unless given during testing

#   def get_init_data_for_agent(self, agent: mephisto.data_model.agent.Agent) -> Dict[str, Any]:
View Source
    def get_init_data_for_agent(self, agent: "Agent") -> Dict[str, Any]:
        """
        Return the data for an agent already assigned to a particular unit
        """
        # TODO(#97) implement
        pass

Return the data for an agent already assigned to a particular unit

#   def run_onboarding(self, onboarding_agent: mephisto.data_model.agent.OnboardingAgent):
View Source
    def run_onboarding(self, onboarding_agent: "OnboardingAgent"):
        """
        Mock runners simply wait for an act to come in with whether
        or not onboarding is complete
        """
        onboarding_agent.await_submit(self.timeout)

Mock runners simply wait for an act to come in with whether or not onboarding is complete

View Source
    def run_unit(self, unit: "Unit", agent: "Agent"):
        """
        Mock runners will pass the agents for the given assignment
        all of the required messages to finish a task.
        """
        self.tracked_tasks[unit.db_id] = unit
        time.sleep(0.3)
        assigned_agent = unit.get_assigned_agent()
        assert assigned_agent is not None, "No agent was assigned"
        assert (
            assigned_agent.db_id == agent.db_id
        ), "Task was not given to assigned agent"
        packet = agent.get_live_update(timeout=self.timeout)
        if packet is not None:
            agent.observe(packet)
        agent.await_submit(self.timeout)
        del self.tracked_tasks[unit.db_id]

Mock runners will pass the agents for the given assignment all of the required messages to finish a task.

#   def run_assignment( self, assignment: mephisto.data_model.assignment.Assignment, agents: list[mephisto.data_model.agent.Agent] ):
View Source
    def run_assignment(self, assignment: "Assignment", agents: List["Agent"]):
        """
        Mock runners will pass the agents for the given assignment
        all of the required messages to finish a task.
        """
        self.tracked_tasks[assignment.db_id] = assignment
        agent_dict = {a.db_id: a for a in agents}
        time.sleep(0.3)
        agents = []
        for unit in assignment.get_units():
            assigned_agent = unit.get_assigned_agent()
            assert assigned_agent is not None, "Task was not fully assigned"
            agent = agent_dict.get(assigned_agent.db_id)
            assert agent is not None, "Task was not launched with assigned agents"
            agents.append(agent)
        for agent in agents:
            packet = agent.get_live_update(timeout=self.timeout)
            if packet is not None:
                agent.observe(packet)
        for agent in agents:
            agent.await_submit(self.timeout)
        del self.tracked_tasks[assignment.db_id]

Mock runners will pass the agents for the given assignment all of the required messages to finish a task.

#   def cleanup_assignment(self, assignment: mephisto.data_model.assignment.Assignment):
View Source
    def cleanup_assignment(self, assignment: "Assignment"):
        """No cleanup required yet for ending mock runs"""
        pass

No cleanup required yet for ending mock runs

#   def cleanup_unit(self, unit: mephisto.data_model.unit.Unit):
View Source
    def cleanup_unit(self, unit: "Unit"):
        """No cleanup required yet for ending mock runs"""
        pass

No cleanup required yet for ending mock runs

#   def cleanup_onboarding(self, onboarding_agent: mephisto.data_model.agent.OnboardingAgent):
View Source
    def cleanup_onboarding(self, onboarding_agent: "OnboardingAgent"):
        """No cleanup required yet for ending onboarding in mocks"""
        pass

No cleanup required yet for ending onboarding in mocks

Inherited Members
mephisto.abstractions._subcomponents.task_runner.TaskRunner
execute_onboarding
execute_unit
execute_assignment
filter_units_for_worker
shutdown