2026/4/17 14:21:26
网站建设
项目流程
如何给网站开发挂,想找个专业做网站公司,饰品行业网站开发,cdn网站#x1f6e0;️ 动手实战#xff1a;环境配置 代码实现 避坑指南 #x1f3af; 目标#xff1a;抛开晦涩的公式#xff0c;手把手教你在自己的电脑上搭建并运行第一个联邦学习模拟系统 #x1f4a1; 核心#xff1a;从安装软件到编写“数据切分、客户端训练、服务器聚…️动手实战环境配置 代码实现 避坑指南目标抛开晦涩的公式手把手教你在自己的电脑上搭建并运行第一个联邦学习模拟系统核心从安装软件到编写“数据切分、客户端训练、服务器聚合”的全流程 目录1. 准备工作硬件与软件环境配置2. 实战思维如何在单机上模拟联邦3. 步骤一数据准备与切分 (Data Splitting)4. 步骤二定义共享模型与客户端逻辑5. 步骤三定义服务器逻辑 (FedAvg)6. 步骤四组合运行与结果展示1. 准备工作硬件与软件环境配置在写代码之前我们先确保你的电脑哪怕是普通笔记本已经准备好了“厨房”。1.1 硬件要求对于本教程的 MNIST手写数字案例无需昂贵的显卡。硬件最低要求推荐配置说明CPU任意双核i5 / R5 及以上代码量小CPU 运行更方便无需配置显卡驱动。内存8 GB16 GB联邦学习需在内存中暂存多个模型副本内存过小易报错。1.2 软件安装 (Anaconda PyTorch)为了避免复杂的环境配置我们采用最稳妥的方案安装 Anaconda(你的工具箱)访问 Anaconda 官网 下载安装。安装后打开Anaconda Prompt(黑色终端窗口)。创建虚拟环境与安装 PyTorch(你的发动机)在终端中依次输入以下指令每行输入后按回车# 1. 创建一个名为 fl_demo 的环境conda create -n fl_demopython3.9-y# 2. 激活环境 (最重要的一步看到左侧括号变了才算成功)conda activate fl_demo# 3. 安装 PyTorch (CPU版) 和 Jupyter (编辑器)pipinstalltorch torchvision numpy jupyter启动编辑器在终端输入jupyter notebook浏览器会自动弹出一个网页点击右上角New-Python 3即可开始写代码。2. 实战思维如何在单机上模拟联邦在真实世界中联邦学习涉及 10 台手机和 1 台服务器通过网络传输。但在学习阶段我们在一台电脑上通过for 循环来模拟这个过程。数据隔离模拟我们将一份大数据集强行切分成 10 份分配给 10 个变量假装它们互不通气。通信模拟变量之间的赋值server_model client_model代替了网络传输。3. 步骤一数据准备与切分 (Data Splitting)我们需要编写代码将 MNIST 数据集切分给 N 个客户端。操作将以下代码复制到 Jupyter 的第一个单元格并运行。importtorchfromtorchvisionimportdatasets,transformsfromtorch.utils.dataimportDataLoader,Subsetimportcopydefget_dataset(num_clients5): 下载并切分 MNIST 数据集 # 1. 下载数据# 注意如果网络报错请将 downloadTrue 改为 False并手动下载数据集到 ./data 目录train_datadatasets.MNIST(root./data,trainTrue,downloadTrue,transformtransforms.ToTensor())# 2. 模拟数据切分 (IID切分每个用户拿到的数据分布相似)data_lenlen(train_data)indiceslist(range(data_len))split_sizedata_len//num_clients client_loaders[]foriinrange(num_clients):# 截取属于第 i 个用户的数据索引subset_indicesindices[i*split_size:(i1)*split_size]subsetSubset(train_data,subset_indices)# 封装成 DataLoaderloaderDataLoader(subset,batch_size32,shuffleTrue)client_loaders.append(loader)returnclient_loaders# 测试一下client_loadersget_dataset(num_clients5)print(f数据准备完毕成功创建{len(client_loaders)}个客户端数据源。)4. 步骤二定义共享模型与客户端逻辑所有参与方必须使用相同的神经网络结构。客户端负责“接收模型 - 训练 - 返回参数”。操作复制到第二个单元格并运行。importtorch.nnasnnimporttorch.nn.functionalasF# --- 1. 定义网络结构 ---classSimpleCNN(nn.Module):def__init__(self):super(SimpleCNN,self).__init__()self.conv1nn.Conv2d(1,10,kernel_size5)self.conv2nn.Conv2d(10,20,kernel_size5)self.fcnn.Linear(320,10)defforward(self,x):xF.relu(F.max_pool2d(self.conv1(x),2))xF.relu(F.max_pool2d(self.conv2(x),2))xx.view(-1,320)xself.fc(x)returnx# --- 2. 定义客户端 (Client) ---classClient:def__init__(self,client_id,data_loader,devicecpu):self.client_idclient_id self.data_loaderdata_loader self.devicedevice self.modelSimpleCNN().to(self.device)deflocal_train(self,global_weights,epochs1):# 加载服务器发来的参数self.model.load_state_dict(global_weights)# 本地训练 (常规的 PyTorch 训练流程)optimizertorch.optim.SGD(self.model.parameters(),lr0.01,momentum0.5)self.model.train()forepochinrange(epochs):fordata,targetinself.data_loader:data,targetdata.to(self.device),target.to(self.device)optimizer.zero_grad()outputself.model(data)lossF.cross_entropy(output,target)loss.backward()optimizer.step()# 关键只返回参数 (state_dict)不返回数据returncopy.deepcopy(self.model.state_dict())5. 步骤三定义服务器逻辑 (FedAvg)服务器通过FedAvg (联邦平均算法)将收集到的参数进行加权平均。操作复制到第三个单元格并运行。classServer:def__init__(self,devicecpu):self.global_modelSimpleCNN().to(device)self.devicedevicedefaggregate(self,client_weights_list): FedAvg 核心对参数取平均 # 拿出第一个客户端的参数作为基准avg_weightscopy.deepcopy(client_weights_list[0])# 逐层累加其他客户端的参数forkeyinavg_weights.keys():foriinrange(1,len(client_weights_list)):avg_weights[key]client_weights_list[i][key]# 取平均值avg_weights[key]torch.div(avg_weights[key],len(client_weights_list))# 更新全局模型self.global_model.load_state_dict(avg_weights)defget_weights(self):returnself.global_model.state_dict()6. 步骤四组合运行与结果展示这是最激动人心的时刻我们将启动训练循环。操作复制到第四个单元格并运行。# --- 初始化 ---devicetorch.device(cpu)serverServer(device)clients[Client(i,client_loaders[i],device)foriinrange(5)]# 5个客户端print(启动)# --- 主循环 (3轮为例) ---forround_idxinrange(3):print(f\n--- Round{round_idx1}---)# 1. 服务器下发参数global_weightsserver.get_weights()client_updates[]# 2. 客户端并行训练forclientinclients:w_localclient.local_train(global_weights,epochs1)client_updates.append(w_local)print(fClient{client.client_id}已上传参数)# 3. 服务器聚合server.aggregate(client_updates)print(Server 完成参数聚合 (FedAvg))print(\n训练结束全局模型已更新。)预期输出如果一切顺利你将看到如下输出 联邦学习系统启动... --- Round 1 --- Client 0 已上传参数 ... Server 完成参数聚合 (FedAvg) --- Round 2 --- ... 训练结束全局模型已更新。常见报错 (Troubleshooting)ModuleNotFoundError说明环境没激活。请检查命令行左侧是否有(fl_demo)字样。HTTP Error 503MNIST 下载失败。请检查网络或手动下载数据集放入data文件夹。RuntimeError: CUDA error请确保代码中写的是device cpu。祝你天天开心我将更新更多有意思的内容欢迎关注最后更新2026年1月作者Echo