Policy Rollout (300 Steps)
pusht_policy.mp4
Autonomous Multi-Modal Manipulation
This Space showcases a trained Diffusion Policy operating within the PushT-v0 simulation environment using Hugging Face's LeRobot ecosystem. The agent learns multi-modal trajectories to effectively guide the gray T-shaped block completely into the target green silhouette zone.
Observation Space
Pixels & Agent Position
Image Shape: (3, 384, 384)
Action Space
2D Continuous Control
End-effector delta position
Model Source
lerobot/diffusion_pusht
Pre-trained Checkpoint via HF Hub
Pipeline Optimization
Dynamic Buffer Patching
Overrides state dict pos_grid mismatch
Engineering Note: The default environment instantiation sets the visual input size to 384×384 pixels, forcing the model's position grid (
pos_grid) to shape [144, 2]. To preserve checkpoint compatibility, the execution engine explicitly overwrites and re-registers the pre-trained token configuration [9, 2] buffer layout right after model initialization.
Deployment & Rollout Script
The clean implementation pipeline used to compile dataset statistics, patch the architecture shapes, step through environment dynamics, and generate the rollout asset.
run_pusht.py
import os
import gymnasium as gym
import gym_pusht
import torch
import imageio
from huggingface_hub import hf_hub_download
import safetensors.torch
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.factory import make_pre_post_processors
from lerobot.envs.utils import preprocess_observation
def main():
# 1. Download checkpoint and load config
print("Downloading config from lerobot/diffusion_pusht...")
cfg = PreTrainedConfig.from_pretrained('lerobot/diffusion_pusht')
# We override the observation.image feature shape to (3, 384, 384) to match the environment defaults,
# which instantiates the model's pos_grid as [144, 2] instead of [9, 2] (checkpoint size).
cfg.input_features['observation.image'].shape = (3, 384, 384)
# Build the DiffusionPolicy
print("Building DiffusionPolicy...")
policy = DiffusionPolicy(cfg)
print("Initial pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape)
# Load weights with strict=False
print("Downloading and loading safetensors model weights...")
model_file = hf_hub_download(repo_id='lerobot/diffusion_pusht', filename='model.safetensors')
state_dict = safetensors.torch.load_file(model_file)
policy.load_state_dict(state_dict, strict=False)
# 2. Patch the pos_grid shape mismatch so inference works
print("Patching the pos_grid shape mismatch...")
checkpoint_pos_grid = state_dict['diffusion.rgb_encoder.pool.pos_grid']
policy.diffusion.rgb_encoder.pool.register_buffer('pos_grid', checkpoint_pos_grid)
print("Patched pos_grid shape in model:", policy.diffusion.rgb_encoder.pool.pos_grid.shape)
# Move policy to correct device and set to eval mode
policy.to(cfg.device)
policy.eval()
# 3. Create preprocessor / postprocessor with the extracted dataset stats
print("Creating preprocessor and postprocessor...")
dataset_stats = {
'observation.image': {
'mean': state_dict['normalize_inputs.buffer_observation_image.mean'],
'std': state_dict['normalize_inputs.buffer_observation_image.std'],
},
'observation.state': {
'max': state_dict['normalize_inputs.buffer_observation_state.max'],
'min': state_dict['normalize_inputs.buffer_observation_state.min'],
},
'action': {
'max': state_dict['normalize_targets.buffer_action.max'],
'min': state_dict['normalize_targets.buffer_action.min'],
}
}
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats)
# 4. Instantiate the gym environment
print("Creating PushT environment...")
env = gym.make('gym_pusht/PushT-v0', render_mode='rgb_array', obs_type='pixels_agent_pos')
# Reset env and cache initial frame
policy.reset()
obs, info = env.reset()
frames = [env.render()]
# Run rollout for 300 steps
print("Running 300 steps rollout...")
for step in range(300):
# Format observations to LeRobot format
obs_t = preprocess_observation(obs)
obs_t = preprocessor(obs_t)
# Select action
with torch.no_grad():
action = policy.select_action(obs_t)
action = postprocessor(action)
# Extract numpy action and apply to env (drop batch dimension)
action_numpy = action.to("cpu").numpy()[0]
obs, reward, terminated, truncated, info = env.step(action_numpy)
# Render frame
frame = env.render()
frames.append(frame)
if terminated or truncated:
obs, info = env.reset()
# Close env
env.close()
# 5. Save the frames as pusht_policy.mp4
print("Saving video to pusht_policy.mp4...")
imageio.mimsave("pusht_policy.mp4", frames, fps=10)
print("Done! Video saved successfully.")
if __name__ == "__main__":
main()