共计 3829 个字符,预计需要花费 10 分钟才能阅读完成。
思路梳理
想要数据上传到 FATE,首先须要 reader 读入数据,能力后续进行训练,首先要保障 reader 能读入数据,不晓得是否能分批次读入?
上传数据后,FATE 须要 trainer 进行训练,不止是否存在批次训练这种模式?
查看 Reader 类
值得注意的是,Reader 类并不在 federatedml 库外面,而是一个独自的 pipeline 库外面的组件。翻阅后发现 Reader 类继承了 Output 类。而 Output 类带有一个关键字 data type:
class Output(object): | |
def __init__(self, name, data_type='single', has_data=True, has_model=True, has_cache=False, output_unit=1): | |
if has_model: | |
self.model = Model(name).model | |
self.model_output = Model(name).get_all_output() | |
if has_data: | |
if data_type == "single": | |
self.data = SingleOutputData(name).data | |
self.data_output = SingleOutputData(name).get_all_output() | |
elif data_type == "multi": | |
self.data = TraditionalMultiOutputData(name) | |
self.data_output = TraditionalMultiOutputData(name).get_all_output() | |
else: | |
self.data = NoLimitOutputData(name, output_unit) | |
self.data_output = NoLimitOutputData(name, output_unit).get_all_output() | |
if has_cache: | |
self.cache = Cache(name).cache | |
self.cache_output = Cache(name).get_all_output() |
对应的三个 data type 类也只不过是划分了 data,并没有跟分批次相干的步骤
class SingleOutputData(object): | |
def __init__(self, prefix): | |
self.prefix = prefix | |
@property | |
def data(self): | |
return ".".join([self.prefix, IODataType.SINGLE]) | |
@staticmethod | |
def get_all_output(): | |
return ["data"] | |
class TraditionalMultiOutputData(object): | |
def __init__(self, prefix): | |
self.prefix = prefix | |
@property | |
def train_data(self): | |
return ".".join([self.prefix, IODataType.TRAIN]) | |
@property | |
def test_data(self): | |
return ".".join([self.prefix, IODataType.TEST]) | |
@property | |
def validate_data(self): | |
return ".".join([self.prefix, IODataType.VALIDATE]) | |
@staticmethod | |
def get_all_output(): | |
return [IODataType.TRAIN, | |
IODataType.VALIDATE, | |
IODataType.TEST] | |
class NoLimitOutputData(object): | |
def __init__(self, prefix, output_unit=1): | |
self.prefix = prefix | |
self.output_unit = output_unit | |
@property | |
def data(self): | |
return [self.prefix + "." + "data_" + str(i) for i in range(self.output_unit)] | |
def get_all_output(self): | |
return ["data_" + str(i) for i in range(self.output_unit)] |
所以 Reader 应该是只能单次吞入整个数据集,不可能分批次读入。
查看 Trainer
跟 train 相干的参数都在 TrainerParam
外面。可是 TrainerParam 自身只是个存储参数的包装类,外面没有货色。
最终找到了一个 job submitter 的货色,也是通过传参,调用服务这种模式去做的 Task。这些都是包皮,没有理论的代码。
最初在 federatedml.nn.homo.trainer.fedavg_trainer
里找到 FedAvgTrainer,他外面给了参数,外面有 batch size:
class FedAVGTrainer(TrainerBase): | |
""" | |
Parameters | |
---------- | |
epochs: int >0, epochs to train | |
batch_size: int, -1 means full batch | |
secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number | |
mask to local models. These random number masks will eventually cancel out to get 0. | |
weighted_aggregation: bool, whether add weight to each local model when doing aggregation. | |
if True, According to origin paper, weight of a client is: n_local / n_global, where n_local | |
is the sample number locally and n_global is the sample number of all clients. | |
if False, simply averaging these models. | |
early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between | |
two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol, | |
stop training | |
tol: float, tol value for early stop | |
aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate | |
every n epochs. | |
cuda: bool, use cuda or not | |
pin_memory: bool, for pytorch DataLoader | |
shuffle: bool, for pytorch DataLoader | |
data_loader_worker: int, for pytorch DataLoader, number of workers when loading data | |
validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch. | |
if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision' | |
if is multi classification task, will use metrics 'precision', 'recall', 'accuracy' | |
if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score' | |
checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint. | |
task_type: str, 'auto', 'binary', 'multi', 'regression' | |
this option decides the return format of this trainer, and the evaluation type when running validation. | |
if auto, will automatically infer your task type from labels and predict results. | |
""" |
我本人在 FATE 那里提的 issue:https://github.com/FederatedAI/FATE/issues/4832
最初论断
在 homo 训练,自定义神经网络的场景下应用 FedAvg 训练器可能实现 batch 训练。然而 Reader 是否能加载进来,要看机器,因为 Reader 应该是一次性全副读取的。
正文完