mephisto.abstractions.providers.mock.mock_unit

View Source
#!/usr/bin/env python3

# Copyright (c) Meta Platforms and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from mephisto.data_model.unit import Unit
from mephisto.data_model.constants.assignment_state import AssignmentState
from mephisto.abstractions.blueprint import AgentState

from mephisto.abstractions.providers.mock.provider_type import PROVIDER_TYPE
from typing import List, Optional, Tuple, Dict, Mapping, Any, Type, TYPE_CHECKING

if TYPE_CHECKING:
    from mephisto.abstractions.database import MephistoDB
    from mephisto.data_model.assignment import Assignment
    from mephisto.abstractions.providers.mock.mock_datastore import MockDatastore

from mephisto.utils.logger_core import get_logger

logger = get_logger(name=__name__)


class MockUnit(Unit):
    """
    This class tracks the status of an individual worker's contribution to a
    higher level assignment. It is the smallest 'unit' of work to complete
    the assignment, and this class is only responsible for checking
    the status of that work itself being done.

    It should be extended for usage with a specific crowd provider
    """

    def __init__(
        self,
        db: "MephistoDB",
        db_id: str,
        row: Optional[Mapping[str, Any]] = None,
        _used_new_call: bool = False,
    ):
        super().__init__(db, db_id, row=row, _used_new_call=_used_new_call)
        self.datastore: "MockDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE)

    def launch(self, task_url: str) -> None:
        """Mock launches do nothing right now beyond updating state"""
        self.set_db_status(status=AssignmentState.LAUNCHED)

        # TODO(OWN) get this link to the frontend
        port = task_url.split(":")[1].split("/")[0]
        print(task_url)
        print(
            f"Mock task launched: localhost:{port} for preview, "
            f"localhost:{port}/?worker_id=x&assignment_id={self.db_id}"
        )
        logger.info(
            f"Mock task launched: localhost:{port} for preview, "
            f"localhost:{port}/?worker_id=x&assignment_id={self.db_id} for assignment {self.assignment_id}"
        )

        return None

    def expire(self) -> float:
        """Expiration is immediate on Mocks"""
        if self.get_status() not in [
            AssignmentState.EXPIRED,
            AssignmentState.COMPLETED,
        ]:
            self.set_db_status(AssignmentState.EXPIRED)
        self.datastore.set_unit_expired(self.db_id, True)
        return 0.0

    def is_expired(self) -> bool:
        """Determine if this unit is expired as according to the vendor."""
        return self.datastore.get_unit_expired(self.db_id)

    @staticmethod
    def new(
        db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float
    ) -> "Unit":
        """Create a Unit for the given assignment"""
        return MockUnit._register_unit(db, assignment, index, pay_amount, PROVIDER_TYPE)
View Source
class MockUnit(Unit):
    """
    This class tracks the status of an individual worker's contribution to a
    higher level assignment. It is the smallest 'unit' of work to complete
    the assignment, and this class is only responsible for checking
    the status of that work itself being done.

    It should be extended for usage with a specific crowd provider
    """

    def __init__(
        self,
        db: "MephistoDB",
        db_id: str,
        row: Optional[Mapping[str, Any]] = None,
        _used_new_call: bool = False,
    ):
        super().__init__(db, db_id, row=row, _used_new_call=_used_new_call)
        self.datastore: "MockDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE)

    def launch(self, task_url: str) -> None:
        """Mock launches do nothing right now beyond updating state"""
        self.set_db_status(status=AssignmentState.LAUNCHED)

        # TODO(OWN) get this link to the frontend
        port = task_url.split(":")[1].split("/")[0]
        print(task_url)
        print(
            f"Mock task launched: localhost:{port} for preview, "
            f"localhost:{port}/?worker_id=x&assignment_id={self.db_id}"
        )
        logger.info(
            f"Mock task launched: localhost:{port} for preview, "
            f"localhost:{port}/?worker_id=x&assignment_id={self.db_id} for assignment {self.assignment_id}"
        )

        return None

    def expire(self) -> float:
        """Expiration is immediate on Mocks"""
        if self.get_status() not in [
            AssignmentState.EXPIRED,
            AssignmentState.COMPLETED,
        ]:
            self.set_db_status(AssignmentState.EXPIRED)
        self.datastore.set_unit_expired(self.db_id, True)
        return 0.0

    def is_expired(self) -> bool:
        """Determine if this unit is expired as according to the vendor."""
        return self.datastore.get_unit_expired(self.db_id)

    @staticmethod
    def new(
        db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float
    ) -> "Unit":
        """Create a Unit for the given assignment"""
        return MockUnit._register_unit(db, assignment, index, pay_amount, PROVIDER_TYPE)

This class tracks the status of an individual worker's contribution to a higher level assignment. It is the smallest 'unit' of work to complete the assignment, and this class is only responsible for checking the status of that work itself being done.

It should be extended for usage with a specific crowd provider

#   MockUnit( db: mephisto.abstractions.database.MephistoDB, db_id: str, row: Union[Mapping[str, Any], NoneType] = None, _used_new_call: bool = False )
View Source
    def __new__(
        cls,
        db: "MephistoDB",
        db_id: str,
        row: Optional[Mapping[str, Any]] = None,
        _used_new_call: bool = False,
    ) -> "Unit":
        """
        The new method is overridden to be able to automatically generate
        the expected Unit class without needing to specifically find it
        for a given db_id. As such it is impossible to create a Unit
        as you will instead be returned the correct Unit class according to
        the crowdprovider associated with this Unit.
        """
        if cls == Unit:
            # We are trying to construct a Unit, find what type to use and
            # create that instead
            from mephisto.operations.registry import get_crowd_provider_from_type

            if row is None:
                row = db.get_unit(db_id)
            assert row is not None, f"Given db_id {db_id} did not exist in given db"
            correct_class = get_crowd_provider_from_type(row["provider_type"]).UnitClass
            return super().__new__(correct_class)
        else:
            # We are constructing another instance directly
            return super().__new__(cls)

The new method is overridden to be able to automatically generate the expected Unit class without needing to specifically find it for a given db_id. As such it is impossible to create a Unit as you will instead be returned the correct Unit class according to the crowdprovider associated with this Unit.

#   def launch(self, task_url: str) -> None:
View Source
    def launch(self, task_url: str) -> None:
        """Mock launches do nothing right now beyond updating state"""
        self.set_db_status(status=AssignmentState.LAUNCHED)

        # TODO(OWN) get this link to the frontend
        port = task_url.split(":")[1].split("/")[0]
        print(task_url)
        print(
            f"Mock task launched: localhost:{port} for preview, "
            f"localhost:{port}/?worker_id=x&assignment_id={self.db_id}"
        )
        logger.info(
            f"Mock task launched: localhost:{port} for preview, "
            f"localhost:{port}/?worker_id=x&assignment_id={self.db_id} for assignment {self.assignment_id}"
        )

        return None

Mock launches do nothing right now beyond updating state

#   def expire(self) -> float:
View Source
    def expire(self) -> float:
        """Expiration is immediate on Mocks"""
        if self.get_status() not in [
            AssignmentState.EXPIRED,
            AssignmentState.COMPLETED,
        ]:
            self.set_db_status(AssignmentState.EXPIRED)
        self.datastore.set_unit_expired(self.db_id, True)
        return 0.0

Expiration is immediate on Mocks

#   def is_expired(self) -> bool:
View Source
    def is_expired(self) -> bool:
        """Determine if this unit is expired as according to the vendor."""
        return self.datastore.get_unit_expired(self.db_id)

Determine if this unit is expired as according to the vendor.

#  
@staticmethod
def new( db: mephisto.abstractions.database.MephistoDB, assignment: mephisto.data_model.assignment.Assignment, index: int, pay_amount: float ) -> mephisto.data_model.unit.Unit:
View Source
    @staticmethod
    def new(
        db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float
    ) -> "Unit":
        """Create a Unit for the given assignment"""
        return MockUnit._register_unit(db, assignment, index, pay_amount, PROVIDER_TYPE)

Create a Unit for the given assignment