Skip to content

Configuration



Parse environmental and configuration variables

YAMLParser

Bases: BaseModel

Abstract class on top of BaseModel with functionality to parse config from a yaml file

Source code in src/variables.py
class YAMLParser(BaseModel):
    """
    Abstract class on top of BaseModel with functionality to parse config from a yaml file
    """

    @classmethod
    def parse_yaml(cls, file_path, encoding="utf-8"):
        """
        Parse and validate arguments from a YAML file

        Parameters
        ----------
        file_path : str, Path
            relative path to the yaml file
        encoding : str
            encoding to use when opening the file

        Returns
        -------
        out : TrainConfig
            validated model
        """
        with open(file_path, "r+", encoding=encoding) as file:
            obj = yaml.safe_load(file)
        return cls.model_validate(obj)

parse_yaml(file_path, encoding='utf-8') classmethod

Parse and validate arguments from a YAML file

Parameters:

Name Type Description Default
file_path (str, Path)

relative path to the yaml file

required
encoding str

encoding to use when opening the file

'utf-8'

Returns:

Name Type Description
out TrainConfig

validated model

Source code in src/variables.py
@classmethod
def parse_yaml(cls, file_path, encoding="utf-8"):
    """
    Parse and validate arguments from a YAML file

    Parameters
    ----------
    file_path : str, Path
        relative path to the yaml file
    encoding : str
        encoding to use when opening the file

    Returns
    -------
    out : TrainConfig
        validated model
    """
    with open(file_path, "r+", encoding=encoding) as file:
        obj = yaml.safe_load(file)
    return cls.model_validate(obj)

ClientConfig

Bases: BaseModel

Settings for the aiohttp client

Source code in src/variables.py
class ClientConfig(BaseModel):
    """
    Settings for the aiohttp client
    """
    semaphore: conint(gt=0) = Field(default=16,
                                    description="Value of the semaphore param to use when making "
                                                "asynchronous calls to the final endpoint")

PredictorConfig

Bases: BaseModel

Settings for a single predictor

Source code in src/variables.py
class PredictorConfig(BaseModel):
    """
    Settings for a single predictor
    """
    name: constr(min_length=1) = Field(description="Name of the network")
    host_url: str = Field(default=AWS_SETTINGS.TF_HOST_URL,
                          description="URL for accessing the served network")
    version: int = Field(default=1, description="Dummy variable for the version of the network. "
                                                "Isn't yet used since we serve 1 version per net ")

    @property
    def predict_url(self):
        """Infer and return model url based on its config"""
        net_url = os.path.join(*[self.host_url, self.name, "versions", str(self.version)])
        return f"{net_url}:predict"

predict_url property

Infer and return model url based on its config

EncoderConfig

Bases: PredictorConfig

Settings for a single encoder network

Source code in src/variables.py
class EncoderConfig(PredictorConfig):
    """
    Settings for a single encoder network
    """

RemoveAllKeysConfig

Bases: BaseModel

Config for removing items from predictions dictionary when condition is present

Source code in src/variables.py
class RemoveAllKeysConfig(BaseModel):
    """Config for removing items from predictions dictionary when condition is present"""
    condition: str = Field(description="The name of key that should trigger postprocessing. If "
                                       "its present, all the other classes will be removed")
    exceptions: List[str] = Field(default=None,
                                  description="exceptions to not remove from the dict")

RenameKeyConfig

Bases: BaseModel

Config for renaming items from predictions dictionary when condition is present

Source code in src/variables.py
class RenameKeyConfig(BaseModel):
    """Config for renaming items from predictions dictionary when condition is present"""
    conditions: List[str] = Field(description="The name of key that should trigger postprocessing. If "
                                              "its present, all the other classes will be removed")
    to_rename: str = Field(description="Name of the field to rename")
    new_name: str = Field(description="New name of the field to rename")

RemoveSpecifiedKeysConfig

Bases: BaseModel

Config for removing specific keys from the prediction dictionary if the condition is present

Source code in src/variables.py
class RemoveSpecifiedKeysConfig(BaseModel):
    """Config for removing specific keys from the prediction dictionary
    if the condition is present"""
    condition: str = Field(description="The name of key that should trigger postprocessing. If "
                                       "its present, all the other classes will be removed")
    to_remove: List[str] = Field(description="The names of the fields to remove from the "
                                             "dictionary if condition is present")

KeepOnlyByConditionsConfig

Bases: BaseModel

Config for keeping certain classes only if any of the conditions is present

Source code in src/variables.py
class KeepOnlyByConditionsConfig(BaseModel):
    """Config for keeping certain classes only if any of the conditions is present"""
    conditions: List[str] = Field(description="The names of keys to check. If none is in the "
                                              "predictions, elements from to_filter will be "
                                              "removed")
    to_filter: List[str] = Field(description="Items to filter out if none of the conditions is "
                                             "present")

PostprocessingConfig

Bases: BaseModel

Configuration to postprocess the output of a Classifier

Source code in src/variables.py
class PostprocessingConfig(BaseModel):
    """Configuration to postprocess the output of a Classifier"""
    remove_all_keys_except: Optional[List[RemoveAllKeysConfig]] = \
        Field(default=None, description="Config to remove all classes from predictions for "
                                        "specified classes")
    remove_keys_by_condition: Optional[List[RemoveSpecifiedKeysConfig]] = \
        Field(default=None, description="Config to remove specific classes from predictions "
                                        "containing specific condition")
    rename_key_by_condition: Optional[List[RenameKeyConfig]] = \
        Field(default=None, description="Config to rename a class if they are coupled with "
                                        "specified condition")
    keep_only_by_conditions: Optional[List[KeepOnlyByConditionsConfig]] = \
        Field(default=None, description="Config to filter some elements out if they are not "
                                        "paired with certain conditions")

ClassifierConfig

Bases: PredictorConfig

Settings for a single servable network

Source code in src/variables.py
class ClassifierConfig(PredictorConfig):
    """
    Settings for a single servable network
    """
    output_names: conlist(item_type=constr(min_length=1)) = \
        Field(description="List of the class names the model was trained on. Must be of the same "
                          "length as the final layer of network")
    postprocessing_config: Optional[PostprocessingConfig] = Field(default=None,
                                                                  description="Configuration for "
                                                                              "postprocessing the"
                                                                              " predictors output")
    all_classes: conlist(item_type=constr(min_length=1)) = \
        Field(default=None, description="All possible outputs the network can give after "
                                        "postprocessing")

    thresholds: Optional[Dict[str, confloat(ge=0, le=1)]] = \
        Field(default_factory=dict, description="dictionary of per-class thresholds to use when "
                                                "extracting tags. If any class isn't specified "
                                                "here, default_threshold will be used for it")
    default_threshold: confloat(ge=0, le=1) = \
        Field(default=0.5, description="The default threshold to use when extracting tags from "
                                       "multilabel network")

    is_multilabel: bool = Field(default=False, description="Whether the net is multilabel, "
                                                           "i.e. outputs independent scores for "
                                                           "each class")

    def __init__(self, **kwargs):
        """initialize the config & create a default threshold dictionary to use when extracting
        tags"""
        super().__init__(**kwargs)
        defaults = dict(zip(self.output_names, [self.default_threshold] * len(self.output_names)))
        defaults.update(self.thresholds)
        self.thresholds = defaults

        self.all_classes = self.output_names.copy()
        if self.postprocessing_config is not None:
            for item in self.postprocessing_config.rename_key_by_condition or []:
                self.all_classes.append(item.new_name)

__init__(**kwargs)

initialize the config & create a default threshold dictionary to use when extracting tags

Source code in src/variables.py
def __init__(self, **kwargs):
    """initialize the config & create a default threshold dictionary to use when extracting
    tags"""
    super().__init__(**kwargs)
    defaults = dict(zip(self.output_names, [self.default_threshold] * len(self.output_names)))
    defaults.update(self.thresholds)
    self.thresholds = defaults

    self.all_classes = self.output_names.copy()
    if self.postprocessing_config is not None:
        for item in self.postprocessing_config.rename_key_by_condition or []:
            self.all_classes.append(item.new_name)

ServiceConfig

Bases: YAMLParser

Settings for the whole service consisting of several networks

Source code in src/variables.py
class ServiceConfig(YAMLParser):
    """
    Settings for the whole service consisting of several networks
    """
    tagger: ClassifierConfig = Field(description="Configuration for the Tagger network")
    exterior_styles: ClassifierConfig = Field(
        description="Configuration for the Exterior Styles network")
    encoder: EncoderConfig = Field(description="Configuration for the Encoder network")
    client: ClientConfig = Field(default=ClientConfig(),
                                 description="Configuration for the async Client making calls to "
                                             "the final endpoints")