非对角状态空间模型循环神经网络,并行计算,无需稳定化。
A non-diagonal SSM RNN computed in parallel without requiring stabilization

原始链接: https://github.com/glassroom/goom_ssm_rnn

这项工作提出了一种新的深度循环神经网络(RNN)实现,利用广义数量级(GOOMs)和非对角线性状态空间模型(SSM),以提高性能和可扩展性。其核心创新在于使用复数类型的GOOMs,能够在无需稳定化技术的情况下并行计算非对角递归——这是RNN中常见的挑战。 该模型使用PyTorch构建,专为生成语言建模(使用GPT-2的词汇表)等任务设计,并在Sequential MNIST和Wikitext-103等多个基准测试中进行了测试。虽然由于有限的复数张量支持,目前无法完全使用PyTorch编译,但部分编译可以显著提高速度和内存使用率。 主要特性包括用于梯度反向传播的自定义自动微分函数,以及用于参数分组、损失计算和token生成的各种方法。训练采用AdamW优化和循环学习率调度等标准技术,结果显示性能具有竞争力,并具有在更大数据集上扩展的潜力。代码可供克隆和实验,详情请参见`goom_ssm_rnn.py`文件。

一种新的状态空间模型(SSM)循环神经网络(RNN)方法,允许并行计算,*且*没有通常的稳定性问题。该方法在Hacker News上分享,其核心创新并非RNN本身,而是一篇最新论文中详细介绍的全新对数数字系统 ([https://arxiv.org/abs/2510.03426](https://arxiv.org/abs/2510.03426))。 该系统擅长处理因多次乘法而产生的大型或小型数字的计算——这不仅适用于深度学习。它利用现有的复数实现来提高效率,并包含优化的操作,如点积和矩阵乘积(使用PyTorch的稳定化函数)。 然而,该系统是专门的;它最适合乘法链,而对于严重依赖加法的运算,标准浮点数仍然更可取。PyTorch实现可在GitHub上找到 ([github.com/glassroom](github.com/glassroom))。
相关文章

原文

Reference implementation of a deep RNN that captures dependencies with a non-diagonal linear state-space model (SSM) over our implementation of generalized orders of magnitude (GOOMs), allowing recurrent states to fluctuate freely over a greater dynamic range of real values than previously possible, enabling computation of non-diagonal recurrences in parallel, via a prefix scan, without any form of stabilization.

  1. Clone this repository.

  2. Install the Python dependencies in requirements.txt.

  3. There is no third step.

The following code instantiates a small RNN for generative language modeling tasks with GPT-2's vocabulary:

import torch
import tiktoken
import goom_ssm_rnn

DEVICE = 'cuda'  # change as needed

# Get GPT-2 encoder:
enc = tiktoken.get_encoding('gpt2')

# Instantiate an RNN for natural language generation:
model = goom_ssm_rnn.GenerativeRNN(
    vocab_sz=enc.n_vocab, d_emb=768, n_hid=24, d_hid=32, n_res=24)

# Move model to cuda device:
model.to(device=DEVICE)

# You must provide your own training code.

Use of Complex-Typed GOOMs

Recurrent layers in the model capture sequential dependencies with a non-diagonal linear SSM, executed via a parallel prefix scan, over GOOMs, implemented as torch.complex64 tensors (i.e., with torch.float32 real and imaginary components). As we explain in our paper, the use of complex-typed GOOMs makes it possible for each layer to compute non-diagonal recurrent states in parallel without requiring any form of stabilization.

Otherwise, the rest of the model operates conventionally, over torch.float32 tensors, optionally autocasting to torch.float16, if you specify it. As we explain in our paper, each recurrent layer scales complex-typed GOOMs before exponentiating them to torch.float32 real tensors, because the GOOM magnitudes can be outside bounds representable by torch.float32.

Besides the standard PyTorch forward() method, the model provides three additional methods, for convenience:

  • model.get_param_groups(), which accepts a scalar weight_decay value as input, and returns two parameter groups for training, one with weight decay and one without without decay.

  • model.compute_loss_and_metrics(), which accepts predicted scores over the model's vocabulary, and true token ids, and returns a cross-entropy loss and a dictionary with one metric: 'accuracy'.

  • model.generate(), for generating new token ids, given a sequence of preceding token ids, after the model has been trained on a language-generation task. Please see our code for additional arguments.

Training and Testing the Model

We have implemented the model as a standard PyTorch nn.Module that you can train and test on any task, using conventional techniques, including autocasting. However, at present the model can be only partially compiled, because PyTorch's compiler doesn't yet fully support complex tensors. For information on the current state of PyTorch's support for complex tensors, please see this page on the PyTorch website.

When we apply torch.compile() to the entire model and start training it, lazy compilation spits out a variety of warnings related to the lack of support of complex tensors, but compilation succeeds -- and significantly reduces execution time and memory use. Our implementation of GOOMs incorporates custom torch.Autograd.function transformations under-the-hood to ensure proper backpropagation of gradients, taking special care to handle the singularity at zero gracefully. (As a real number approaches zero, the real component of its complex logarithm approaches negative infinity.)

Note: We have tested autocasting of float tensors only to torch.float16.

Replicating Published Results

We trained the RNN model in this repository on natural language generation and multiple other tasks.

Natural Language Generation

We trained an instance of the RNN with 768 embedding dimensions (d_emb=768), 24 heads per token (n_hid=24), 32 features per head (d_hid=32), 24 recurrent residual layers (n_res=24), and GPT-2 vocabulary, resulting in 124M parameters, on approximately 10B tokens randomly sampled from The Pile, with a sequence length of 1024 tokens. We trained the RNN with the hyper-parameters shown on the table below. Cross-entropy loss declined to approximately 2.7 after training on 10B tokens. For comparison, cross-entropy for state-of-the-art models of comparable size, trained on 30x or more tokens sampled from higher-quality datasets, is approximately 2.4, suggesting our RNN model can be scaled up to larger tasks.

Hyper-parameter Value
Batch size 960 sequences, split in micro-batches that accumulate gradients
Micro-batch size Largest integer factor of 1000 that fits in GPU memory
Optimizer AdamW, using torch.optim.AdamW
Weight decay 1e-1
Parameter groups 2, obtained with model.get_param_groups(weight_decay=1e-1)
Learning rate schedule One cycle, using torch.optim.lr_scheduler.OneCycleLR
Maximum learning rate 3e-4
Ending learning rate 1e-5
Maximum momentum 0.99
Minimum momentum 0.85
Warm-up period 10 batches (9600 sample sequences)
Compilation Yes (applies only to operations on floats, not complex GOOMs)
Autocasting Yes, to torch.float16 (only floats, not complex GOOMs)
Training iterations 10240 batches
Cumulative tokens 10B (1024 tokens/sequence x 960 sequences/batch x 10240 batches)

Other tasks include Sequential MNIST generation (unrolling the images into sequences of 784 pixel-tokens, using a vocabulary size of 256 gray levels, and generating each next pixel), Sequential MNIST classification (replacing the generative-language-modeling head with a linear-classification head that predicts 10 classes from the last pixel-token's hidden state), Wikitext-103 (using the GPT-2 vocabulary for convenience), and Copy-Memory tasks. For all such tasks, we instantiated the RNN with 512 embedding dimensions (d_emb=512), 16 heads per token (n_hid=16), 32 features per head (d_hid=32), eight residual recurrent layers (n_res=8), a task-specific vocabulary, and a task-specific model head, resulting in 12.8M to 38M parameters. We trained all models with the hyper-parameters shown on the table below. The models trained to competitive performance on all tasks we tested.

Hyper-parameter Value
Batch size 1000, split in micro-batches that accumulate gradients
Micro-batch size Largest integer factor of 1000 that fits in GPU memory
Optimizer AdamW, using torch.optim.AdamW
Weight decay 1e-1
Parameter groups 2, obtained with model.get_param_groups(weight_decay=1e-1)
Learning rate schedule One cycle, using torch.optim.lr_scheduler.OneCycleLR
Maximum learning rate 3e-4
Ending learning rate 1e-5
Maximum momentum 0.99
Minimum momentum 0.85
Warm-up period 10 batches (10,000 samples)
Compilation Yes (applies only to operations on floats, not complex GOOMs)
Autocasting Yes, to torch.float16 (only floats, not complex GOOMs)
Data augmentation Yes, conventional (e.g., affine transforms on training images)
Training iterations At least 1,800 (1.8M samples); harder tasks require more samples

Modifying the RNN for Other Tasks

You can modify or replace the model's language-modeling head, as needed, for tasks other than generative language modeling. All model components are defined in a single file:

goom_ssm_rnn.py

@article{
heinsen2025generalized,
title={Generalized Orders of Magnitude for Scalable, Parallel, High-Dynamic-Range Computation},
author={Franz A. Heinsen and Leo Kozachkov},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=SUuzb0SOGu},
note={}
}

The work here originated with casual conversations over email between us, the authors, in which we wondered if it might be possible to find a succinct expression for computing non-diagonal linear recurrences in parallel, by mapping them to the complex plane. Our casual conversations gradually evolved into the development of generalized orders of magnitude, along with an algorithm for estimating Lyapunov exponents in parallel, and a novel method for selectively resetting interim states in a parallel prefix scan.

We hope others find our work and our code useful.

联系我们 contact @ memedata.com