SequenceLayers: Streaming made simple.
2019 - present tags: google machine-learning open-sourceIn October 2019, I created SequenceLayers to make the process of building streaming neural networks easier. In 2023 we open-sourced the library, and in 2025 we published a tech report to explain the rationale behind the design decisions in the library.
SequenceLayers enables declarative definitions of sequence processing models that can be processed in a layer-wise fashion or in a block-by-block fashion over the time dimension, producing identical results in each mode. It is akin to Keras, which also has a declarative API, but supports streaming and sequence modeling as a first-class feature.
Here is an example of a decoder-only Transformer block in SequenceLayers.
import jax
import numpy as np
import sequence_layers.jax as sl
config = sl.Serial.Config([
# Self attention.
sl.Residual.Config([
sl.RMSNormalization.Config(name='pre_norm'),
sl.DotProductSelfAttention.Config(
num_heads=16,
units_per_head=64,
# Global causal attention.
max_past_horizon=-1,
max_future_horizon=0,
name='self_attention'
),
sl.DenseShaped.Config([d_model], name='output_projection),
sl.RMSNormalization.Config(name='post_norm'),
sl.Dropout.Config(dropout_rate),
], name='attention_block'),
# Gated GeLU FFN.
sl.Residual.Config([
sl.RMSNormalization.Config(name='pre_norm'),
sl.Dense.Config(4 * d_model, name='dense1'),
sl.GatedUnit.Config(jax.nn.gelu, None),
sl.Dense.Config(d_model, name='dense2'),
sl.RMSNormalization.Config(name='post_norm'),
sl.Dropout.Config(dropout_rate),
], name='ffn_block')
], name='transformer_block')
transformer = config.make()
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
# Random input sequence:
x = sl.Sequence(
values=jax.random.normal(k1, (2, 4096, 1024))),
mask=jax.random.uniform(k2, (2, 4096)) > 0.5
)
# Run Flax layer initialization.
params = transformer.init(k3, x, training=False)
# Bind the layer for imperative/example usage.
transformer = transformer.bind(params)
# Process x layer-wise:
y_layer = block.layer(x, training=True)
# Process x 8 steps at a time:
block_size = 8
num_blocks = (x.shape[1] + block_size - 1) // block_size
state = block.get_initial_state(x.shape[0], x.channel_spec, training=False)
y_step = []
for i in range(num_blocks):
x_i = x[:, i * block_size : (i + 1) * block_size]
y_i, state = block.step(x_i, state, training=False)
y_step.append(y_i)
y_step = sl.Sequence.concatenate_sequences(y_step)
np.testing.assert_array_allclose(y_layer.values, y_step.values)
np.testing.assert_array_equal(y_layer.mask, y_step.mask)
Building streaming sequence models is surprisingly tricky. There are four common pitfalls I kept running into:
- Batching unequal sequences requires tracking invalid timesteps and verifying all layers handle padding correctly, including pooling and sampling operations.
- Causality constraints mean modern architectures need separate efficient parallel training and autoregressive sampling code paths, both avoiding causality violations.
- Offline vs. streaming mismatch: converting parallel inference (via masking) to streaming requires re-implementation due to lookahead windows and memory constraints.
- Unnecessary coupling: architecture details become entangled with
algorithms, such as pairing
AutoregressiveTransformerwhen architecture and autoregressive modeling are independent choices.
SequenceLayers addresses these with three core features. Each SequenceLayer is:
- Streamable: SequenceLayers gives you streaming for free, in a production-friendly way. Every layer implements explicit state and a step method alongside the traditional layer-wise call.
- Correct: SequenceLayers is correct by default, making entire classes of bugs impossible. Layer and step methods are tested to produce identical results, and mask-aware Sequence objects track padding throughout.
- Composable: An easy-to-understand declarative API enforces these guarantees, enabling sequence models with concise definitions that read like block diagrams.
SequenceLayers has been used to abstract architectural details in:
- Classifiers
- Contrastive / distance metric learning models.
- Regression models.
- Probabilistic models (autoregressive models, normalizing flows, diffusion, VAEs, GANs).
across a wide variety of tasks:
- Audio / speech classification.
- Image classification.
- Contextualized word embedding.
- Text-to-speech synthesis.
- Speech and phoneme recognition.
- Speech translation.
- Speech vocoding.
- Audio tokenization and synthesis.
- Real-time music synthesis.
- Video understanding.
- Language modeling.
Additionally, SequenceLayers is used extensively in production at Google for many streaming applications.
For more detail on the library design and rationale behind the design, see the tech report and the code on GitHub, available under the Apache 2.0 license.
I’m deeply thankful that Google was willing to let me open source this library. My hope is that it serves as a useful example to the community of a simple abstraction that more than pulls its weight when working with neural networks that operate over sequences.



