No description
  • Jupyter Notebook 66.3%
  • Python 33.7%
Find a file
2026-03-30 19:29:14 +08:00
configs feat: add watch time 2026-03-30 19:29:14 +08:00
eda fix:remove softmax in DIN 2026-03-30 10:23:54 +08:00
kuairand_baseline feat: add watch time 2026-03-30 19:29:14 +08:00
.gitignore init:baseline with din 2026-03-29 21:21:06 +08:00
.python-version init:baseline with din 2026-03-29 21:21:06 +08:00
main.py init:baseline with din 2026-03-29 21:21:06 +08:00
pyproject.toml feat: wandb and gauc support 2026-03-30 12:07:56 +08:00
README.md feat: add watch time 2026-03-30 19:29:14 +08:00
uv.lock feat: wandb and gauc support 2026-03-30 12:07:56 +08:00

KuaiRand-1K DIN Multi-Task Baseline

KuaiRand-1K 标准曝光日志上搭建一个可复现的多任务推荐 baseline联合预测

  • is_click
  • is_like
  • is_forward
  • play_time_ms(可选的 GEM watch-time 塔)

第一版面向简历项目叙事重点是把“数据切分、严格时序、DIN 用户兴趣建模、可切换多任务骨干、完整训练评估链路”做扎实。默认骨干仍然是 MMoE,并额外提供一层标准 PLE 作为可选实验分支。

Pipeline

  1. preprocess
    • 使用 2022-04-08 ~ 2022-04-21 作为历史冷启动窗口。
    • 2022-04-22 ~ 2022-05-08 的标准日志构造训练样本。
    • train: 2022-04-22 ~ 2022-05-04
    • val: 2022-05-05 ~ 2022-05-06
    • test: 2022-05-07 ~ 2022-05-08
    • DIN 行为序列只保留历史 is_click=1video_id,最大长度 50,并额外保存该历史行为的 is_like / is_forward 标记。
    • 标签默认包含 is_click / is_like / is_forward / play_time_ms 四列;未启用 watch-time 塔的配置会自动忽略第 4 列。
    • 预处理输出为按 split 存放的 numpy memmap 文件和 feature_meta.json
  2. train
    • 原生 PyTorch 实现共享 DIN 编码器,并通过配置切换 MMoE 或一层 PLE 多任务骨干。
    • 三任务联合训练;默认 likeforward 使用加权 BCE也提供一个“全部标签等权、关闭正样本重加权”的 MMoE baseline 配置。
    • 可选的 watch-time 版本会额外启用一个 1 个指数分布 + 8 个高斯分布的 GEM 塔,用 NLL + entropy + L1 联合训练。
    • 基于验证集三任务平均 gAUC 早停并保存最佳 checkpoint如果某个 split 上 gAUC 不可用,则自动回退到平均 ROC-AUC
  3. evaluate
    • 输出二分类任务的 ROC-AUCPR-AUCLogLoss、标签基准率。
    • 如果启用 watch-time 塔,额外输出 mae_msrmse_msnll
    • 可生成测试集预测文件。

Project Layout

configs/
  din_mmoe_baseline.toml
  din_mmoe_attention_baseline.toml
  din_mmoe_equal_weight_baseline.toml
  din_mmoe_watch_time_baseline.toml
  din_mmoe_smoke.toml
  din_mmoe_attention_smoke.toml
  din_mmoe_equal_weight_smoke.toml
  din_mmoe_watch_time_smoke.toml
  din_ple_baseline.toml
  din_ple_watch_time_baseline.toml
  din_ple_smoke.toml
  din_ple_watch_time_smoke.toml
kuairand_baseline/
  config.py
  constants.py
  dataset.py
  evaluate.py
  metrics.py
  model.py
  preprocess.py
  torch_utils.py
  train.py
main.py

Install

当前项目尽量减少依赖:

  • pandas
  • numpy
  • torch
  • wandb(可选)

推荐做法:

uv sync

如果你的环境还没有 PyTorch也可以手动安装合适版本后再运行训练和评估。

W&B

项目现在默认开启 Weights & Biases 记录,默认模式是 offline;如果不想使用,可以把配置里的 wandb.enabled = false

在配置里打开:

[wandb]
enabled = true
project = "kuairand"
mode = "online"
watch_model = false
log_model = false

训练时会记录:

  • 每个 epoch 的 train.loss
  • 每个任务的验证集指标,包括 ROC-AUCgAUC
  • 最终 val/test 指标

如果你希望同步到线上面板,把 mode 改成 online 并完成 wandb login

Metrics

当前每个任务都会输出:

  • ROC-AUC
  • gAUC
  • PR-AUC
  • LogLoss
  • base_rate

其中 gAUCuser_id 分组计算,对每个有正负样本同时存在的用户单独计算 AUC再按该用户样本数加权平均。

训练阶段默认使用三任务平均 gAUC 作为 early stopping 指标。

如果启用 watch-time 塔,还会额外输出:

  • mae_ms
  • rmse_ms
  • nll

单独执行 evaluate 时,也会为对应 split 记录一次评估结果。

Run

全量 baseline

python main.py preprocess --config configs/din_mmoe_baseline.toml
python main.py train --config configs/din_mmoe_baseline.toml
python main.py evaluate --config configs/din_mmoe_baseline.toml --split test

全量 MMoE + enriched attention

python main.py preprocess --config configs/din_mmoe_attention_baseline.toml
python main.py train --config configs/din_mmoe_attention_baseline.toml
python main.py evaluate --config configs/din_mmoe_attention_baseline.toml --split test

全量 MMoE + equal-weight loss baseline

python main.py preprocess --config configs/din_mmoe_equal_weight_baseline.toml
python main.py train --config configs/din_mmoe_equal_weight_baseline.toml
python main.py evaluate --config configs/din_mmoe_equal_weight_baseline.toml --split test

全量 MMoE + GEM watch-time

python main.py preprocess --config configs/din_mmoe_watch_time_baseline.toml
python main.py train --config configs/din_mmoe_watch_time_baseline.toml
python main.py evaluate --config configs/din_mmoe_watch_time_baseline.toml --split test

全量 PLE

python main.py preprocess --config configs/din_ple_baseline.toml
python main.py train --config configs/din_ple_baseline.toml
python main.py evaluate --config configs/din_ple_baseline.toml --split test

全量 PLE + GEM watch-time

python main.py preprocess --config configs/din_ple_watch_time_baseline.toml
python main.py train --config configs/din_ple_watch_time_baseline.toml
python main.py evaluate --config configs/din_ple_watch_time_baseline.toml --split test

先做 smoke run

python main.py preprocess --config configs/din_mmoe_smoke.toml
python main.py train --config configs/din_mmoe_smoke.toml

MMoE + enriched attention smoke run

python main.py preprocess --config configs/din_mmoe_attention_smoke.toml
python main.py train --config configs/din_mmoe_attention_smoke.toml

MMoE + equal-weight smoke run

python main.py preprocess --config configs/din_mmoe_equal_weight_smoke.toml
python main.py train --config configs/din_mmoe_equal_weight_smoke.toml

MMoE + watch-time smoke run

python main.py preprocess --config configs/din_mmoe_watch_time_smoke.toml
python main.py train --config configs/din_mmoe_watch_time_smoke.toml

PLE smoke run

python main.py preprocess --config configs/din_ple_smoke.toml
python main.py train --config configs/din_ple_smoke.toml

PLE + watch-time smoke run

python main.py preprocess --config configs/din_ple_watch_time_smoke.toml
python main.py train --config configs/din_ple_watch_time_smoke.toml

Outputs

默认 MMoE 输出目录:artifacts/din_mmoe_baseline/

  • preprocessed/{train,val,test}/
    • sparse.npy
    • dense.npy
    • history.npy
    • labels.npy
    • info.npy
  • preprocessed/feature_meta.json
  • checkpoints/best_model.pt
  • metrics/metrics.json
  • metrics/predictions_test.csv

Features

用户侧:

  • 稀疏特征:user_id、活跃度/范围分桶字段、onehot_feat0-17
  • 稠密特征:follow_user_numfans_user_numfriend_user_numregister_days

物品侧:

  • 稀疏特征:video_idauthor_idmusic_idmusic_typevideo_typeupload_typevisible_statustag
  • 稠密特征:video_durationserver_widthserver_heightupload_age_days
  • 视频统计特征:video_features_statistic_1k.csv 中全部统计列

Modeling Notes

  • video_id 的当前样本 embedding 与历史点击序列 embedding 共用一套参数。
  • baseline MMoEPLE 默认都只使用点击序列中的 video_id 表示做 DIN attention。
  • 可选的 enriched-attention 配置会把 item 静态特征拼进 target/history item 表示,并把历史 item 的 is_like / is_forward 标记送进 attention scorer。
  • 默认 MMoE 使用 4 个 experts3 个 task-specific gates 和 3 个 task towers。
  • 可选 PLE 使用一层标准结构shared experts + task-specific experts + task gates + task towers。
  • watch-time 配置会在共享编码器之后额外打开第 4 个 task route并用 GEM 头预测 play_time_ms 的 mixture distribution。
  • train.loss.positive_weighting = "sqrt_ratio_capped" 是默认口径;如果设为 "none",三个任务仍然等权平均,但不再对 like / forward 的正样本做额外放大。
  • 第一版不引入随机曝光日志,不做对比基线和消融实验。

Known Limits

  • 全量预处理会产生较大的中间文件,建议预留充足磁盘空间。
  • video_idauthor_idmusic_id 都是高基数特征,训练资源需求不低。
  • 当前仓库没有内置实验追踪系统,第一版以可复现脚本和结果文件为主。