Ijepa中特征归一化的作用。
The Role of Feature Normalization in Ijepa

原始链接: https://github.com/theAdamColton/elucidating-featurenorm-ijepa

这项研究使用 Vision Transformer (ViT-Small,约 3 亿参数) 来处理图像块和“注册”令牌(代表非图像数据)的任务。代码使用 `uv` 进行依赖管理,并且数据集需要约 100GB 的存储空间。 训练,使用 320 的批次大小,大约需要 116 小时和 22GB 的 VRAM。可以通过命令行参数使用 `uv run main.py` 来支持恢复训练、评估(包括 IN1k 验证和可视化)以及损失函数绘图。 关键实现细节包括一个独特的 `token_id` 系统,用于跟踪图像块的来源(注册、样本、高度、宽度),因为每个批次处理多个样本/不同分辨率。至关重要的是,在评估之前*必须*调用 `model.eval()`,以确保特殊层的正确行为。 LiDAR 分数是在随机训练子集上计算的,并且在恢复训练时可能会有所变化。支持单 GPU 训练,并提供可选的优化(PILLOW-SIMD、TOME、替代位置嵌入/归一化),但有些会显示性能下降。添加注册令牌显著降低了性能,这是一个目前未解决的问题。

黑客新闻 新 | 过去 | 评论 | 提问 | 展示 | 招聘 | 提交 登录 Ijepa 中特征归一化的作用 (github.com/theadamcolton) 3 点赞 bigonion 1 小时前 | 隐藏 | 过去 | 收藏 | 讨论 指南 | 常见问题 | 列表 | API | 安全 | 法律 | 申请 YC | 联系 搜索:
相关文章

原文

[arxiv]

We use uv for dependency management.

Download the training datasets and NYU-Depth tar files: uv run download_dataset.py

This requires roughly 100GB of storage space.

Run the default training configuration which trains a ~300m parameter ViT-Small with a patch size of 16 and a batch size of 320. This consumes ~22GB of VRAM and takes 116 hours (assuming validation logging is turned off): uv run main.py --config conf/small.yaml

Or resume a training run: uv run main.py --config /path/to/checkpoint/config.yaml --conf.resume_checkpoint_path /path/to/checkpoint/checkpointfile.pt

Or evaluate the IN1k validation performance of a pretrained model: uv run main.py --config /path/to/checkpoint/config.yaml --conf.resume_checkpoint_path /path/to/checkpoint/checkpointfile.pt --conf.mode validate

Or visualize features of a pretrained model: uv run main.py --config /path/to/checkpoint/config.yaml --conf.resume_checkpoint_path /path/to/checkpoint/checkpointfile.pt --conf.mode visualize-embeddings

Or plot the losses of a pretrained model: uv run main.py --config /path/to/checkpoint/config.yaml --conf.resume_checkpoint_path /path/to/checkpoint/checkpointfile.pt --conf.mode plot-sample-losses

Run tests: uv run python -m unittest

The code refers to token_ids this is a LongTensor that contains 4 integers for each token: register id, sample id, height id, width id. Register ID refers to the index of the register, if this patch is a register and does not contain image data, or a MASK_TOKEN_ID. Sample ID refers to the unique index of the sample that this patch/register comes from. Height ID refers to the index of this patch into the patched image, or MASK_TOKEN_ID if this token is a register. Width ID refers to the index of this patch into the patched image, or MASK_TOKEN_ID if this token is a register.

We need to keep track of these IDs because unlike most ViT models, our model processes one or more samples per batch element. Our model processes batches that contains patches from many images of varied resolution.

Unlike many transformer model's, pytorch's eval mode will effect our model's forward. Calling eval() will cause the DiffMOEMLP layers to use dynamic allocation causing the number of allocated experts to be determined by the capacity predictor. Make sure to call model.eval() before doing any evaluation. For training, use train mode.

Our LiDAR score is computed from a random subset of the training data. This subset is random, so if you resume a run you may observe a change in the LiDAR score.

  • This code only supports single-gpu training.

  • You can optionally install PILLOW-SIMD for a boost in dataloading speed.

  • You should probably disable LiDAR score logging if you have limited system RAM.

  • You can enable TOME for the encoder and predictor. We only tested this breifly and observed a distinct performance decline.

  • You can use absolute factorized learnable position embeddings instead of ROPE2D. In a short test we found this decreases performance very slightly

  • The predictor can be trained without token dropping and without batch repeat. We found this drastically decreases downstream performance.

  • You can add register tokens to the encoder and to the predictor. The encoder's register tokens can be passed unchanged to the predictor, or be wiped. We found that adding 8 register tokens dramatically reduced downstream performance and leave it as an open problem as to why register tokens decrease performance by so much.

  • You can choose a feature normalization mode other than LN and DynTanh. We have batchnorm, disabled, and running batchnorm.

联系我们 contact @ memedata.com