找回密码
 立即注册
首页 业界区 安全 PLE模型简洁解读

PLE模型简洁解读

公西颖初 2025-8-7 18:25:09
PLE模型简洁解读

1.png

基础设定


  • 有 2 个任务:CTR、CVR
  • 使用 1 层 PLE(num_levels = 1)
  • 每个任务 2 个任务特定专家(specific_expert_num = 2)
  • 有 1 个共享专家(shared_expert_num = 1)
  • 输入 embedding 是:[batch_size, 64] 的拼接向量
我们来看看“这一层”里的每一个步骤数据是如何流动的。
第 1 步:准备输入
  1. ple_inputs = [x_ctr, x_cvr, x_shared]
复制代码

  • x_ctr = CTR 的输入 = 原始 embedding 向量 [B, 64]
  • x_cvr = CVR 的输入 = 同上
  • x_shared = Shared 的输入 = 同上
注意:这三个向量在第 1 层是一样的,但在后续层会变得不同。
第 2 步:任务专家和共享专家网络

每个任务的 experts:

每个任务有 2 个 specific expert,输入是自己:

  • CTR 的两个专家 → 输入 x_ctr → 输出 [B, 64]
  • CVR 的两个专家 → 输入 x_cvr → 输出 [B, 64]
共享 experts:

只有 1 个共享专家,输入是 x_shared,输出 [B, 64]
第 3 步:Gate 网络

我们看 CTR 任务的 gate 是怎么处理的:
CTR 的 gate 做了什么?


  • 输入: x_ctr → shape [B, 64]
  • 过一个小 DNN: 输出变为 [B, H](比如 H=32)
  • 线性变换 + softmax: 输出为 [B, 3],表示对 3 个专家的权重:

    • expert_1_ctr
    • expert_2_ctr
    • expert_shared

  1. gate_input = DNN(...)(x_ctr)   # [B, 32]
  2. gate_weights = Dense(3, activation='softmax')(gate_input)  # [B, 3]
复制代码
第 4 步:Gate × Experts

将所有专家输出堆叠:
  1. expert_outputs = tf.stack([expert_1_ctr, expert_2_ctr, expert_shared], axis=1)  # [B, 3, 64]
复制代码
将 gate 权重 reshape:
  1. gate_weights = tf.expand_dims(gate_weights, -1)  # [B, 3, 1]
复制代码
点乘加权求和:
  1. fused_output = tf.reduce_sum(expert_outputs * gate_weights, axis=1)  # [B, 64]
复制代码
✅ 这就是 CTR 任务在这一层提取到的特征,来自自己和共享专家的动态组合。
CVR 任务也完全一样,只是换成用 x_cvr 输入,构建自己的 gate 和专家融合。
下一层(若存在):

然后这些输出(fused_output_ctr, fused_output_cvr, fused_output_shared)会作为下一层的输入,继续重复这一机制。
每一层都会重新生成:

  • 专家网络(不同任务分开)
  • gate(使用该层输入为条件)
从而实现「逐层提纯」。
Gate 的本质:

Gate 是一个小的 DNN 网络,输入是当前任务的 embedding,输出是对所有专家的 softmax 权重
决定了“这个任务现在要听谁的话”
PLE的pytorch实现
  1. class Expert(nn.Module):
  2.     def __init__(self, input_dim, expert_dim):
  3.         super(Expert, self).__init__()
  4.         self.layer = nn.Sequential(
  5.             nn.Linear(input_dim, expert_dim),
  6.             nn.ReLU(),
  7.             nn.BatchNorm1d(expert_dim),
  8.             nn.Dropout(0.2)
  9.         )
  10.     def forward(self, x):
  11.         return self.layer(x)
  12. class Gate(nn.Module):
  13.     def __init__(self, input_dim, n_experts):
  14.         super(Gate, self).__init__()
  15.         self.gate = nn.Sequential(
  16.             nn.Linear(input_dim, n_experts),
  17.             nn.Softmax(dim=-1)
  18.         )
  19.     def forward(self, x):
  20.         weights = self.gate(x)  # [B, n_experts]
  21.         return weights.unsqueeze(-1)  # [B, n_experts, 1]
  22. class PLELayer(nn.Module):
  23.     def __init__(self, input_dim, expert_dim, n_tasks, n_task_experts, n_shared_experts):
  24.         super(PLELayer, self).__init__()
  25.         self.n_tasks = n_tasks
  26.         self.task_experts = nn.ModuleList([
  27.             nn.ModuleList([Expert(input_dim, expert_dim) for _ in range(n_task_experts)])
  28.             for _ in range(n_tasks)
  29.         ])
  30.         self.shared_experts = nn.ModuleList([
  31.             Expert(input_dim, expert_dim) for _ in range(n_shared_experts)
  32.         ])
  33.         self.task_gates = nn.ModuleList([
  34.             Gate(input_dim, n_task_experts + n_shared_experts)
  35.             for _ in range(n_tasks)
  36.         ])
  37.         self.shared_gate = Gate(input_dim, n_tasks * n_task_experts + n_shared_experts)
  38.     def forward(self, task_inputs, shared_input):
  39.         # Compute expert outputs
  40.         task_outputs = []
  41.         for i in range(self.n_tasks):
  42.             task_outputs.append([expert(task_inputs[i]) for expert in self.task_experts[i]])
  43.         shared_outputs = [expert(shared_input) for expert in self.shared_experts]
  44.         # Task-specific gate outputs
  45.         next_task_inputs = []
  46.         for i in range(self.n_tasks):
  47.             all_expert_outputs = task_outputs[i] + shared_outputs
  48.             stacked = torch.stack(all_expert_outputs, dim=1)  # [B, n_experts, D]
  49.             weights = self.task_gates[i](task_inputs[i])      # [B, n_experts, 1]
  50.             fused = torch.sum(stacked * weights, dim=1)       # [B, D]
  51.             next_task_inputs.append(fused)
  52.         # Shared gate output (for next layer's shared input)
  53.         flat_all_experts = sum(task_outputs, []) + shared_outputs
  54.         stacked_shared = torch.stack(flat_all_experts, dim=1)
  55.         shared_weights = self.shared_gate(shared_input)
  56.         next_shared_input = torch.sum(stacked_shared * shared_weights, dim=1)  # [B, D]
  57.         return next_task_inputs, next_shared_input
  58. class PLE(nn.Module):
  59.     # 正确处理多层维度
  60.     def __init__(self, input_dim, expert_dim, n_tasks=3, n_layers=2,
  61.                  n_task_experts=2, n_shared_experts=1):
  62.         super(PLE, self).__init__()
  63.         self.n_tasks = n_tasks
  64.         self.ple_layers = nn.ModuleList()
  65.         
  66.         # 为每一层设置正确的输入维度
  67.         for layer_idx in range(n_layers):
  68.             if layer_idx == 0:
  69.                 # 第一层:使用原始输入维度
  70.                 current_input_dim = input_dim
  71.             else:
  72.                 # 后续层:使用expert输出维度作为输入
  73.                 current_input_dim = expert_dim
  74.                
  75.             self.ple_layers.append(
  76.                 PLELayer(
  77.                     input_dim=current_input_dim,  # 动态设置输入维度
  78.                     expert_dim=expert_dim,
  79.                     n_tasks=n_tasks,
  80.                     n_task_experts=n_task_experts,
  81.                     n_shared_experts=n_shared_experts
  82.                 )
  83.             )
  84.     def forward(self, x):
  85.         # Initial input: shared across all tasks and shared experts
  86.         task_inputs = [x for _ in range(self.n_tasks)]
  87.         shared_input = x
  88.         for layer in self.ple_layers:
  89.             task_inputs, shared_input = layer(task_inputs, shared_input)
  90.         return task_inputs  # final task-specific vectors [task1_repr, task2_repr, task3_repr]
复制代码
来源:豆瓜网用户自行投稿发布,如果侵权,请联系站长删除

相关推荐

您需要登录后才可以回帖 登录 | 立即注册