SHRAM β€” Sparse Hybrid Token Routed Attention Mixture

A research baseline implementing the SHRAM architecture from "An Examination of Sparse Attention for Long Context Purposes." No pretrained weights β€” pull the architecture from the Hub and instantiate a freshly initialised model from config. Every parameter is overridable at instantiation time via kwargs.

Important: trust_remote_code=True is required. It downloads the architecture source files from the Hub and imports them into your Python process. Review the source at smithblack-0/SHRAM before use. Those interested can also clone the git repository at https://github.com/smithblack-0/advanced-transformers-lib

Architecture

SHRAM replaces every standard attention layer with a hybrid layer H(x) = h_l(x) + h_s(x):

  • h_l β€” local sliding-window causal attention path.
  • h_s β€” MoSRAH sparse routed path. Each token selects K of L available expert heads via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.

All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).

Usage

This repository contains no pretrained weights. The intended workflow is: pull the architecture config from the Hub, instantiate a model with fresh random weights, then train it yourself.

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# Step 1: pull the architecture config from the Hub.
# AutoConfig.from_pretrained downloads config.json only β€” no weights are loaded.
# Override any parameter via kwargs.
config = AutoConfig.from_pretrained(
    "smithblack-0/SHRAM",
    trust_remote_code=True,
    num_hidden_layers=16,       # example override
    num_mosrah_heads=32,        # example override
)

# Step 2: instantiate with fresh random weights.
# from_config never loads a checkpoint β€” it always produces a randomly initialised model.
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

# Step 3: load the tokenizer.
tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM")

After training your own checkpoint, save and reload it in the standard way:

model.save_pretrained("./my-checkpoint")
model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)

Constructor Defaults

The values below are the defaults you get if you call AutoConfig.from_pretrained with no overrides. They are not the parameters of a pretrained model β€” this repository contains no weights. All values are overridable via kwargs.

Parameter Default
alpha 1.0
attention_dropout 0.0
beta 32.0
dtype None
head_dim 16
hidden_size 512
inference_sequence_length 1024
intermediate_size 1366
local_rope_theta 10000.0
mosrah_rope_theta 10000.0
num_hidden_layers 12
num_mosrah_heads 16
num_selected_heads 16
num_sliding_window_heads 16
output_hidden_states False
rms_norm_eps 1e-05
rope_mode main_sequence
tie_word_embeddings False
training_sequence_length 1024
use_cache True
vocab_size 50277
window_size 128

License

MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX (EleutherAI/gpt-neox-20b, Apache 2.0).

Downloads last month
2,424
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support