- Jupyter Notebook 66.3%
- Python 33.7%
| configs | ||
| eda | ||
| kuairand_baseline | ||
| .gitignore | ||
| .python-version | ||
| main.py | ||
| pyproject.toml | ||
| README.md | ||
| uv.lock | ||
KuaiRand-1K DIN Multi-Task Baseline
在 KuaiRand-1K 标准曝光日志上搭建一个可复现的多任务推荐 baseline,联合预测:
is_clickis_likeis_forwardplay_time_ms(可选的 GEM watch-time 塔)
第一版面向简历项目叙事,重点是把“数据切分、严格时序、DIN 用户兴趣建模、可切换多任务骨干、完整训练评估链路”做扎实。默认骨干仍然是 MMoE,并额外提供一层标准 PLE 作为可选实验分支。
Pipeline
preprocess- 使用
2022-04-08 ~ 2022-04-21作为历史冷启动窗口。 - 对
2022-04-22 ~ 2022-05-08的标准日志构造训练样本。 train: 2022-04-22 ~ 2022-05-04val: 2022-05-05 ~ 2022-05-06test: 2022-05-07 ~ 2022-05-08- DIN 行为序列只保留历史
is_click=1的video_id,最大长度50,并额外保存该历史行为的is_like/is_forward标记。 - 标签默认包含
is_click / is_like / is_forward / play_time_ms四列;未启用 watch-time 塔的配置会自动忽略第 4 列。 - 预处理输出为按 split 存放的
numpy memmap文件和feature_meta.json。
- 使用
train- 原生 PyTorch 实现共享
DIN编码器,并通过配置切换MMoE或一层PLE多任务骨干。 - 三任务联合训练;默认
like和forward使用加权 BCE,也提供一个“全部标签等权、关闭正样本重加权”的MMoEbaseline 配置。 - 可选的 watch-time 版本会额外启用一个
1个指数分布 +8个高斯分布的 GEM 塔,用NLL + entropy + L1联合训练。 - 基于验证集三任务平均
gAUC早停并保存最佳 checkpoint;如果某个 split 上gAUC不可用,则自动回退到平均ROC-AUC。
- 原生 PyTorch 实现共享
evaluate- 输出二分类任务的
ROC-AUC、PR-AUC、LogLoss、标签基准率。 - 如果启用 watch-time 塔,额外输出
mae_ms、rmse_ms、nll。 - 可生成测试集预测文件。
- 输出二分类任务的
Project Layout
configs/
din_mmoe_baseline.toml
din_mmoe_attention_baseline.toml
din_mmoe_equal_weight_baseline.toml
din_mmoe_watch_time_baseline.toml
din_mmoe_smoke.toml
din_mmoe_attention_smoke.toml
din_mmoe_equal_weight_smoke.toml
din_mmoe_watch_time_smoke.toml
din_ple_baseline.toml
din_ple_watch_time_baseline.toml
din_ple_smoke.toml
din_ple_watch_time_smoke.toml
kuairand_baseline/
config.py
constants.py
dataset.py
evaluate.py
metrics.py
model.py
preprocess.py
torch_utils.py
train.py
main.py
Install
当前项目尽量减少依赖:
pandasnumpytorchwandb(可选)
推荐做法:
uv sync
如果你的环境还没有 PyTorch,也可以手动安装合适版本后再运行训练和评估。
W&B
项目现在默认开启 Weights & Biases 记录,默认模式是 offline;如果不想使用,可以把配置里的 wandb.enabled = false。
在配置里打开:
[wandb]
enabled = true
project = "kuairand"
mode = "online"
watch_model = false
log_model = false
训练时会记录:
- 每个 epoch 的
train.loss - 每个任务的验证集指标,包括
ROC-AUC和gAUC - 最终
val/test指标
如果你希望同步到线上面板,把 mode 改成 online 并完成 wandb login。
Metrics
当前每个任务都会输出:
ROC-AUCgAUCPR-AUCLogLossbase_rate
其中 gAUC 按 user_id 分组计算,对每个有正负样本同时存在的用户单独计算 AUC,再按该用户样本数加权平均。
训练阶段默认使用三任务平均 gAUC 作为 early stopping 指标。
如果启用 watch-time 塔,还会额外输出:
mae_msrmse_msnll
单独执行 evaluate 时,也会为对应 split 记录一次评估结果。
Run
全量 baseline:
python main.py preprocess --config configs/din_mmoe_baseline.toml
python main.py train --config configs/din_mmoe_baseline.toml
python main.py evaluate --config configs/din_mmoe_baseline.toml --split test
全量 MMoE + enriched attention:
python main.py preprocess --config configs/din_mmoe_attention_baseline.toml
python main.py train --config configs/din_mmoe_attention_baseline.toml
python main.py evaluate --config configs/din_mmoe_attention_baseline.toml --split test
全量 MMoE + equal-weight loss baseline:
python main.py preprocess --config configs/din_mmoe_equal_weight_baseline.toml
python main.py train --config configs/din_mmoe_equal_weight_baseline.toml
python main.py evaluate --config configs/din_mmoe_equal_weight_baseline.toml --split test
全量 MMoE + GEM watch-time:
python main.py preprocess --config configs/din_mmoe_watch_time_baseline.toml
python main.py train --config configs/din_mmoe_watch_time_baseline.toml
python main.py evaluate --config configs/din_mmoe_watch_time_baseline.toml --split test
全量 PLE:
python main.py preprocess --config configs/din_ple_baseline.toml
python main.py train --config configs/din_ple_baseline.toml
python main.py evaluate --config configs/din_ple_baseline.toml --split test
全量 PLE + GEM watch-time:
python main.py preprocess --config configs/din_ple_watch_time_baseline.toml
python main.py train --config configs/din_ple_watch_time_baseline.toml
python main.py evaluate --config configs/din_ple_watch_time_baseline.toml --split test
先做 smoke run:
python main.py preprocess --config configs/din_mmoe_smoke.toml
python main.py train --config configs/din_mmoe_smoke.toml
MMoE + enriched attention smoke run:
python main.py preprocess --config configs/din_mmoe_attention_smoke.toml
python main.py train --config configs/din_mmoe_attention_smoke.toml
MMoE + equal-weight smoke run:
python main.py preprocess --config configs/din_mmoe_equal_weight_smoke.toml
python main.py train --config configs/din_mmoe_equal_weight_smoke.toml
MMoE + watch-time smoke run:
python main.py preprocess --config configs/din_mmoe_watch_time_smoke.toml
python main.py train --config configs/din_mmoe_watch_time_smoke.toml
PLE smoke run:
python main.py preprocess --config configs/din_ple_smoke.toml
python main.py train --config configs/din_ple_smoke.toml
PLE + watch-time smoke run:
python main.py preprocess --config configs/din_ple_watch_time_smoke.toml
python main.py train --config configs/din_ple_watch_time_smoke.toml
Outputs
默认 MMoE 输出目录:artifacts/din_mmoe_baseline/
preprocessed/{train,val,test}/sparse.npydense.npyhistory.npylabels.npyinfo.npy
preprocessed/feature_meta.jsoncheckpoints/best_model.ptmetrics/metrics.jsonmetrics/predictions_test.csv
Features
用户侧:
- 稀疏特征:
user_id、活跃度/范围分桶字段、onehot_feat0-17 - 稠密特征:
follow_user_num、fans_user_num、friend_user_num、register_days
物品侧:
- 稀疏特征:
video_id、author_id、music_id、music_type、video_type、upload_type、visible_status、tag - 稠密特征:
video_duration、server_width、server_height、upload_age_days - 视频统计特征:
video_features_statistic_1k.csv中全部统计列
Modeling Notes
video_id的当前样本 embedding 与历史点击序列 embedding 共用一套参数。- baseline
MMoE与PLE默认都只使用点击序列中的video_id表示做 DIN attention。 - 可选的 enriched-attention 配置会把 item 静态特征拼进 target/history item 表示,并把历史 item 的
is_like/is_forward标记送进 attention scorer。 - 默认
MMoE使用 4 个 experts,3 个 task-specific gates 和 3 个 task towers。 - 可选
PLE使用一层标准结构:shared experts + task-specific experts + task gates + task towers。 - watch-time 配置会在共享编码器之后额外打开第 4 个 task route,并用 GEM 头预测
play_time_ms的 mixture distribution。 train.loss.positive_weighting = "sqrt_ratio_capped"是默认口径;如果设为"none",三个任务仍然等权平均,但不再对like/forward的正样本做额外放大。- 第一版不引入随机曝光日志,不做对比基线和消融实验。
Known Limits
- 全量预处理会产生较大的中间文件,建议预留充足磁盘空间。
video_id、author_id、music_id都是高基数特征,训练资源需求不低。- 当前仓库没有内置实验追踪系统,第一版以可复现脚本和结果文件为主。