Created
February 21, 2021 19:16
-
-
Save eugeneyan/5ca7d2aa5683be497fa3c03bf4c608cb to your computer and use it in GitHub Desktop.
Test RandomForest performs better with same depth
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_rf_better_than_dt(dummy_titanic): | |
X_train, y_train, X_test, y_test = dummy_titanic | |
dt = DecisionTree(depth_limit=10) | |
dt.fit(X_train, y_train) | |
rf = RandomForest(depth_limit=10, num_trees=7, col_subsampling=0.8, row_subsampling=0.8) | |
rf.fit(X_train, y_train) | |
pred_test_dt = dt.predict(X_test) | |
pred_test_binary_dt = np.round(pred_test_dt) | |
acc_test_dt = accuracy_score(y_test, pred_test_binary_dt) | |
auc_test_dt = roc_auc_score(y_test, pred_test_dt) | |
pred_test_rf = rf.predict(X_test) | |
pred_test_binary_rf = np.round(pred_test_rf) | |
acc_test_rf = accuracy_score(y_test, pred_test_binary_rf) | |
auc_test_rf = roc_auc_score(y_test, pred_test_rf) | |
assert acc_test_rf > acc_test_dt, 'RandomForest should have higher accuracy than DecisionTree on test set.' | |
assert auc_test_rf > auc_test_dt, 'RandomForest should have higher AUC ROC than DecisionTree on test set.' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment