#bert #llm #machine-learning #gpt #transformer

axonml-llm

Large Language Model architectures for the Axonml ML framework

16 releases (5 breaking)

Uses new Rust 2024

new 0.6.2 Apr 17, 2026
0.6.0 Apr 9, 2026
0.5.0 Mar 31, 2026
0.4.3 Mar 25, 2026
0.1.0 Jan 19, 2026

#1898 in Machine learning


Used in 2 crates

MIT/Apache

2.5MB
48K SLoC

axonml-llm

AxonML Logo

License: Apache-2.0 Rust 1.75+ Crate Version Part of AxonML


Overview

axonml-llm provides nine large-language-model architectures for the AxonML framework, all implemented in pure Rust on top of axonml-tensor and axonml-autograd. Shared infrastructure includes multi-head / causal self-attention with a KV cache, a FlashAttention kernel, RoPE, RMSNorm, a HuggingFace weight loader (safetensors), a state-dict name-mapping helper, a pretrained-weights hub with on-disk caching, a configurable text-generation sampler, and an HF-style tokenizer.


Supported architectures

Architecture Module Notes
GPT-2 gpt2 Decoder-only transformer with learned positional embeddings and GPT2LMHead.
BERT bert Bidirectional encoder with BertForSequenceClassification and BertForMaskedLM.
LLaMA llama LLaMA-2/3-style: split-halves RoPE, grouped-query attention, SwiGLU MLP, RMSNorm. LLaMAForCausalLM.
Mistral mistral LLaMA + sliding-window attention. MistralForCausalLM.
Phi phi Phi-1/2/3-style: partial RoPE, GELU MLP, optional parallel-attn block. PhiForCausalLM.
SSM ssm Mamba/S6-style selective state-space model: depthwise Conv1d + selective scan + RMSNorm. SSMBlock, SSMForCausalLM.
Hydra hydra Hybrid: alternates SSM blocks and windowed (local) attention. HydraModel.
Chimera chimera Sparse MoE (top-k routing) + differential attention, with load-balancing auxiliary loss. ChimeraModel.
Trident trident 1.58-bit ternary weights (TernaryLinear) + RoPE + GQA + ReLU²-gated FFN + SubLN (BitNet b1.58-2B-4T recipe). TridentModel.

Preset configurations

Config Presets
GPT2Config tiny, small, medium, large, xl
BertConfig tiny, base, large
LLaMAConfig llama2_7b, llama2_13b, llama3_8b, tiny
MistralConfig mistral_7b, mistral_7b_instruct, mixtral_8x7b, tiny
PhiConfig phi1, phi2, phi3_mini, tiny
SSMConfig from_d_model(d_model, vocab_size) builder
HydraConfig base (~300M), small, tiny
ChimeraConfig default_2b (8 experts, top-2), small, tiny
TridentConfig default_150m, tiny, medium, plus 1B/3B/smoke constructors exposed for training

Shared building blocks

  • Attention (attention) — MultiHeadSelfAttention, CausalSelfAttention, scaled_dot_product_attention, plus FlashAttention + FlashAttentionConfig. KVCache and per-layer LayerKVCache for incremental decoding.
  • Transformer (transformer) — TransformerBlock, TransformerEncoder, TransformerDecoder with configurable depth/width/heads/activation and pre- or post-norm.
  • Embeddings (embedding) — TokenEmbedding, PositionalEmbedding (sinusoidal), GPT2Embedding, BertEmbedding.
  • Generation (generation) — GenerationConfig with greedy(), sampling(temp), top_k_sampling(k, temp), nucleus_sampling(p, temp), beam_search(beams), plus builder methods with_max_tokens, with_eos_token, with_repetition_penalty. TextGenerator drives the next-token logic.
  • Weight loadingHFLoader (safetensors), load_llama_from_hf, load_mistral_from_hf, and generic LoadStateDict trait with LoadResult + map_hf_to_axonml / map_axonml_to_hf name mappers.
  • HubPretrainedLLM registry (llm_registry(), list_models()), download_weights(name, force) with an on-disk cache under $XDG_CACHE_HOME/axonml/hub/llm.
  • TokenizerHFTokenizer + SpecialTokens (HuggingFace-compatible tokenizer.json).

Modules

Module Description
attention Multi-head / causal / flash attention, KV cache
bert BERT encoder + classification and MLM heads
chimera Sparse MoE + differential attention model
config GPT2Config, BertConfig, TransformerConfig
embedding Token, positional, BERT/GPT-2 combined embeddings
error LLMError / LLMResult
generation GenerationConfig, TextGenerator, sampling strategies
gpt2 GPT-2 + GPT2LMHead
hf_loader HuggingFace safetensors loading (LLaMA, Mistral, …)
hub Pretrained-weight registry and downloader
hydra SSM + windowed-attention hybrid
llama LLaMA with RoPE, GQA, SwiGLU, RMSNorm
mistral Mistral (LLaMA + sliding-window attention)
phi Phi with partial RoPE / GELU / optional parallel-attn
ssm Mamba-style selective SSM blocks and LM head
state_dict LoadStateDict trait, HF ↔ AxonML name mapping
tokenizer HuggingFace-compatible HFTokenizer
transformer Encoder / decoder blocks, feed-forward, layer norm
trident 1.58-bit ternary SLM (BitNet b1.58 recipe)

Usage

Add the crate to your Cargo.toml:

[dependencies]
axonml-llm = "0.6.1"

GPT-2 text generation

use axonml_llm::{GPT2LMHead, GPT2Config};
use axonml_tensor::Tensor;

let config = GPT2Config::small();
let model = GPT2LMHead::new(&config);

let input_ids = Tensor::from_vec(vec![50256u32, 1, 2, 3], &[1, 4]).unwrap();
let output  = model.generate(&input_ids, 50, 0.8, Some(50));
let greedy  = model.generate_greedy(&input_ids, 50);

BERT sequence classification

use axonml_llm::{BertForSequenceClassification, BertConfig};
use axonml_tensor::Tensor;

let config = BertConfig::base();
let model  = BertForSequenceClassification::new(&config, /*num_labels=*/ 2);

let input_ids = Tensor::from_vec(vec![101u32, 2054, 2003, 1996, 102], &[1, 5]).unwrap();
let logits = model.forward_classification(&input_ids);

BERT masked language modeling

use axonml_llm::{BertForMaskedLM, BertConfig};
use axonml_tensor::Tensor;

let model = BertForMaskedLM::new(&BertConfig::base());
let input_ids = Tensor::from_vec(vec![101u32, 2054, 103, 1996, 102], &[1, 5]).unwrap();
let logits = model.forward_mlm(&input_ids); // [batch, seq, vocab]

LLaMA / Mistral / Phi

use axonml_llm::{LLaMAForCausalLM, LLaMAConfig, MistralForCausalLM, MistralConfig, PhiForCausalLM, PhiConfig};

let llama   = LLaMAForCausalLM::new(&LLaMAConfig::llama2_7b());
let mistral = MistralForCausalLM::new(&MistralConfig::mistral_7b());
let phi     = PhiForCausalLM::new(&PhiConfig::phi2());

SSM (Mamba-style) causal LM

use axonml_llm::{SSMForCausalLM, SSMConfig};

let cfg = SSMConfig::from_d_model(/*d_model=*/ 512, /*vocab_size=*/ 32000);
let ssm = SSMForCausalLM::new(&cfg);

Hydra hybrid SSM + windowed attention

use axonml_llm::{HydraModel, HydraConfig};

let hydra = HydraModel::new(&HydraConfig::base()); // 768d, 24 layers, window 256

Chimera sparse MoE + differential attention

use axonml_llm::{ChimeraModel, ChimeraConfig};

let cfg = ChimeraConfig::default_2b(); // 8 experts per layer, top-2 active
let model = ChimeraModel::new(&cfg);
// Returns (logits, lb_loss) when used via forward_with_loss for training.

Trident 1.58-bit ternary LM

use axonml_llm::{TridentModel, TridentConfig};

let cfg = TridentConfig::default_150m(); // ternary weights + RoPE + GQA + ReLU² + SubLN
let model = TridentModel::new(&cfg);

Custom transformer encoder

use axonml_llm::TransformerEncoder;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;

let encoder = TransformerEncoder::new(
    6,       // num_layers
    512,     // hidden_size
    8,       // num_heads
    2048,    // intermediate_size
    0.1,     // dropout
    1e-12,   // layer_norm_eps
    "gelu",  // activation
    false,   // pre_norm
);

let input  = Variable::new(Tensor::randn(&[2, 128, 512]), false);
let output = encoder.forward(&input);

Generation configuration

use axonml_llm::{GenerationConfig, TextGenerator};

let config = GenerationConfig::nucleus_sampling(0.95, 0.8)
    .with_max_tokens(100)
    .with_repetition_penalty(1.2)
    .with_eos_token(50256);

let generator = TextGenerator::new(config);
let next_token = generator.get_next_token(&logits, &generated_so_far);

Loading HuggingFace weights

use axonml_llm::{load_llama_from_hf, LLaMAConfig};

let mut model = axonml_llm::LLaMAForCausalLM::new(&LLaMAConfig::llama2_7b());
load_llama_from_hf(&mut model, "path/to/model.safetensors")?;

Pretrained hub

use axonml_llm::{llm_registry, download_llm_weights};

for m in llm_registry().values() {
    println!("{}  ({} params)", m.name, m.num_parameters);
}

let path = download_llm_weights("bert-base-uncased", /*force=*/ false)?;

Configuration reference

BERT

Config Hidden size Layers Heads
BertConfig::tiny() 128 2 2
BertConfig::base() 768 12 12
BertConfig::large() 1024 24 16

GPT-2

Config Embedding dim Layers Heads
GPT2Config::tiny() 128 2 2
GPT2Config::small() 768 12 12
GPT2Config::medium() 1024 24 16
GPT2Config::large() 1280 36 20
GPT2Config::xl() 1600 48 25

LLaMA / Mistral / Phi

Config Presets
LLaMA llama2_7b, llama2_13b, llama3_8b, tiny
Mistral mistral_7b, mistral_7b_instruct, mixtral_8x7b, tiny
Phi phi1, phi2, phi3_mini, tiny

Hybrid / SSM / MoE / Ternary

Config Presets / builder
SSMConfig from_d_model(d_model, vocab)
HydraConfig base, small, tiny
ChimeraConfig default_2b, small, tiny
TridentConfig default_150m, tiny, medium (plus 1B/3B/smoke constructors)

Generation strategies

Strategy Method Description
Greedy GenerationConfig::greedy() Always takes the argmax
Sampling GenerationConfig::sampling(temp) Temperature-scaled softmax sampling
Top-K GenerationConfig::top_k_sampling(k, temp) Sample from top-k tokens
Nucleus GenerationConfig::nucleus_sampling(p, temp) Sample from top-p probability mass
Beam Search GenerationConfig::beam_search(beams) Beam search decoding

Tests

cargo test -p axonml-llm
cargo test -p axonml-llm -- --nocapture

License

Licensed under either of:

  • MIT License
  • Apache License, Version 2.0

at your option.


Last updated: 2026-04-16 (v0.6.1)

Dependencies

~15–30MB
~399K SLoC