关于机器学习:FATE联邦学习FATE是否支持batch分批训练

4次阅读

共计 3829 个字符,预计需要花费 10 分钟才能阅读完成。

思路梳理

想要数据上传到 FATE,首先须要 reader 读入数据,能力后续进行训练,首先要保障 reader 能读入数据,不晓得是否能分批次读入?

上传数据后,FATE 须要 trainer 进行训练,不止是否存在批次训练这种模式?

查看 Reader 类

值得注意的是,Reader 类并不在 federatedml 库外面,而是一个独自的 pipeline 库外面的组件。翻阅后发现 Reader 类继承了 Output 类。而 Output 类带有一个关键字 data type:

class Output(object):
    def __init__(self, name, data_type='single', has_data=True, has_model=True, has_cache=False, output_unit=1):
        if has_model:
            self.model = Model(name).model
            self.model_output = Model(name).get_all_output()

        if has_data:
            if data_type == "single":
                self.data = SingleOutputData(name).data
                self.data_output = SingleOutputData(name).get_all_output()
            elif data_type == "multi":
                self.data = TraditionalMultiOutputData(name)
                self.data_output = TraditionalMultiOutputData(name).get_all_output()
            else:
                self.data = NoLimitOutputData(name, output_unit)
                self.data_output = NoLimitOutputData(name, output_unit).get_all_output()

        if has_cache:
            self.cache = Cache(name).cache
            self.cache_output = Cache(name).get_all_output()

对应的三个 data type 类也只不过是划分了 data,并没有跟分批次相干的步骤

class SingleOutputData(object):
    def __init__(self, prefix):
        self.prefix = prefix

    @property
    def data(self):
        return ".".join([self.prefix, IODataType.SINGLE])

    @staticmethod
    def get_all_output():
        return ["data"]


class TraditionalMultiOutputData(object):
    def __init__(self, prefix):
        self.prefix = prefix

    @property
    def train_data(self):
        return ".".join([self.prefix, IODataType.TRAIN])

    @property
    def test_data(self):
        return ".".join([self.prefix, IODataType.TEST])

    @property
    def validate_data(self):
        return ".".join([self.prefix, IODataType.VALIDATE])

    @staticmethod
    def get_all_output():
        return [IODataType.TRAIN,
                IODataType.VALIDATE,
                IODataType.TEST]


class NoLimitOutputData(object):
    def __init__(self, prefix, output_unit=1):
        self.prefix = prefix
        self.output_unit = output_unit

    @property
    def data(self):
        return [self.prefix + "." + "data_" + str(i) for i in range(self.output_unit)]

    def get_all_output(self):
        return ["data_" + str(i) for i in range(self.output_unit)]

所以 Reader 应该是只能单次吞入整个数据集,不可能分批次读入。

查看 Trainer

跟 train 相干的参数都在 TrainerParam 外面。可是 TrainerParam 自身只是个存储参数的包装类,外面没有货色。
最终找到了一个 job submitter 的货色,也是通过传参,调用服务这种模式去做的 Task。这些都是包皮,没有理论的代码。

最初在 federatedml.nn.homo.trainer.fedavg_trainer 里找到 FedAvgTrainer,他外面给了参数,外面有 batch size:

class FedAVGTrainer(TrainerBase):
    """

    Parameters
    ----------
    epochs: int >0, epochs to train
    batch_size: int, -1 means full batch
    secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
                            mask to local models. These random number masks will eventually cancel out to get 0.
    weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
                         if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
                         is the sample number locally and n_global is the sample number of all clients.
                         if False, simply averaging these models.

    early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
                two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
                stop training
    tol: float, tol value for early stop

    aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
                             every n epochs.
    cuda: bool, use cuda or not
    pin_memory: bool, for pytorch DataLoader
    shuffle: bool, for pytorch DataLoader
    data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
    validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
                      if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
                      if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
                      if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
    checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
    task_type: str, 'auto', 'binary', 'multi', 'regression'
               this option decides the return format of this trainer, and the evaluation type when running validation.
               if auto, will automatically infer your task type from labels and predict results.
    """

我本人在 FATE 那里提的 issue:https://github.com/FederatedAI/FATE/issues/4832

最初论断

在 homo 训练,自定义神经网络的场景下应用 FedAvg 训练器可能实现 batch 训练。然而 Reader 是否能加载进来,要看机器,因为 Reader 应该是一次性全副读取的。

正文完
 0