LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
Paper • 2603.19312 • Published • 44
A 26M-parameter Joint Embedding Predictive Architecture (JEPA) trained to play a 2-player turn-based T-block puzzle against humans. Plans each turn via CEM in pure embedding space — no physics simulator on the model's side at inference.
vit_tiny_patch14_224, trained from scratch, unfrozen, 192-dim)MSE_pred + 0.09·SIGReg + 10·MSE_state12,500 episodes × 40 steps × 5 regimes (contact_push 40%, approach_no_push 25%, null_thrust 15%, near_t 10%, far_from_t 10%). 500K frames at 224×224 RGB with 5-frame action chunks (matches gameplay turn structure).
git clone https://github.com/SotoAlt/relay.git
cd relay
pip install torch torchvision timm einops pymunk pygame opencv-python-headless \
shapely fastapi 'uvicorn[standard]' numpy pillow h5py
huggingface-cli download sotoalt/relay relay_stage1_v9_trackE_ep02_uhead.pt \
--local-dir checkpoints/
PYTHONPATH=. python -m world_model.infer_relay \
--port 8800 --device cpu \
--checkpoint-v9 checkpoints/relay_stage1_v9_trackE_ep02_uhead.pt \
--model-execute-jepa
Open http://localhost:8800/.
Phase A joint training, λ_state=10, 5-epoch fine-tune validation:
| metric | pre-joint v9 | post-joint v9 |
|---|---|---|
val_pred |
0.0071 | 0.036 |
val_state |
0.032 (Phase B) | 0.007 (4.6× better) |
pymunk calibration probe agent_mae |
~113 px | ~54 px (-52%) |
Match-play (model-execute-jepa, 12 × 4 opponent policies):
| opponent | win % | net progress |
|---|---|---|
| random | 92% | +126 |
| passive | 100% | +149 |
| chase_t | 0% | -67 |
| adversarial | 25% | -10 |
MIT