LeRobot PushT Evaluation

Diffusion Policy Gymnasium
Policy Rollout (300 Steps) pusht_policy.mp4

Autonomous Multi-Modal Manipulation

This Space showcases a trained Diffusion Policy operating within the PushT-v0 simulation environment using Hugging Face's LeRobot ecosystem. The agent learns multi-modal trajectories to effectively guide the gray T-shaped block completely into the target green silhouette zone.

Observation Space

Pixels & Agent Position

Image Shape: (3, 384, 384)

Action Space

2D Continuous Control

End-effector delta position

Model Source

lerobot/diffusion_pusht

Pre-trained Checkpoint via HF Hub

Pipeline Optimization

Dynamic Buffer Patching

Overrides state dict pos_grid mismatch

Engineering Note: The default environment instantiation sets the visual input size to 384×384 pixels, forcing the model's position grid (pos_grid) to shape [144, 2]. To preserve checkpoint compatibility, the execution engine explicitly overwrites and re-registers the pre-trained token configuration [9, 2] buffer layout right after model initialization.

Deployment & Rollout Script

The clean implementation pipeline used to compile dataset statistics, patch the architecture shapes, step through environment dynamics, and generate the rollout asset.

run_pusht.py
import os
import gymnasium as gym
import gym_pusht
import torch
import imageio
from huggingface_hub import hf_hub_download
import safetensors.torch
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.factory import make_pre_post_processors
from lerobot.envs.utils import preprocess_observation

def main():
    # 1. Download checkpoint and load config
    print("Downloading config from lerobot/diffusion_pusht...")
    cfg = PreTrainedConfig.from_pretrained('lerobot/diffusion_pusht')
    
    # We override the observation.image feature shape to (3, 384, 384) to match the environment defaults,
    # which instantiates the model's pos_grid as [144, 2] instead of [9, 2] (checkpoint size).
    cfg.input_features['observation.image'].shape = (3, 384, 384)

    # Build the DiffusionPolicy
    print("Building DiffusionPolicy...")
    policy = DiffusionPolicy(cfg)
    print("Initial pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape)

    # Load weights with strict=False
    print("Downloading and loading safetensors model weights...")
    model_file = hf_hub_download(repo_id='lerobot/diffusion_pusht', filename='model.safetensors')
    state_dict = safetensors.torch.load_file(model_file)
    policy.load_state_dict(state_dict, strict=False)

    # 2. Patch the pos_grid shape mismatch so inference works
    print("Patching the pos_grid shape mismatch...")
    checkpoint_pos_grid = state_dict['diffusion.rgb_encoder.pool.pos_grid']
    policy.diffusion.rgb_encoder.pool.register_buffer('pos_grid', checkpoint_pos_grid)
    print("Patched pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape)

    # Move policy to correct device and set to eval mode
    policy.to(cfg.device)
    policy.eval()

    # 3. Create preprocessor / postprocessor with the extracted dataset stats
    print("Creating preprocessor and postprocessor...")
    dataset_stats = {
        'observation.image': {
            'mean': state_dict['normalize_inputs.buffer_observation_image.mean'],
            'std': state_dict['normalize_inputs.buffer_observation_image.std'],
        },
        'observation.state': {
            'max': state_dict['normalize_inputs.buffer_observation_state.max'],
            'min': state_dict['normalize_inputs.buffer_observation_state.min'],
        },
        'action': {
            'max': state_dict['normalize_targets.buffer_action.max'],
            'min': state_dict['normalize_targets.buffer_action.min'],
        }
    }
    preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats)

    # 4. Instantiate the gym environment
    print("Creating PushT environment...")
    env = gym.make('gym_pusht/PushT-v0', render_mode='rgb_array', obs_type='pixels_agent_pos')
    
    # Reset env and cache initial frame
    policy.reset()
    obs, info = env.reset()
    frames = [env.render()]

    # Run rollout for 300 steps
    print("Running 300 steps rollout...")
    for step in range(300):
        # Format observations to LeRobot format
        obs_t = preprocess_observation(obs)
        obs_t = preprocessor(obs_t)

        # Select action
        with torch.no_grad():
            action = policy.select_action(obs_t)
        action = postprocessor(action)

        # Extract numpy action and apply to env (drop batch dimension)
        action_numpy = action.to("cpu").numpy()[0]
        obs, reward, terminated, truncated, info = env.step(action_numpy)
        
        # Render frame
        frame = env.render()
        frames.append(frame)

        if terminated or truncated:
            obs, info = env.reset()

    # Close env
    env.close()

    # 5. Save the frames as pusht_policy.mp4
    print("Saving video to pusht_policy.mp4...")
    imageio.mimsave("pusht_policy.mp4", frames, fps=10)
    print("Done! Video saved successfully.")

if __name__ == "__main__":
    main()