Created
August 17, 2023 00:40
-
-
Save navneet1v/87c693f015ae1267984173df7036ff32 to your computer and use it in GitHub Desktop.
Experiment file for doing vector distance calculation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* This file was generated by the Gradle 'init' task. | |
* | |
* This is a general purpose Gradle build. | |
* Learn more about Gradle by exploring our samples at https://docs.gradle.org/7.4.2/samples | |
*/ | |
apply plugin: 'java' | |
buildscript { | |
repositories { | |
mavenLocal() | |
//maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } | |
mavenCentral() | |
maven { url "https://plugins.gradle.org/m2/" } | |
} | |
} | |
repositories { | |
mavenLocal() | |
mavenCentral() | |
maven { url "https://plugins.gradle.org/m2/" } | |
} | |
dependencies { | |
implementation 'org.apache.lucene:lucene-core:9.7.0' | |
implementation 'org.apache.lucene:lucene-queryparser:9.7.0' | |
compileOnly 'org.projectlombok:lombok:1.18.20' | |
annotationProcessor 'org.projectlombok:lombok:1.18.20' | |
} | |
compileJava { | |
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright OpenSearch Contributors | |
* SPDX-License-Identifier: Apache-2.0 | |
*/ | |
package navneev.vector; | |
import java.io.ByteArrayInputStream; | |
import java.nio.ByteBuffer; | |
import java.nio.ByteOrder; | |
import java.util.stream.IntStream; | |
/** | |
* Class implements KNNVectorSerializer based on serialization/deserialization of array as collection of individual numbers | |
*/ | |
public class KNNVectorAsCollectionOfFloatsSerializer { | |
private static final int BYTES_IN_FLOAT = 4; | |
public static byte[] floatToByteArray(float[] input) { | |
final ByteBuffer bb = ByteBuffer.allocate(input.length * BYTES_IN_FLOAT).order(ByteOrder.BIG_ENDIAN); | |
IntStream.range(0, input.length).forEach((index) -> bb.putFloat(input[index])); | |
byte[] bytes = new byte[bb.flip().limit()]; | |
bb.get(bytes); | |
return bytes; | |
} | |
public static float[] byteToFloatArray(ByteArrayInputStream byteStream) { | |
if (byteStream == null || byteStream.available() % BYTES_IN_FLOAT != 0) { | |
throw new IllegalArgumentException("Byte stream cannot be deserialized to array of floats"); | |
} | |
final byte[] vectorAsByteArray = new byte[byteStream.available()]; | |
byteStream.read(vectorAsByteArray, 0, byteStream.available()); | |
final int sizeOfFloatArray = vectorAsByteArray.length / BYTES_IN_FLOAT; | |
final float[] vector = new float[sizeOfFloatArray]; | |
ByteBuffer.wrap(vectorAsByteArray).asFloatBuffer().get(vector); | |
return vector; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package navneev.vector; | |
import lombok.Builder; | |
import lombok.SneakyThrows; | |
import lombok.Value; | |
import org.apache.lucene.index.VectorSimilarityFunction; | |
import org.apache.lucene.util.BytesRef; | |
import java.io.BufferedWriter; | |
import java.io.ByteArrayInputStream; | |
import java.io.File; | |
import java.io.FileWriter; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Locale; | |
import java.util.Map; | |
import java.util.Random; | |
/** | |
* This is a testing class that will help us understand how much time does it take to calculate the distances for | |
* different dimensions. | |
*/ | |
public class VectorDistanceCalculationTime { | |
private static final Random random = new Random(1212121212); | |
private static final List<Integer> DIMENSIONS_LIST = List.of(64, 128, 256, 512, 768, 968, 1024, 2048, 4096); | |
private static final List<Integer> CORPUS_SIZE = List.of(1000, 2000, 5000, 8000, 10000, 15000); | |
private static final VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; | |
public static void main(String[] args) { | |
runWarmup(); | |
List<Result> resultList = new ArrayList<>(); | |
for(int corpusSize : CORPUS_SIZE) { | |
System.out.printf("\nData for Corpus Size: %d\n", corpusSize); | |
for (int dimension : DIMENSIONS_LIST) { | |
System.gc(); | |
resultList.add(runBruteForce(dimension, corpusSize)); | |
} | |
} | |
writeResultsInFile(resultList); | |
} | |
private static void runWarmup() { | |
float[][] warmupQueries = loadData(10, 1); | |
float[][] warmupDataSet = loadData(10, 10); | |
List<BytesRef> byteDataSet = new ArrayList<>(); | |
for(float[] floats : warmupDataSet) { | |
byteDataSet.add(new BytesRef(KNNVectorAsCollectionOfFloatsSerializer.floatToByteArray(floats))); | |
} | |
for (float[] query : warmupQueries) { | |
for (BytesRef value : byteDataSet) { | |
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); | |
vectorSimilarityFunction.compare(query, | |
KNNVectorAsCollectionOfFloatsSerializer.byteToFloatArray(byteStream)); | |
} | |
} | |
} | |
private static Result runBruteForce(int dimensions, int corpusSize) { | |
int numberOfTimes = 10; | |
int queriesCount = 1; | |
// load data from file | |
float bestTime = Float.MAX_VALUE; | |
float totalTime = 0; | |
for(int i = 0; i < numberOfTimes; i++) { | |
float[][] dataset = loadData(dimensions, corpusSize); | |
float[][] queries = loadData(dimensions, queriesCount); | |
List<BytesRef> byteDataSet = new ArrayList<>(); | |
for(float[] floats : dataset) { | |
byteDataSet.add(new BytesRef(KNNVectorAsCollectionOfFloatsSerializer.floatToByteArray(floats))); | |
} | |
long startTime = System.nanoTime(); | |
for (float[] query : queries) { | |
for (BytesRef value : byteDataSet) { | |
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); | |
vectorSimilarityFunction.compare(query, | |
KNNVectorAsCollectionOfFloatsSerializer.byteToFloatArray(byteStream)); | |
} | |
} | |
long endTime = System.nanoTime(); | |
float tookTime = ((float) (endTime - startTime)) / 1000000; | |
bestTime = Math.min(bestTime, tookTime); | |
totalTime = tookTime + totalTime; | |
} | |
System.out.printf("For %d query, corpus size %d, dimensions %d Best Time %f millisecond, Total Time for %d " + | |
"runs : %f millisecond \n", queriesCount, corpusSize, dimensions, bestTime, numberOfTimes, totalTime); | |
return Result.builder().corpusSize(corpusSize).bestTime(bestTime).totalTime(totalTime).dimension(dimensions).build(); | |
} | |
private static float[][] loadData(int dimensions, int corpusSize) { | |
float[][] vectors = new float[corpusSize][dimensions]; | |
for(int i = 0 ; i < vectors.length; i++) { | |
vectors[i] = generateRandomVector(dimensions); | |
} | |
return vectors; | |
} | |
private static float[] generateRandomVector(int dimensions) { | |
float[] vector = new float[dimensions]; | |
for(int i = 0 ; i < dimensions; i++) { | |
vector[i] = -500 + (float) random.nextGaussian() * (1000); | |
} | |
return vector; | |
} | |
@SneakyThrows | |
private static void writeResultsInFile(List<Result> resultList) { | |
String fileName = "results.csv"; | |
File file = new File(fileName); | |
if(file.exists()) { | |
System.out.println("Deleting File"); | |
file.delete(); | |
} else { | |
file.createNewFile(); | |
} | |
final StringBuilder header = new StringBuilder("dimensions,"); | |
for(int corpusSize: CORPUS_SIZE) { | |
header.append(corpusSize).append("_latency_serialized(ms),"); | |
} | |
BufferedWriter writer = new BufferedWriter(new FileWriter(fileName, true)); | |
writer.append(header.toString()); | |
writer.newLine(); | |
Map<Integer,List<Float>> dimensionToLatencyList = new HashMap<>(); | |
for(Result res : resultList) { | |
if(!dimensionToLatencyList.containsKey(res.getDimension())) { | |
dimensionToLatencyList.put(res.getDimension(), new ArrayList<>()); | |
} | |
dimensionToLatencyList.get(res.getDimension()).add(res.bestTime); | |
} | |
for(int dim : DIMENSIONS_LIST) { | |
List<Float> latencies = dimensionToLatencyList.get(dim); | |
writer.append(String.valueOf(dim)).append(","); | |
for(float lat : latencies) { | |
writer.append(String.valueOf(lat)).append(","); | |
} | |
writer.newLine(); | |
} | |
writer.close(); | |
System.out.println("Results written in the file at : " + file.getAbsoluteFile()); | |
} | |
private static void printArray(float[] array) { | |
System.out.print("["); | |
for(float f: array) { | |
System.out.printf("%f ",f); | |
} | |
System.out.print("]"); | |
} | |
@Value | |
@Builder | |
private static class Result { | |
float bestTime; | |
float totalTime; | |
int dimension; | |
int corpusSize; | |
@Override | |
public String toString() { | |
return String.format(Locale.ROOT, "%d,%d,%f,%f", dimension, corpusSize, bestTime, totalTime); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment