https://github.com/state-spaces/mamba
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer, AutoModelForCausalLM
import torch
from functools import partial
from collections import OrderedDict, defaultdict
import os
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
def annotate_module_static_attr(top_module, family_name=None):
# static attr:
# first_name, last_name, class_name, is_leaf_module, leaf_has_weight
if family_name is None:
family = top_module.__class__.__name__.lower() + "class_as_family_name"
else:
family = family_name
for parent_name, parent_module in top_module.named_modules():
# handle top level because children loop below operate one level below, top level module will be missed
if parent_name == "":
parent_module.first_name = family
parent_module.last_name = ""
for child_name, child_module in parent_module.named_children():
child_module.first_name = child_name
if parent_name == "":
# just to handle the period if we dont do this conditional loop
child_module.last_name = f"{family}"
else:
child_module.last_name = f"{family}.{parent_name}"
# Following applies to every module
parent_module.leaf_module = False
if len(list(parent_module.children())) == 0:
parent_module.is_leaf_module = True
parent_module.leaf_has_weight = False
if len(list(parent_module.parameters())) > 0:
parent_module.leaf_has_weight = True
parent_module.class_name = parent_module.__class__.__name__
parent_module.full_name = f"{parent_module.last_name}.{parent_module.first_name}" # must be put at last
model_id="microsoft/Phi-3-mini-4k-instruct"
model_id="state-spaces/mamba-2.8b"
model_id="state-spaces/mamba2-2.7b"
# model_id="mistralai/Mamba-Codestral-7B-v0.1"
maxlen = 10
device = "cuda"
dtype = torch.float16
if model_id.startswith("mistralai/Mamba"):
from huggingface_hub import snapshot_download
from pathlib import Path
mistral_models_path = Path.home().joinpath('mistral_models', 'Mamba-Codestral-7B-v0.1')
if not mistral_models_path.exists():
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistralai/Mamba-Codestral-7B-v0.1", allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], local_dir=mistral_models_path)
exit()
#TODO not working for "mistralai/Mamba-Codestral-7B-v0.1", doesnt work natively with HF transformers
is_mamba = model_id.startswith("state-spaces/mamba") or model_id.startswith("state-spaces/transformerpp") or model_id.startswith("mistralai/Mamba")
if is_mamba is True:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(model_id, device=device, dtype=dtype)
else:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, torch_dtype=dtype)
top_module = model
annotate_module_static_attr(top_module=top_module, family_name=os.path.basename(model_id))
modtype_to_modlist = defaultdict(list)
modname_to_modtype = OrderedDict()
modname_to_module = OrderedDict()
for n, m in top_module.named_modules():
modtype_to_modlist[m.class_name].append(f"{m.last_name}.{m.first_name}")
modname_to_modtype[m.full_name] = m.class_name
modname_to_module[m.full_name] = m
layer_dump = defaultdict(list)
def hook(module, input, output):
layer_dump[module.full_name].append(
dict(
ifm=tuple(input[0].shape),
wei=tuple(module.weight.shape),
ofm=tuple(output.shape),
)
)
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(hook))
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
if is_mamba is True:
out = model.generate(
input_ids=input_ids,
max_length=maxlen,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=0.7,
top_k=1,
top_p=0.9,
min_p=0.0,
repetition_penalty=1.2,
)
else:
out = model.generate(input_ids, max_new_tokens=maxlen)
if is_mamba is True:
print(tokenizer.batch_decode(out.sequences.tolist()))
else:
print(tokenizer.batch_decode(out))
with open(f"layerwise_dump_{os.path.basename(model_id)}.csv", "w") as csvout:
step=1
for step in [0, 1]:
for lid, l in enumerate(layer_dump):
d = layer_dump[l][step]
rpt_str = f"l{lid};{l};{step};i:{d['ifm']};w:{d['wei']};o:{d['ofm']}"
print(rpt_str)
csvout.write(rpt_str+"\n")
print("end.")