def transform_data(data: List[Tuple[Image, str]]):
"""
作为dataloader的collate_fn,用于预处理函数。
将输入的mnist图片调整大小、归一化后,变为torch.Tensor返回。
"""
xs, ys = [], []
for x, y in data:
xs.append(np.array(x).reshape((1, 28, 28)))
ys.append(int(y))
imgs = torch.tensor(xs)
label = torch.tensor(ys)
imgs = imgs / 255 - 0.5
return imgs, label
class Example(HorizontalLearning):
def __init__(self) -> None:
super().__init__(
name="example", # 任务名称,用于在Deltaboard中的展示
max_rounds=2, # 任务训练的总轮次,每聚合更新一次权重,代表一轮
validate_interval=1, # 验证的轮次间隔,1表示每完成一轮,进行一次验证
validate_frac=0.1, # 验证集的比例,范围(0,1)
strategy=FaultTolerantFedAvg( # 安全聚合的策略,可选策略目前包含 FedAvg和FaultTolerantFedAvg,都位于delta.task.learning包下
min_clients=2, # 算法所需的最少客户端数,至少为2
max_clients=3, # 算法所支持的最大客户端数,必须大雨等于min_clients
merge_epoch=1, # 聚合更新的间隔,merge_interval_epoch表示每多少个epoch聚合更新一次权重
wait_timeout=30, # 等待超时时间,用来控制一轮计算的超时时间
connection_timeout=10 # 连接超时时间,用来控制流程中每个阶段的超时时间
)
)
self.model = LeNet()
self.loss_func = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=0.1,
momentum=0.9,
weight_decay=1e-3,
nesterov=True,
)
def dataset(self) -> delta.dataset.Dataset:
"""
定义任务所需要的数据集。
return: 一个delta.dataset.Dataset
"""
return delta.dataset.Dataset(dataset="mnist")
def make_train_dataloader(self, dataset: Dataset) -> DataLoader:
"""
定义训练集Dataloader,可以对dataset进行各种变换、预处理等操作。
dataset: 训练集的Dataset
return: 训练集的Dataloader
"""
return DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True, collate_fn=transform_data) # type: ignore
def make_validate_dataloader(self, dataset: Dataset) -> DataLoader:
"""
定义验证集Dataloader,可以对dataset进行各种变换、预处理等操作。
dataset: 验证集的Dataset
return: 验证集的Dataloader
"""
return DataLoader(dataset, batch_size=64, shuffle=False, drop_last=False, collate_fn=transform_data) # type: ignore
def train(self, dataloader: Iterable):
"""
训练步骤
dataloader: 训练数据集对应的dataloader
return: None
"""
for batch in dataloader:
x, y = batch
y_pred = self.model(x)
loss = self.loss_func(y_pred, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def validate(self, dataloader: Iterable) -> Dict[str, Any]:
"""
验证步骤,输出验证的指标值
dataloader: 验证集对应的dataloader
return: Dict[str, float],一个字典,键为指标的名称(str),值为对应的指标值(float)
"""
total_loss = 0
count = 0
ys = []
y_s = []
for batch in dataloader:
x, y = batch
y_pred = self.model(x)
loss = self.loss_func(y_pred, y)
total_loss += loss.item()
count += 1
y_ = torch.argmax(y_pred, dim=1)
y_s.extend(y_.tolist())
ys.extend(y.tolist())
avg_loss = total_loss / count
tp = len([1 for i in range(len(ys)) if ys[i] == y_s[i]])
precision = tp / len(ys)
return {"loss": avg_loss, "precision": precision}
def state_dict(self) -> Dict[str, torch.Tensor]:
"""
需要训练、更新的模型参数
在聚合更新、保存结果时,只会更新、保存get_params返回的参数
return: List[torch.Tensor], 模型参数列表
"""
return self.model.state_dict()