-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathminecraft_dataset.py
More file actions
377 lines (321 loc) · 14.4 KB
/
minecraft_dataset.py
File metadata and controls
377 lines (321 loc) · 14.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
#!/usr/bin/env python3
"""
PyTorch dataset for Minecraft schematic files.
This dataset loads schematic files and provides 16×16×16 chunks with masks.
Uses sentence transformers to create embeddings for block types.
"""
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pickle
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
import torch.nn as nn
from schematic_loader import load_schematic_to_numpy, BLOCK_ID_TO_NAME
class MinecraftSchematicDataset(Dataset):
"""
PyTorch dataset for Minecraft schematic files.
Provides 16×16×16 chunks with masks for areas where dimensions are smaller than 16.
Uses sentence transformers to create embeddings for block types.
"""
def __init__(
self,
schematics_dir,
chunk_size=16,
preload=False,
cache_file=None,
embedding_cache_file=None,
max_files=None,
min_dimension=16,
embedding_dim=32,
):
"""
Initialize the dataset.
Args:
schematics_dir (str): Directory containing schematic files
chunk_size (int): Size of the chunks to extract (default: 16)
preload (bool): Whether to preload all data into memory (default: False)
cache_file (str, optional): Path to cache file for block mappings
embedding_cache_file (str, optional): Path to cache file for block embeddings
max_files (int, optional): Maximum number of files to load
min_dimension (int): Minimum dimension required for a schematic to be included
embedding_dim (int): Dimension for the block embeddings after PCA reduction
"""
self.schematics_dir = schematics_dir
self.chunk_size = chunk_size
self.preload = preload
self.min_dimension = min_dimension
self.embedding_dim = embedding_dim
# Set embedding cache file
self.embedding_cache_file = embedding_cache_file or "cache/block_embeddings.pt"
# Find all schematic files
self.schematic_files = []
for root, _, files in os.walk(schematics_dir):
for file in files:
if file.endswith(".schematic") or file.endswith(".schem"):
self.schematic_files.append(os.path.join(root, file))
# Limit the number of files if specified
if max_files is not None:
self.schematic_files = self.schematic_files[:max_files]
# Create a mapping of block names to indices
self.block_to_idx = {}
self.idx_to_block = {}
# Try to load block mappings from cache
if cache_file and os.path.exists(cache_file):
print(f"Loading block mappings from cache: {cache_file}")
with open(cache_file, "rb") as f:
cache_data = pickle.load(f)
self.block_to_idx = cache_data["block_to_idx"]
self.idx_to_block = cache_data["idx_to_block"]
self.file_info = cache_data.get("file_info", [])
else:
# Scan files to build block mappings and collect file info
self._scan_files()
# Save block mappings to cache if specified
if cache_file:
print(f"Saving block mappings to cache: {cache_file}")
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, "wb") as f:
pickle.dump(
{
"block_to_idx": self.block_to_idx,
"idx_to_block": self.idx_to_block,
"file_info": self.file_info,
},
f,
)
# Create block embeddings using sentence transformers
self.block_embeddings = self._create_block_embeddings()
# Preload data if requested
self.preloaded_data = None
if self.preload:
self._preload_data()
def _create_block_embeddings(self):
"""
Create embeddings for block types using sentence transformers.
Uses PCA to reduce dimensions to the specified embedding_dim.
Returns:
torch.Tensor: Tensor of shape [num_blocks, embedding_dim] with embeddings for each block type
"""
# Check if embeddings already exist in cache
if os.path.exists(self.embedding_cache_file):
print(f"Loading block embeddings from cache: {self.embedding_cache_file}")
return torch.load(self.embedding_cache_file)
print("Creating block embeddings using sentence transformers...")
# Create a mapping from block indices to block names
block_names = []
for idx in range(len(self.idx_to_block)):
block = self.idx_to_block.get(idx)
if isinstance(block, str):
# Special tokens
if block == "<pad>":
block_names.append("padding")
elif block == "<unk>":
block_names.append("unknown block")
else:
block_names.append(block)
else:
# Numeric block ID, get name from BLOCK_ID_TO_NAME
block_name = BLOCK_ID_TO_NAME.get(block, f"unknown block {block}")
block_names.append(block_name)
# Load the sentence transformer model
print("Loading sentence transformer model...")
model = SentenceTransformer("all-mpnet-base-v2")
# Generate embeddings for all block names
print("Generating embeddings for block names...")
embeddings = model.encode(block_names)
# Apply PCA to reduce dimensions
print(f"Applying PCA to reduce dimensions to {self.embedding_dim}...")
pca = PCA(n_components=self.embedding_dim)
reduced_embeddings = pca.fit_transform(embeddings)
embeddings_tensor = torch.tensor(reduced_embeddings, dtype=torch.float32)
# Set padding token embedding to zeros
if "<pad>" in self.block_to_idx:
pad_idx = self.block_to_idx["<pad>"]
embeddings_tensor[pad_idx] = torch.zeros(self.embedding_dim)
# Save embeddings to cache
os.makedirs(os.path.dirname(self.embedding_cache_file), exist_ok=True)
torch.save(embeddings_tensor, self.embedding_cache_file)
print(f"Created embeddings of shape: {embeddings_tensor.shape}")
return embeddings_tensor
def _scan_files(self):
"""Scan all schematic files to build block mappings and collect file info."""
print("Scanning schematic files to build block mappings...")
# Add special tokens
self.block_to_idx["<pad>"] = 0
self.block_to_idx["<unk>"] = 1
self.idx_to_block[0] = "<pad>"
self.idx_to_block[1] = "<unk>"
# Collect block types and file info
block_counts = defaultdict(int)
self.file_info = []
for file_path in tqdm(self.schematic_files):
try:
# Load the schematic
blocks, dimensions = load_schematic_to_numpy(file_path)
# Check if dimensions meet the minimum requirement
height, length, width = blocks.shape
if (
height < self.min_dimension
or length < self.min_dimension
or width < self.min_dimension
):
continue
# Count unique blocks
unique_blocks = np.unique(blocks)
for block in unique_blocks:
block_counts[block] += 1
# Store file info
self.file_info.append(
{"path": file_path, "dimensions": dimensions, "shape": blocks.shape}
)
except Exception as e:
print(f"Error loading {file_path}: {e}")
# Create block mappings (starting from 2 because 0 and 1 are reserved)
next_idx = 2
for block, count in sorted(
block_counts.items(), key=lambda x: x[1], reverse=True
):
if block not in self.block_to_idx:
self.block_to_idx[block] = next_idx
self.idx_to_block[next_idx] = block
next_idx += 1
print(f"Found {len(self.file_info)} valid schematic files")
print(f"Found {len(self.block_to_idx)} unique block types")
def _preload_data(self):
"""Preload all data into memory."""
print("Preloading data into memory...")
self.preloaded_data = []
for info in tqdm(self.file_info):
try:
blocks, _ = load_schematic_to_numpy(info["path"])
self.preloaded_data.append(blocks)
except Exception as e:
print(f"Error preloading {info['path']}: {e}")
self.preloaded_data.append(None)
def _get_blocks(self, idx):
"""Get blocks array for a given index."""
if self.preload and self.preloaded_data[idx] is not None:
return self.preloaded_data[idx]
else:
blocks, _ = load_schematic_to_numpy(self.file_info[idx]["path"])
return blocks
def _extract_chunk(self, blocks):
"""
Extract a chunk from the blocks array.
Args:
blocks: The blocks array
start_y, start_z, start_x: Starting coordinates for the chunk
Returns:
chunk: The extracted chunk
mask: Mask indicating valid positions (1) vs padding (0)
"""
sliding_window_width = self.chunk_size
sliding_window_height = self.chunk_size
sliding_window_depth = self.chunk_size
block_map = np.full(
(sliding_window_width, sliding_window_height, sliding_window_depth),
"minecraft:air",
dtype=object,
)
minimum_width = min(sliding_window_width, blocks.shape[0])
minimum_height = min(sliding_window_height, blocks.shape[1])
minimum_depth = min(sliding_window_depth, blocks.shape[2])
x_start = np.random.randint(0, blocks.shape[0] - minimum_width + 1)
y_start = np.random.randint(0, blocks.shape[1] - minimum_height + 1)
z_start = np.random.randint(0, blocks.shape[2] - minimum_depth + 1)
x_end = x_start + minimum_width
y_end = y_start + minimum_height
z_end = z_start + minimum_depth
random_roll_x_value = np.random.randint(
0, sliding_window_width - minimum_width + 1
)
random_roll_y_value = np.random.randint(
0, sliding_window_height - minimum_height + 1
)
random_roll_z_value = np.random.randint(
0, sliding_window_depth - minimum_depth + 1
)
block_map[
random_roll_x_value : random_roll_x_value + minimum_width,
random_roll_y_value : random_roll_y_value + minimum_height,
random_roll_z_value : random_roll_z_value + minimum_depth,
] = blocks[x_start:x_end, y_start:y_end, z_start:z_end]
block_map = np.vectorize(self.block_to_idx.get)(block_map)
block_map = block_map.astype(int)
block_map_mask = np.zeros((16, 16, 16), dtype=int)
block_map_mask[
random_roll_x_value : random_roll_x_value + minimum_width,
random_roll_y_value : random_roll_y_value + minimum_height,
random_roll_z_value : random_roll_z_value + minimum_depth,
] = 1
return block_map, block_map_mask
def __len__(self):
"""Return the number of chunks in the dataset."""
return len(self.file_info)
def __getitem__(self, idx):
"""
Get a chunk from the dataset.
Args:
idx (int): Index
Returns:
dict: A dictionary containing:
- 'blocks': Tensor of shape (chunk_size, chunk_size, chunk_size) with block indices
- 'block_embeddings': Tensor of shape (chunk_size, chunk_size, chunk_size, embedding_dim) with embeddings
- 'mask': Tensor of shape (chunk_size, chunk_size, chunk_size) with 1 for valid positions, 0 for padding
- 'file_path': Path to the source schematic file
"""
# Get the blocks array
blocks = self._get_blocks(idx)
# Extract the chunk
chunk, mask = self._extract_chunk(blocks)
# Convert to tensors
chunk_tensor = torch.tensor(chunk, dtype=torch.long)
mask_tensor = torch.tensor(mask, dtype=torch.float)
# Create embeddings tensor for the chunk
# Shape: [chunk_size, chunk_size, chunk_size, embedding_dim]
chunk_embeddings = torch.zeros(
(self.chunk_size, self.chunk_size, self.chunk_size, self.embedding_dim),
dtype=torch.float32,
)
# Fill in embeddings for each position
for y in range(self.chunk_size):
for z in range(self.chunk_size):
for x in range(self.chunk_size):
block_idx = chunk_tensor[y, z, x].item()
chunk_embeddings[y, z, x] = self.block_embeddings[block_idx]
return {
"blocks": chunk_tensor, # Original block indices
"block_embeddings": chunk_embeddings, # Block embeddings
"mask": mask_tensor,
"file_path": self.file_info[idx]["path"],
}
# Example usage
if __name__ == "__main__":
# Create the dataset
dataset = MinecraftSchematicDataset(
schematics_dir="minecraft-schematics-raw",
chunk_size=16,
cache_file="cache/block_mappings.pkl",
max_files=100, # Limit to 100 files for testing
)
# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)
# Get a batch
batch = next(iter(dataloader))
# Print batch info
print(f"Batch size: {batch['blocks'].shape}")
print(f"Mask size: {batch['mask'].shape}")
# Print some statistics
print(f"\nDataset size: {len(dataset)}")
print(f"Number of unique blocks: {len(dataset.block_to_idx)}")
# Print the most common blocks
print("\nBlock mapping:")
for i in range(min(10, len(dataset.idx_to_block))):
print(f"{i}: {dataset.idx_to_block.get(i, '<unknown>')}")
# Calculate the percentage of valid (non-padded) positions in the batch
valid_percentage = batch["mask"].float().mean().item() * 100
print(f"\nPercentage of valid positions in batch: {valid_percentage:.2f}%")