Created
August 9, 2023 16:46
-
-
Save tomaarsen/2a4c3ddcde17260897d62cc0ccbca516 to your computer and use it in GitHub Desktop.
Keyphrase extraction model with SpanMarker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset, concatenate_datasets | |
from transformers import TrainingArguments | |
from span_marker import SpanMarkerModel, Trainer | |
def main() -> None: | |
# Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels | |
dataset = load_dataset("midas/inspec", "extraction") | |
dataset = dataset.rename_columns({"document": "tokens", "doc_bio_tags": "ner_tags"}) | |
# Map string labels to integer labels instead | |
real_labels = ["O", "B", "I"] | |
dataset = dataset.map(lambda sample: {"ner_tags": [real_labels.index(tag) for tag in sample]}, input_columns="ner_tags") | |
# Use more readable labels | |
labels = ["O", "B-KEY", "I-KEY"] | |
# Train using train + validation set. | |
train_dataset = concatenate_datasets((dataset["train"], dataset["validation"])) | |
# Initialize a SpanMarker model using a pretrained BERT-style encoder | |
model_name = "bert-base-cased" | |
model = SpanMarkerModel.from_pretrained( | |
model_name, | |
labels=labels, | |
# SpanMarker hyperparameters: | |
model_max_length=256, | |
marker_max_length=128, | |
entity_max_length=8, | |
) | |
# Prepare the 🤗 transformers training arguments | |
args = TrainingArguments( | |
output_dir=f"models/span_marker_bert_base_cased_keyphrase_inspec", | |
run_name=f"bb_keyphrase", | |
# Training Hyperparameters: | |
learning_rate=5e-5, | |
per_device_train_batch_size=32, | |
per_device_eval_batch_size=32, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
warmup_ratio=0.1, | |
bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16. | |
# Other Training parameters | |
logging_first_step=True, | |
logging_steps=50, | |
evaluation_strategy="no", | |
save_strategy="steps", | |
# eval_steps=300, | |
save_total_limit=2, | |
dataloader_num_workers=2, | |
) | |
# Initialize the trainer using our model, training args & dataset, and train | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=train_dataset | |
) | |
trainer.train() | |
trainer.save_model(f"models/span_marker_bert_base_cased_keyphrase_inspec/checkpoint-final") | |
# Compute & save the metrics on the test set | |
metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test") | |
trainer.save_metrics("test", metrics) | |
trainer.create_model_card() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment