relay — JEPA plans a turn-based contact-push puzzle in latent space

A 26M-parameter Joint Embedding Predictive Architecture (JEPA) trained to play a 2-player turn-based T-block puzzle against humans. Plans each turn via CEM in pure embedding space — no physics simulator on the model's side at inference.

Architecture

  • ViT-Tiny encoder (vit_tiny_patch14_224, trained from scratch, unfrozen, 192-dim)
  • Projector MLP(192→2048→192) + BatchNorm
  • AR causal Transformer predictor (AdaLN, 6 layers, 16 heads, mlp_dim=2048)
  • DexWM-style joint state head MLP(192→256→256→8) trained alongside the predictor
  • Loss: MSE_pred + 0.09·SIGReg + 10·MSE_state

Training data

12,500 episodes × 40 steps × 5 regimes (contact_push 40%, approach_no_push 25%, null_thrust 15%, near_t 10%, far_from_t 10%). 500K frames at 224×224 RGB with 5-frame action chunks (matches gameplay turn structure).

Quick start

git clone https://github.com/SotoAlt/relay.git
cd relay
pip install torch torchvision timm einops pymunk pygame opencv-python-headless \
    shapely fastapi 'uvicorn[standard]' numpy pillow h5py

huggingface-cli download sotoalt/relay relay_stage1_v9_trackE_ep02_uhead.pt \
    --local-dir checkpoints/

PYTHONPATH=. python -m world_model.infer_relay \
    --port 8800 --device cpu \
    --checkpoint-v9 checkpoints/relay_stage1_v9_trackE_ep02_uhead.pt \
    --model-execute-jepa

Open http://localhost:8800/.

Results

Phase A joint training, λ_state=10, 5-epoch fine-tune validation:

metric pre-joint v9 post-joint v9
val_pred 0.0071 0.036
val_state 0.032 (Phase B) 0.007 (4.6× better)
pymunk calibration probe agent_mae ~113 px ~54 px (-52%)

Match-play (model-execute-jepa, 12 × 4 opponent policies):

opponent win % net progress
random 92% +126
passive 100% +149
chase_t 0% -67
adversarial 25% -10

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Papers for sotoalt/relay