Skip to content

Instantly share code, notes, and snippets.

@frogermcs
Created June 9, 2019 12:28
Show Gist options
  • Save frogermcs/3f7a3317615ad3b2db9f5670c39b23ec to your computer and use it in GitHub Desktop.
Save frogermcs/3f7a3317615ad3b2db9f5670c39b23ec to your computer and use it in GitHub Desktop.
public class ModelClassificator {
private static final int MAX_CLASSIFICATION_RESULTS = 3;
private static final float CLASSIFICATION_THRESHOLD = 0.2f;
private final Interpreter interpreter;
private final List<String> labels;
private final ModelConfig modelConfig;
public ModelClassificator(Context context,
ModelConfig modelConfig) throws IOException {
ByteBuffer model = AssetsUtils.loadFile(context, modelConfig.getModelFilename());
this.interpreter = new Interpreter(model);
this.labels = AssetsUtils.loadLines(context, modelConfig.getLabelsFilename());
this.modelConfig = modelConfig;
}
public List<ClassificationResult> process(Bitmap bitmap) {
Bitmap toClassify = ThumbnailUtils.extractThumbnail(
bitmap, modelConfig.getInputWidth(), modelConfig.getInputHeight()
);
ByteBuffer byteBufferToClassify = bitmapToModelsMatchingByteBuffer(toClassify);
float[][] result = new float[1][labels.size()];
interpreter.run(byteBufferToClassify, result);
/* ... */
return getSortedResult(result);
}
private ByteBuffer bitmapToModelsMatchingByteBuffer(Bitmap bitmap) { /* ... */ }
private float[] pixelToChannelValues(int pixel) { /* ... */ }
private List<ClassificationResult> getSortedResult(float[][] resultsArray) { /* ... */ }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment