Last active
May 24, 2024 23:44
-
-
Save wassname/9f410d11f33cec75393b64d62286dd41 to your computer and use it in GitHub Desktop.
Craftax symbolic to env.state to pixel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
craftax = "1.4.1" | |
jax = "^0.4.28" | |
jax = | |
https://gist.github.com/wassname/9f410d11f33cec75393b64d62286dd41 | |
""" | |
import numpy as np | |
import numpy as np | |
from craftax.craftax.craftax_state import EnvState, Inventory, Mobs | |
from craftax.craftax.constants import MAX_OBS_DIM, OBS_DIM, BlockType, ItemType, MONSTERS_KILLED_TO_CLEAR_LEVEL | |
from craftax.craftax.util.game_logic_utils import is_boss_vulnerable | |
def inverse_render_craftax_symbolic(flattened, env_state): | |
# Extract sizes | |
all_map_size = np.prod(OBS_DIM) * (len(BlockType) + len(ItemType) + 5 * 8 + 1) | |
inventory_size = 16 | |
potions_size = 6 | |
armour_size = 4 | |
armour_enchantments_size = 4 | |
intrinsics_size = 9 | |
direction_size = 4 | |
special_values_size = 8 | |
# Split the flattened array into respective components | |
start = 0 | |
end = all_map_size | |
all_map_flattened = flattened[start:end].reshape((*OBS_DIM, len(BlockType) + len(ItemType) + 5 * 8 + 1)) | |
start = end | |
end += inventory_size | |
inventory = flattened[start:end] | |
start = end | |
end += potions_size | |
potions = flattened[start:end].astype(np.int32) | |
start = end | |
end += intrinsics_size | |
intrinsics = flattened[start:end] | |
start = end | |
end += direction_size | |
direction = flattened[start:end] | |
start = end | |
end += armour_size | |
armour = flattened[start:end].astype(np.int32) | |
start = end | |
end += armour_enchantments_size | |
armour_enchantments = flattened[start:end] | |
start = end | |
end += special_values_size | |
special_values = flattened[start:end] | |
# Reconstruct map, item map, mob map, and light map | |
map = all_map_flattened[:, :, :len(BlockType)].argmax(-1) | |
item_map = all_map_flattened[:, :, len(BlockType):len(BlockType) + len(ItemType)].argmax(-1) | |
mob_map = all_map_flattened[:, :, len(BlockType) + len(ItemType):-1].argmax(-1) | |
light_map_view = all_map_flattened[:, :, -1] > 0.05 | |
# pad the maps from (9,11) to to 9, 48, 48 | |
map = np.pad(map[None, :], ((0, 8), (20, 19), (19, 18)), mode='constant', constant_values=0) | |
item_map = np.pad(item_map[None, :], ((0, 8), (20, 19), (19, 18)), mode='constant', constant_values=0) | |
mob_map = np.pad(mob_map[None, :], ((0, 8), (20, 19), (19, 18)), mode='constant', constant_values=0) | |
light_map_view = np.pad(light_map_view[None, :], ((0, 8), (20, 19), (19, 18)), mode='constant', constant_values=0) | |
# Extract mobs information from mob_map | |
mobs = {} # This would need a more detailed extraction based on mob_map encoding | |
# Extract other state components | |
state = EnvState( | |
map=map, | |
item_map=item_map, | |
mob_map=mob_map, | |
light_map=light_map_view, | |
down_ladders=env_state.down_ladders, # Assuming placeholders for missing data | |
up_ladders=env_state.up_ladders, | |
chests_opened=env_state.chests_opened, | |
monsters_killed=env_state.monsters_killed, | |
player_position=env_state.player_position, | |
player_level=int(special_values[5] * 10.0), | |
player_direction=np.argmax(direction) + 1, | |
player_health=intrinsics[0] * 10.0, | |
player_food=int(intrinsics[1] * 10.0), | |
player_drink=int(intrinsics[2] * 10.0), | |
player_energy=int(intrinsics[3] * 10.0), | |
player_mana=int(intrinsics[4] * 10.0), | |
is_sleeping=bool(special_values[1]), | |
is_resting=bool(special_values[2]), | |
player_recover=env_state.player_recover, | |
player_hunger=env_state.player_hunger, | |
player_thirst=env_state.player_thirst, | |
player_fatigue=env_state.player_fatigue, | |
player_recover_mana=env_state.player_recover_mana, | |
player_xp=int(intrinsics[5] * 10.0), | |
player_dexterity=int(intrinsics[6] * 10.0), | |
player_strength=int(intrinsics[7] * 10.0), | |
player_intelligence=int(intrinsics[8] * 10.0), | |
inventory=Inventory( | |
wood=int((inventory[0] * 10.0) ** 2), | |
stone=int((inventory[1] * 10.0) ** 2), | |
coal=int((inventory[2] * 10.0) ** 2), | |
iron=int((inventory[3] * 10.0) ** 2), | |
diamond=int((inventory[4] * 10.0) ** 2), | |
sapling=int((inventory[5] * 10.0) ** 2), | |
pickaxe=int(inventory[11] * 4.0), | |
sword=int(inventory[12] * 4.0), | |
bow=int(inventory[15]), | |
arrows=int((inventory[9] * 10.0) ** 2), | |
armour=armour * 2, | |
torches=int((inventory[8] * 10.0) ** 2), | |
ruby=int((inventory[6] * 10.0) ** 2), | |
sapphire=int((inventory[7] * 10.0) ** 2), | |
potions=(potions * 10) ** 2, | |
books=int(inventory[10] * 2.0) | |
), | |
melee_mobs=env_state.melee_mobs, | |
passive_mobs=env_state.passive_mobs, | |
ranged_mobs=env_state.ranged_mobs, | |
mob_projectiles=env_state.mob_projectiles, | |
mob_projectile_directions=env_state.mob_projectile_directions, | |
player_projectiles=env_state.player_projectiles, | |
player_projectile_directions=env_state.player_projectile_directions, | |
growing_plants_positions=env_state.growing_plants_positions, | |
growing_plants_age=env_state.growing_plants_age, | |
growing_plants_mask=env_state.growing_plants_mask, | |
potion_mapping=env_state.potion_mapping, | |
learned_spells=special_values[3:5].astype(bool), | |
sword_enchantment=int(inventory[13]), | |
bow_enchantment=int(inventory[14]), | |
armour_enchantments=armour_enchantments.astype(int), | |
boss_progress=env_state.boss_progress, | |
boss_timesteps_to_spawn_this_round=env_state.boss_timesteps_to_spawn_this_round, | |
light_level=float(special_values[0]), | |
achievements=env_state.achievements, | |
state_rng=env_state.state_rng, | |
timestep=env_state.timestep, | |
fractal_noise_angles=env_state.fractal_noise_angles | |
) | |
return state | |
# Create environment | |
from craftax.craftax_env import make_craftax_env_from_name | |
env = make_craftax_env_from_name("Craftax-Symbolic-v1", auto_reset=True) | |
env_params = env.default_params | |
import jax | |
rng = jax.random.PRNGKey(0) | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 3) | |
rngs | |
# Get an initial state and observation | |
obs, state = env.reset(rngs[0], env_params) | |
env_state2 = inverse_render_craftax_symbolic(obs[0].numpy(), env_state=env.env_state, ) | |
im = render_craftax_pixels(env_state2, block_pixel_size=10).astype(np.uint8) | |
plt.imshow(im) | |
# compare to | |
im = render_craftax_pixels(env.state, block_pixel_size=10).astype(np.uint8) | |
plt.imshow(im) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment