Skip to content
Snippets Groups Projects
Commit 45c3b2b9 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Support cache file reading from path.

Until now we could only read cache file by providing cache hint and
quantization type. This commit adds static method which can read any
valid cache file from file.
parent 9012d801
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,8 @@ public interface ICacheFile { ...@@ -10,6 +10,8 @@ public interface ICacheFile {
void readFromStream(DataInputStream inputStream) throws IOException; void readFromStream(DataInputStream inputStream) throws IOException;
void readFromStream(DataInputStream inputStream, CacheFileHeader header) throws IOException;
CacheFileHeader getHeader(); CacheFileHeader getHeader();
void report(StringBuilder builder); void report(StringBuilder builder);
......
...@@ -36,7 +36,7 @@ public class QuantizationCacheManager { ...@@ -36,7 +36,7 @@ public class QuantizationCacheManager {
private File getCacheFilePathForSQ(final String trainFile, final int codebookSize) { private File getCacheFilePathForSQ(final String trainFile, final int codebookSize) {
final File inputFile = new File(trainFile); final File inputFile = new File(trainFile);
return new File(cacheFolder, String.format("%s_%d_bits.qvc", return new File(cacheFolder, String.format("%s_%d_bits.qvc",
inputFile.getName(), codebookSize)); inputFile.getName(), codebookSize));
} }
/** /**
...@@ -52,7 +52,7 @@ public class QuantizationCacheManager { ...@@ -52,7 +52,7 @@ public class QuantizationCacheManager {
final V3i vDim) { final V3i vDim) {
final File inputFile = new File(trainFile); final File inputFile = new File(trainFile);
return new File(cacheFolder, String.format("%s_%d_%dx%dx%d.qvc", inputFile.getName(), codebookSize, return new File(cacheFolder, String.format("%s_%d_%dx%dx%d.qvc", inputFile.getName(), codebookSize,
vDim.getX(), vDim.getY(), vDim.getZ())); vDim.getX(), vDim.getY(), vDim.getZ()));
} }
...@@ -143,8 +143,8 @@ public class QuantizationCacheManager { ...@@ -143,8 +143,8 @@ public class QuantizationCacheManager {
*/ */
public String saveCodebook(final String trainFile, final VQCodebook codebook) throws IOException { public String saveCodebook(final String trainFile, final VQCodebook codebook) throws IOException {
final String fileName = getCacheFilePathForVQ(trainFile, final String fileName = getCacheFilePathForVQ(trainFile,
codebook.getCodebookSize(), codebook.getCodebookSize(),
codebook.getVectorDims()).getAbsolutePath(); codebook.getVectorDims()).getAbsolutePath();
final CacheFileHeader header = createHeaderForVQ(new File(trainFile).getName(), codebook); final CacheFileHeader header = createHeaderForVQ(new File(trainFile).getName(), codebook);
final VQCacheFile cacheFile = new VQCacheFile(header, codebook); final VQCacheFile cacheFile = new VQCacheFile(header, codebook);
...@@ -282,6 +282,21 @@ public class QuantizationCacheManager { ...@@ -282,6 +282,21 @@ public class QuantizationCacheManager {
return null; return null;
} }
public static ICacheFile readCacheFile(final String path) {
try (FileInputStream fis = new FileInputStream(path);
DataInputStream dis = new DataInputStream(fis)) {
CacheFileHeader header = new CacheFileHeader();
header.readFromStream(dis);
ICacheFile cacheFile = getCacheFile(header.getQuantizationType());
assert (cacheFile != null);
cacheFile.readFromStream(dis, header);
return cacheFile;
} catch (IOException e) {
return null;
}
}
/** /**
* Inspect cache file specified by the path. * Inspect cache file specified by the path.
* *
......
...@@ -5,8 +5,6 @@ import azgracompress.quantization.scalar.SQCodebook; ...@@ -5,8 +5,6 @@ import azgracompress.quantization.scalar.SQCodebook;
import java.io.DataInputStream; import java.io.DataInputStream;
import java.io.DataOutputStream; import java.io.DataOutputStream;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.stream.Collectors;
public class SQCacheFile implements ICacheFile { public class SQCacheFile implements ICacheFile {
private CacheFileHeader header; private CacheFileHeader header;
...@@ -36,9 +34,11 @@ public class SQCacheFile implements ICacheFile { ...@@ -36,9 +34,11 @@ public class SQCacheFile implements ICacheFile {
public void readFromStream(DataInputStream inputStream) throws IOException { public void readFromStream(DataInputStream inputStream) throws IOException {
header = new CacheFileHeader(); header = new CacheFileHeader();
header.readFromStream(inputStream); header.readFromStream(inputStream);
readFromStream(inputStream, header);
}
public void readFromStream(DataInputStream inputStream, CacheFileHeader header) throws IOException {
final int codebookSize = header.getCodebookSize(); final int codebookSize = header.getCodebookSize();
final int[] centroids = new int[codebookSize]; final int[] centroids = new int[codebookSize];
final long[] frequencies = new long[codebookSize]; final long[] frequencies = new long[codebookSize];
......
...@@ -38,9 +38,12 @@ public class VQCacheFile implements ICacheFile { ...@@ -38,9 +38,12 @@ public class VQCacheFile implements ICacheFile {
public void readFromStream(DataInputStream inputStream) throws IOException { public void readFromStream(DataInputStream inputStream) throws IOException {
header = new CacheFileHeader(); header = new CacheFileHeader();
header.readFromStream(inputStream); header.readFromStream(inputStream);
readFromStream(inputStream, header);
}
@Override
public void readFromStream(DataInputStream inputStream, CacheFileHeader header) throws IOException {
final int codebookSize = header.getCodebookSize(); final int codebookSize = header.getCodebookSize();
final CodebookEntry[] vectors = new CodebookEntry[codebookSize]; final CodebookEntry[] vectors = new CodebookEntry[codebookSize];
final long[] frequencies = new long[codebookSize]; final long[] frequencies = new long[codebookSize];
......
...@@ -43,6 +43,16 @@ public class VQCodebook { ...@@ -43,6 +43,16 @@ public class VQCodebook {
return vectors; return vectors;
} }
public int[][] getRawVectors() {
assert (codebookSize == vectors.length);
assert (vectors[0].getVector().length == (int) vectorDims.multiplyTogether());
final int[][] rawCodebook = new int[vectors.length][(int) vectorDims.multiplyTogether()];
for (int i = 0; i < codebookSize; i++) {
rawCodebook[i] = vectors[i].getVector();
}
return rawCodebook;
}
/** /**
* Get frequencies of codebook vectors at indices. * Get frequencies of codebook vectors at indices.
* *
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment