Source code for fed.config
"""This module should be cached locally due to all configurations
are mutable.
"""
import json
from dataclasses import dataclass, fields
from typing import Dict, List, Optional
import cloudpickle
import fed._private.compatible_utils as compatible_utils
import fed._private.constants as fed_constants
[docs]class ClusterConfig:
"""A local cache of cluster configuration items."""
def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)
@property
def cluster_addresses(self):
return self._data[fed_constants.KEY_OF_CLUSTER_ADDRESSES]
@property
def current_party(self):
return self._data[fed_constants.KEY_OF_CURRENT_PARTY_NAME]
@property
def tls_config(self):
return self._data[fed_constants.KEY_OF_TLS_CONFIG]
class JobConfig:
def __init__(self, raw_bytes: bytes) -> None:
if raw_bytes is None:
self._data = {}
else:
self._data = cloudpickle.loads(raw_bytes)
@property
def cross_silo_comm_config_dict(self) -> Dict:
return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT, {})
# A module level cache for the cluster configurations.
_cluster_config = None
_job_config = None
def get_cluster_config(job_name: str = None) -> ClusterConfig:
"""This function is not thread safe to use."""
global _cluster_config
if _cluster_config is None:
assert (
job_name is not None
), "Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG)
_cluster_config = ClusterConfig(raw_dict)
return _cluster_config
def get_job_config(job_name: str = None) -> JobConfig:
"""This config still acts like cluster config for now"""
global _job_config
if _job_config is None:
assert (
job_name is not None
), "Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG)
_job_config = JobConfig(raw_dict)
return _job_config
[docs]@dataclass
class CrossSiloMessageConfig:
"""A class to store parameters used for Proxy Actor.
Attributes:
proxy_max_restarts:
The max restart times for the send proxy.
serializing_allowed_list:
The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
deserializing execution attack when crossing silos.
send_resource_label:
Customized resource label, the SenderProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the sender proxy actor will be
started only on nodes with `{"resource": {"my_label": $NUM}}` where
$NUM >= 1.
recv_resource_label:
Customized resource label, the ReceiverProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the receiver proxy actor will be
started only on nodes with `{"resource": {"my_label": $NUM}}` where
$NUM >= 1.
exit_on_sending_failure:
whether exit when failure on cross-silo sending. If True, a SIGINT will be
signaled to self if failed to sending cross-silo data and exit then.
continue_waiting_for_data_sending_on_error:
Whether to continue waiting for data sending if an error occurs, including
data-sending errors and receiving errors from the peer. If True, wait until
all data has been sent.
messages_max_size_in_bytes:
The maximum length in bytes of cross-silo messages. If None, the default
value of 500 MB is specified.
timeout_in_ms:
The timeout in mili-seconds of a cross-silo RPC call. It's 60000 by
default.
http_header:
The HTTP header, e.g. metadata in grpc, sent with the RPC request.
This won't override basic tcp headers, such as `user-agent`, but concat
them together.
max_concurrency:
the max_concurrency of the sender/receiver proxy actor.
use_global_proxy:
Whether using the global proxy actor or create new proxy actor for current
job.
"""
proxy_max_restarts: int = None
timeout_in_ms: int = 60000
messages_max_size_in_bytes: int = None
exit_on_sending_failure: Optional[bool] = False
continue_waiting_for_data_sending_on_error: Optional[bool] = False
serializing_allowed_list: Optional[Dict[str, str]] = None
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
http_header: Optional[Dict[str, str]] = None
max_concurrency: Optional[int] = None
expose_error_trace: Optional[bool] = False
use_global_proxy: Optional[bool] = True
def __json__(self):
return json.dumps(self.__dict__)
@classmethod
def from_json(cls, json_str):
data = json.loads(json_str)
return cls(**data)
[docs] @classmethod
def from_dict(cls, data: Dict) -> 'CrossSiloMessageConfig':
"""Initialize CrossSiloMessageConfig from a dictionary.
Args:
data (Dict): Dictionary with keys as member variable names.
Returns:
CrossSiloMessageConfig: An instance of CrossSiloMessageConfig.
"""
# Get the attributes of the class
data = data or {}
attrs = [field.name for field in fields(cls)]
# Filter the dictionary to only include keys that are attributes of the class
filtered_data = {key: value for key, value in data.items() if key in attrs}
return cls(**filtered_data)
[docs]@dataclass
class GrpcCrossSiloMessageConfig(CrossSiloMessageConfig):
"""A class to store parameters used for GRPC communication
Attributes:
grpc_retry_policy:
a dict descibes the retry policy for cross silo rpc call. If None, the
following default retry policy will be used. More details please refer to
`retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa
.. code:: python
{
"maxAttempts": 4,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": [
"UNAVAILABLE"
]
}
grpc_channel_options: A list of tuples to store GRPC channel options, e.g.
.. code:: python
[
('grpc.enable_retries', 1),
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""
grpc_channel_options: List = None
grpc_retry_policy: Dict[str, str] = None