Skip to content

Instantly share code, notes, and snippets.

Created August 17, 2016 14:03
Show Gist options
  • Save teh/79635edd21d49a29769f30e62bfff19c to your computer and use it in GitHub Desktop.
Save teh/79635edd21d49a29769f30e62bfff19c to your computer and use it in GitHub Desktop.
import sklearn.base as base
import sklearn.linear_model as lm
import numpy
class ByThreshold(base.BaseEstimator, base.ClassifierMixin):
def __init__(self, estimator, threshold=0.95):
self.threshold = threshold
self.estimator = estimator
def get_params(self):
return {
'threshold': self.threshold,
'estimator': self.estimator,
def set_params(self, params):
self.threshold = params['threshold']
self.estimator = params['estimator']
def fit(self, X, y):
return, y)
def predict(self, X):
p = self.estimator.predict_proba(X)
ix = p.max(axis=1) < self.threshold
y = self.estimator.predict(X)
y[ix] = -1
return y
X = numpy.array([[0, 0, 1],[1, 0, 1],[1, 1, 1]])
y = numpy.array([0,1,1])
clf = ByThreshold(lm.LogisticRegression(), 0.6), y)
print clf.predict([
[1, 1, 0],
[0, 0, 1],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment