Created
August 17, 2016 14:03
-
-
Save teh/79635edd21d49a29769f30e62bfff19c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sklearn.base as base | |
import sklearn.linear_model as lm | |
import numpy | |
class ByThreshold(base.BaseEstimator, base.ClassifierMixin): | |
def __init__(self, estimator, threshold=0.95): | |
self.threshold = threshold | |
self.estimator = estimator | |
def get_params(self): | |
return { | |
'threshold': self.threshold, | |
'estimator': self.estimator, | |
} | |
def set_params(self, params): | |
self.threshold = params['threshold'] | |
self.estimator = params['estimator'] | |
def fit(self, X, y): | |
return self.estimator.fit(X, y) | |
def predict(self, X): | |
p = self.estimator.predict_proba(X) | |
ix = p.max(axis=1) < self.threshold | |
y = self.estimator.predict(X) | |
y[ix] = -1 | |
return y | |
X = numpy.array([[0, 0, 1],[1, 0, 1],[1, 1, 1]]) | |
y = numpy.array([0,1,1]) | |
clf = ByThreshold(lm.LogisticRegression(), 0.6) | |
clf.fit(X, y) | |
print clf.predict([ | |
[1, 1, 0], | |
[0, 0, 1], | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment