mephisto.abstractions.providers.mock.mock_requester

View Source
#!/usr/bin/env python3

# Copyright (c) Meta Platforms and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from mephisto.data_model.requester import Requester, RequesterArgs
from mephisto.abstractions.providers.mock.provider_type import PROVIDER_TYPE

from typing import Optional, Dict, List, Mapping, Any, TYPE_CHECKING

if TYPE_CHECKING:
    from mephisto.abstractions.database import MephistoDB
    from mephisto.data_model.task_run import TaskRun
    from mephisto.abstractions.providers.mock.mock_datastore import MockDatastore
    from argparse import _ArgumentGroup as ArgumentGroup
    from omegaconf import DictConfig

MOCK_BUDGET = 100000.0


@dataclass
class MockRequesterArgs(RequesterArgs):
    name: str = field(
        default="MOCK_REQUESTER",
        metadata={
            "help": "Name for the requester in the Mephisto DB.",
            "required": True,
        },
    )
    force_fail: bool = field(
        default=False, metadata={"help": "Trigger a failed registration"}
    )


class MockRequester(Requester):
    """
    High level class representing a requester on some kind of crowd provider. Sets some default
    initializations, but mostly should be extended by the specific requesters for crowd providers
    with whatever implementation details are required to get those to work.
    """

    ArgsClass = MockRequesterArgs

    def __init__(
        self,
        db: "MephistoDB",
        db_id: str,
        row: Optional[Mapping[str, Any]] = None,
        _used_new_call: bool = False,
    ):
        super().__init__(db, db_id, row=row, _used_new_call=_used_new_call)
        self.datastore: "MockDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE)

    def register(self, args: Optional["DictConfig"] = None) -> None:
        """Mock requesters don't actually register credentials"""
        if args is not None:
            if args.get("force_fail") is True:
                raise Exception("Forced failure test exception was set")
        else:
            self.datastore.set_requester_registered(self.db_id, True)

    def is_registered(self) -> bool:
        """Return the registration status"""
        return self.datastore.get_requester_registered(self.db_id)

    def get_available_budget(self) -> float:
        """MockRequesters have $100000 to spend"""
        return MOCK_BUDGET

    @classmethod
    def is_sandbox(self) -> bool:
        """MockRequesters are for testing only, and are thus treated as sandbox"""
        return True

    @staticmethod
    def new(db: "MephistoDB", requester_name: str) -> "Requester":
        return MockRequester._register_requester(db, requester_name, PROVIDER_TYPE)
View Source
class MockRequesterArgs(RequesterArgs):
    name: str = field(
        default="MOCK_REQUESTER",
        metadata={
            "help": "Name for the requester in the Mephisto DB.",
            "required": True,
        },
    )
    force_fail: bool = field(
        default=False, metadata={"help": "Trigger a failed registration"}
    )

MockRequesterArgs(name: str = 'MOCK_REQUESTER', force_fail: bool = False)

#   MockRequesterArgs(name: str = 'MOCK_REQUESTER', force_fail: bool = False)
#   name: str = 'MOCK_REQUESTER'
#   force_fail: bool = False
View Source
class MockRequester(Requester):
    """
    High level class representing a requester on some kind of crowd provider. Sets some default
    initializations, but mostly should be extended by the specific requesters for crowd providers
    with whatever implementation details are required to get those to work.
    """

    ArgsClass = MockRequesterArgs

    def __init__(
        self,
        db: "MephistoDB",
        db_id: str,
        row: Optional[Mapping[str, Any]] = None,
        _used_new_call: bool = False,
    ):
        super().__init__(db, db_id, row=row, _used_new_call=_used_new_call)
        self.datastore: "MockDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE)

    def register(self, args: Optional["DictConfig"] = None) -> None:
        """Mock requesters don't actually register credentials"""
        if args is not None:
            if args.get("force_fail") is True:
                raise Exception("Forced failure test exception was set")
        else:
            self.datastore.set_requester_registered(self.db_id, True)

    def is_registered(self) -> bool:
        """Return the registration status"""
        return self.datastore.get_requester_registered(self.db_id)

    def get_available_budget(self) -> float:
        """MockRequesters have $100000 to spend"""
        return MOCK_BUDGET

    @classmethod
    def is_sandbox(self) -> bool:
        """MockRequesters are for testing only, and are thus treated as sandbox"""
        return True

    @staticmethod
    def new(db: "MephistoDB", requester_name: str) -> "Requester":
        return MockRequester._register_requester(db, requester_name, PROVIDER_TYPE)

High level class representing a requester on some kind of crowd provider. Sets some default initializations, but mostly should be extended by the specific requesters for crowd providers with whatever implementation details are required to get those to work.

#   MockRequester( db: mephisto.abstractions.database.MephistoDB, db_id: str, row: Union[Mapping[str, Any], NoneType] = None, _used_new_call: bool = False )
View Source
    def __new__(
        cls,
        db: "MephistoDB",
        db_id: str,
        row: Optional[Mapping[str, Any]] = None,
        _used_new_call: bool = False,
    ) -> "Requester":
        """
        The new method is overridden to be able to automatically generate
        the expected Requester class without needing to specifically find it
        for a given db_id. As such it is impossible to create a base Requester
        as you will instead be returned the correct Requester class according to
        the crowdprovider associated with this Requester.
        """
        from mephisto.operations.registry import get_crowd_provider_from_type

        if cls == Requester:
            # We are trying to construct a Requester, find what type to use and
            # create that instead
            if row is None:
                row = db.get_requester(db_id)
            assert row is not None, f"Given db_id {db_id} did not exist in given db"
            correct_class = get_crowd_provider_from_type(
                row["provider_type"]
            ).RequesterClass
            return super().__new__(correct_class)
        else:
            # We are constructing another instance directly
            return super().__new__(cls)

The new method is overridden to be able to automatically generate the expected Requester class without needing to specifically find it for a given db_id. As such it is impossible to create a base Requester as you will instead be returned the correct Requester class according to the crowdprovider associated with this Requester.

#   def register( self, args: Union[omegaconf.dictconfig.DictConfig, NoneType] = None ) -> None:
View Source
    def register(self, args: Optional["DictConfig"] = None) -> None:
        """Mock requesters don't actually register credentials"""
        if args is not None:
            if args.get("force_fail") is True:
                raise Exception("Forced failure test exception was set")
        else:
            self.datastore.set_requester_registered(self.db_id, True)

Mock requesters don't actually register credentials

#   def is_registered(self) -> bool:
View Source
    def is_registered(self) -> bool:
        """Return the registration status"""
        return self.datastore.get_requester_registered(self.db_id)

Return the registration status

#   def get_available_budget(self) -> float:
View Source
    def get_available_budget(self) -> float:
        """MockRequesters have $100000 to spend"""
        return MOCK_BUDGET

MockRequesters have $100000 to spend

#  
@classmethod
def is_sandbox(self) -> bool:
View Source
    @classmethod
    def is_sandbox(self) -> bool:
        """MockRequesters are for testing only, and are thus treated as sandbox"""
        return True

MockRequesters are for testing only, and are thus treated as sandbox

#  
@staticmethod
def new( db: mephisto.abstractions.database.MephistoDB, requester_name: str ) -> mephisto.data_model.requester.Requester:
View Source
    @staticmethod
    def new(db: "MephistoDB", requester_name: str) -> "Requester":
        return MockRequester._register_requester(db, requester_name, PROVIDER_TYPE)

Try to create a new requester by this name, raise an exception if the name already exists.

Implementation should call _register_requester(db, requester_id) when sure the requester can be successfully created to have it put into the db and return the result.

#   class MockRequester.ArgsClass(mephisto.data_model.requester.RequesterArgs):
View Source
class MockRequesterArgs(RequesterArgs):
    name: str = field(
        default="MOCK_REQUESTER",
        metadata={
            "help": "Name for the requester in the Mephisto DB.",
            "required": True,
        },
    )
    force_fail: bool = field(
        default=False, metadata={"help": "Trigger a failed registration"}
    )

MockRequesterArgs(name: str = 'MOCK_REQUESTER', force_fail: bool = False)