Skip to content
Snippets Groups Projects
Commit 7aa915c7 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

New code for codebook caching.

parent ffb98c89
No related branches found
No related tags found
No related merge requests found
package azgracompress.cache;
import azgracompress.data.V3i;
import azgracompress.fileformat.QuantizationType;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
public class CacheFileHeader {
public static final String QCMP_CACHE_MAGIC_VALUE = "QCMPCACHE";
private String magicValue;
private QuantizationType quantizationType;
private int codebookSize;
private int trainFileNameSize;
private String trainFileName;
private int vectorSizeX;
private int vectorSizeY;
private int vectorSizeZ;
public void setQuantizationType(QuantizationType quantizationType) {
this.quantizationType = quantizationType;
}
public void setCodebookSize(int codebookSize) {
this.codebookSize = codebookSize;
}
public void setTrainFileName(String trainFileName) {
this.trainFileName = trainFileName;
this.trainFileNameSize = this.trainFileName.length();
}
public void setVectorSizeX(int vectorSizeX) {
this.vectorSizeX = vectorSizeX;
}
public void setVectorSizeY(int vectorSizeY) {
this.vectorSizeY = vectorSizeY;
}
public void setVectorSizeZ(int vectorSizeZ) {
this.vectorSizeZ = vectorSizeZ;
}
public QuantizationType getQuantizationType() {
return quantizationType;
}
public int getCodebookSize() {
return codebookSize;
}
public int getTrainFileNameSize() {
return trainFileNameSize;
}
public String getTrainFileName() {
return trainFileName;
}
public int getVectorSizeX() {
return vectorSizeX;
}
public int getVectorSizeY() {
return vectorSizeY;
}
public int getVectorSizeZ() {
return vectorSizeZ;
}
public V3i getVectorDim() {
return new V3i(vectorSizeX, vectorSizeY, vectorSizeZ);
}
/**
* Write QCMP cache file header to stream.
*
* @param outputStream Data output stream.
* @throws IOException when fails to write the header to stream.
*/
public void writeToStream(DataOutputStream outputStream) throws IOException {
outputStream.writeBytes(QCMP_CACHE_MAGIC_VALUE);
outputStream.writeByte(quantizationType.getValue());
outputStream.writeShort(codebookSize);
outputStream.writeShort(trainFileName.length());
outputStream.writeBytes(trainFileName);
outputStream.writeShort(vectorSizeX);
outputStream.writeShort(vectorSizeY);
outputStream.writeShort(vectorSizeZ);
}
/**
* Read header from the stream.
*
* @param inputStream Data input stream.
*/
public void readFromStream(DataInputStream inputStream) throws IOException {
final int MIN_AVAIL = 9;
if (inputStream.available() < MIN_AVAIL) {
throw new IOException("Invalid file. File too small.");
}
byte[] magicBuffer = new byte[QCMP_CACHE_MAGIC_VALUE.length()];
final int readFromMagic = inputStream.readNBytes(magicBuffer, 0, QCMP_CACHE_MAGIC_VALUE.length());
if (readFromMagic != QCMP_CACHE_MAGIC_VALUE.length()) {
throw new IOException("Invalid file type. Unable to read magic value");
}
magicValue = new String(magicBuffer);
if (!magicValue.equals(QCMP_CACHE_MAGIC_VALUE)) {
throw new IOException("Invalid file type. Wrong magic value.");
}
quantizationType = QuantizationType.fromByte(inputStream.readByte());
codebookSize = inputStream.readUnsignedShort();
trainFileNameSize = inputStream.readUnsignedShort();
byte[] fileNameBuffer = new byte[trainFileNameSize];
inputStream.readNBytes(fileNameBuffer, 0, trainFileNameSize);
trainFileName = new String(fileNameBuffer);
vectorSizeX = inputStream.readUnsignedShort();
vectorSizeY = inputStream.readUnsignedShort();
vectorSizeZ = inputStream.readUnsignedShort();
}
public void setVectorDims(final V3i v3i) {
this.vectorSizeX = v3i.getX();
this.vectorSizeY = v3i.getY();
this.vectorSizeZ = v3i.getZ();
}
public void report(final StringBuilder sb) {
sb.append("Magic: ").append(magicValue).append('\n');
sb.append("CodebookType: ");
switch (quantizationType) {
case Scalar:
sb.append("Scalar\n");
break;
case Vector1D:
sb.append(String.format("Vector1D %s\n", vectorSizeX));
break;
case Vector2D:
sb.append(String.format("Vector2D %s\n", new V3i(vectorSizeX, vectorSizeY, vectorSizeZ).toString()));
break;
}
sb.append("CodebookSize: ").append(codebookSize).append('\n');
sb.append("TrainFile: ").append(trainFileName).append('\n');
}
}
package azgracompress.cache;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
public interface ICacheFile {
void writeToStream(DataOutputStream outputStream) throws IOException;
void readFromStream(DataInputStream inputStream) throws IOException;
CacheFileHeader getHeader();
}
package azgracompress.cache;
import azgracompress.data.V3i;
import azgracompress.fileformat.QuantizationType;
import azgracompress.quantization.scalar.SQCodebook;
import azgracompress.quantization.vector.VQCodebook;
import java.io.*;
public class QuantizationCacheManager {
/**
* Folders where cache files are stored.
*/
private final String cacheFolder;
/**
* Create cache manager with the cache folder.
*
* @param cacheFolder Cache folder
*/
public QuantizationCacheManager(final String cacheFolder) {
this.cacheFolder = cacheFolder;
//noinspection ResultOfMethodCallIgnored
new File(this.cacheFolder).mkdirs();
}
/**
* Get cache file for scalar quantizer.
*
* @param trainFile Input image file name.
* @param codebookSize Codebook size.
* @return Cache file for scalar quantizer.
*/
private File getCacheFilePathForSQ(final String trainFile, final int codebookSize) {
final File inputFile = new File(trainFile);
return new File(cacheFolder, String.format("%s_%d_bits.qvc",
inputFile.getName(), codebookSize));
}
/**
* Get cache file for vector quantizer.
*
* @param trainFile Input image file name.
* @param codebookSize Size of the codebook.
* @param vDim Vector dimensions.
* @return Cache file for vector quantizer.
*/
private File getCacheFilePathForVQ(final String trainFile,
final int codebookSize,
final V3i vDim) {
final File inputFile = new File(trainFile);
return new File(cacheFolder, String.format("%s_%d_%dx%d.qvc", inputFile.getName(), codebookSize,
vDim.getX(), vDim.getY()));
}
/**
* Create CacheFileHeader for ScalarQuantization cache.
*
* @param trainFile Image file used for training.
* @param codebook Final SQ codebook.
* @return SQ cache file header.
*/
private CacheFileHeader createHeaderForSQ(final String trainFile, final SQCodebook codebook) {
CacheFileHeader header = new CacheFileHeader();
header.setQuantizationType(QuantizationType.Scalar);
header.setCodebookSize(codebook.getCodebookSize());
header.setTrainFileName(trainFile);
header.setVectorDims(new V3i(0));
return header;
}
/**
* Find the correct quantization type based on vector dimension.
*
* @param vectorDims Quantization vector dimensions.
* @return Correct QuantizationType.
*/
private QuantizationType getQuantizationTypeFromVectorDimensions(final V3i vectorDims) {
if (vectorDims.getX() > 1) {
if (vectorDims.getY() == 1 && vectorDims.getZ() == 1) {
return QuantizationType.Vector1D;
} else if (vectorDims.getY() > 1 && vectorDims.getZ() == 1) {
return QuantizationType.Vector2D;
} else {
return QuantizationType.Vector3D;
}
} else if (vectorDims.getX() == 1 && vectorDims.getY() > 1 && vectorDims.getZ() == 1) {
return QuantizationType.Vector1D;
}
return QuantizationType.Invalid;
}
/**
* Create CacheFileHeader for VQ cache.
*
* @param trainFile Image file used for training.
* @param codebook Final VQ codebook.
* @return VQ cache file header.
*/
private CacheFileHeader createHeaderForVQ(final String trainFile, final VQCodebook codebook) {
CacheFileHeader header = new CacheFileHeader();
header.setQuantizationType(getQuantizationTypeFromVectorDimensions(codebook.getVectorDims()));
header.setCodebookSize(codebook.getCodebookSize());
header.setTrainFileName(trainFile);
header.setVectorDims(codebook.getVectorDims());
return header;
}
/**
* Save SQ codebook to cache.
*
* @param trainFile Image file used for training.
* @param codebook SQ codebook.
* @throws IOException when fails to save the cache file.
*/
public void saveCodebook(final String trainFile, final SQCodebook codebook) throws IOException {
final String fileName = getCacheFilePathForSQ(trainFile, codebook.getCodebookSize()).getAbsolutePath();
final CacheFileHeader header = createHeaderForSQ(new File(trainFile).getName(), codebook);
final SQCacheFile cacheFile = new SQCacheFile(header, codebook);
try (FileOutputStream fos = new FileOutputStream(fileName, false);
DataOutputStream dos = new DataOutputStream(fos)) {
cacheFile.writeToStream(dos);
} catch (IOException ex) {
throw new IOException("Failed to save SQ cache file\n" + ex.getMessage());
}
}
/**
* Save VQ codebook to cache.
*
* @param trainFile Image file used for training.
* @param codebook VQ codebook.
* @throws IOException when fails to save the cache file.
*/
public void saveCodebook(final String trainFile, final VQCodebook codebook) throws IOException {
final String fileName = getCacheFilePathForVQ(trainFile,
codebook.getCodebookSize(),
codebook.getVectorDims()).getAbsolutePath();
final CacheFileHeader header = createHeaderForVQ(new File(trainFile).getName(), codebook);
final VQCacheFile cacheFile = new VQCacheFile(header, codebook);
try (FileOutputStream fos = new FileOutputStream(fileName, false);
DataOutputStream dos = new DataOutputStream(fos)) {
cacheFile.writeToStream(dos);
} catch (IOException ex) {
throw new IOException("Failed to save VQ cache file\n" + ex.getMessage());
}
}
/**
* Read data from file to cache file.
*
* @param file Cache file.
* @param cacheFile Actual cache file object.
* @return Cache file with data from disk.
* @throws IOException when fails to read the cache file from disk.
*/
private ICacheFile readCacheFile(final File file, final ICacheFile cacheFile) throws IOException {
try (FileInputStream fis = new FileInputStream(file);
DataInputStream dis = new DataInputStream(fis)) {
cacheFile.readFromStream(dis);
return cacheFile;
}
}
/**
* Load SQ cache file from disk.
*
* @param trainFile Input image file.
* @param codebookSize Codebook size.
* @return SQ cache file.
*/
private SQCacheFile loadSQCacheFile(final String trainFile, final int codebookSize) {
final File path = getCacheFilePathForSQ(trainFile, codebookSize);
try {
return (SQCacheFile) readCacheFile(path, new SQCacheFile());
} catch (IOException e) {
System.err.println("Failed to read SQ cache file." + path);
e.printStackTrace(System.err);
return null;
}
}
/**
* Read VQ cache file disk.
*
* @param trainFile Input image file.
* @param codebookSize Codebook size.
* @param vDim Quantization vector dimension.
* @return VQ cache file.
*/
private VQCacheFile loadVQCacheFile(final String trainFile,
final int codebookSize,
final V3i vDim) {
final File path = getCacheFilePathForVQ(trainFile, codebookSize, vDim);
try {
return (VQCacheFile) readCacheFile(path, new VQCacheFile());
} catch (IOException e) {
System.err.println("Failed to read VQ cache file." + path);
e.printStackTrace(System.err);
return null;
}
}
/**
* Read SQ codebook from cache file.
*
* @param trainFile Input image file.
* @param codebookSize Codebook size.
* @return SQ codebook or null.
*/
public SQCodebook readSQCodebook(final String trainFile, final int codebookSize) {
final SQCacheFile cacheFile = loadSQCacheFile(trainFile, codebookSize);
if (cacheFile != null)
return cacheFile.getCodebook();
else
return null;
}
/**
* Read VQ codebook from cache file.
*
* @param trainFile Input image file.
* @param codebookSize Codebook size.
* @param vDim Quantization vector dimension.
* @return VQ codebook.
*/
public VQCodebook readVQCodebook(final String trainFile,
final int codebookSize,
final V3i vDim) {
final VQCacheFile cacheFile = loadVQCacheFile(trainFile, codebookSize, vDim);
if (cacheFile != null)
return cacheFile.getCodebook();
else
return null;
}
/**
* Log information about SQ cache file.
*
* @param trainFile Input image file.
* @param codebookSize Codebook size.
*/
public void validateAndReport(final String trainFile, final int codebookSize) {
final SQCacheFile cacheFile = loadSQCacheFile(trainFile, codebookSize);
if (cacheFile == null) {
System.err.println("Invalid SQ cache file.");
return;
}
StringBuilder sb = new StringBuilder();
cacheFile.getHeader().report(sb);
sb.append("Frequencies: ");
for (final long fV : cacheFile.getCodebook().getSymbolFrequencies())
sb.append(fV).append(", ");
sb.append('\n');
System.out.println(sb.toString());
}
/**
* Log information about VQ cache file.
*
* @param trainFile Input image file.
* @param codebookSize Codebook size.
* @param vDim Quantization vector dimension.
*/
public void validateAndReport(final String trainFile,
final int codebookSize,
final V3i vDim) {
final VQCacheFile cacheFile = loadVQCacheFile(trainFile, codebookSize, vDim);
if (cacheFile == null) {
System.err.println("Invalid VQ cache file.");
return;
}
StringBuilder sb = new StringBuilder();
cacheFile.getHeader().report(sb);
sb.append("Frequencies: ");
for (final long fV : cacheFile.getCodebook().getVectorFrequencies())
sb.append(fV).append(", ");
sb.append('\n');
System.out.println(sb.toString());
}
}
package azgracompress.cache;
import azgracompress.quantization.scalar.SQCodebook;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
public class SQCacheFile implements ICacheFile {
private CacheFileHeader header;
private SQCodebook codebook;
public SQCacheFile() {
}
public SQCacheFile(final CacheFileHeader header, final SQCodebook codebook) {
this.header = header;
this.codebook = codebook;
assert (header.getCodebookSize() == codebook.getCodebookSize());
}
public void writeToStream(DataOutputStream outputStream) throws IOException {
header.writeToStream(outputStream);
final int[] quantizationValues = codebook.getCentroids();
final long[] frequencies = codebook.getSymbolFrequencies();
for (final int qV : quantizationValues) {
outputStream.writeShort(qV);
}
for (final long sF : frequencies) {
outputStream.writeLong(sF);
}
}
public void readFromStream(DataInputStream inputStream) throws IOException {
header = new CacheFileHeader();
header.readFromStream(inputStream);
final int codebookSize = header.getCodebookSize();
final int[] centroids = new int[codebookSize];
final long[] frequencies = new long[codebookSize];
for (int i = 0; i < codebookSize; i++) {
centroids[i] = inputStream.readUnsignedShort();
}
for (int i = 0; i < codebookSize; i++) {
frequencies[i] = inputStream.readLong();
}
codebook = new SQCodebook(centroids, frequencies);
}
public CacheFileHeader getHeader() {
return header;
}
public SQCodebook getCodebook() {
return codebook;
}
}
package azgracompress.cache;
import azgracompress.quantization.vector.CodebookEntry;
import azgracompress.quantization.vector.VQCodebook;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
public class VQCacheFile implements ICacheFile {
private CacheFileHeader header;
private VQCodebook codebook;
public VQCacheFile() {
}
public VQCacheFile(final CacheFileHeader header, final VQCodebook codebook) {
this.header = header;
this.codebook = codebook;
assert (header.getCodebookSize() == codebook.getCodebookSize());
}
public void writeToStream(DataOutputStream outputStream) throws IOException {
header.writeToStream(outputStream);
final CodebookEntry[] entries = codebook.getVectors();
for (final CodebookEntry entry : entries) {
for (final int vectorValue : entry.getVector()) {
outputStream.writeShort(vectorValue);
}
}
final long[] frequencies = codebook.getVectorFrequencies();
for (final long vF : frequencies) {
outputStream.writeLong(vF);
}
}
public void readFromStream(DataInputStream inputStream) throws IOException {
header = new CacheFileHeader();
header.readFromStream(inputStream);
final int codebookSize = header.getCodebookSize();
final CodebookEntry[] vectors = new CodebookEntry[codebookSize];
final long[] frequencies = new long[codebookSize];
final int entrySize = header.getVectorSizeX() * header.getVectorSizeY() * header.getVectorSizeZ();
for (int i = 0; i < codebookSize; i++) {
int[] vector = new int[entrySize];
for (int j = 0; j < entrySize; j++) {
vector[j] = inputStream.readUnsignedShort();
}
vectors[i] = new CodebookEntry(vector);
}
for (int i = 0; i < codebookSize; i++) {
frequencies[i] = inputStream.readLong();
}
codebook = new VQCodebook(header.getVectorDim(), vectors, frequencies);
}
public CacheFileHeader getHeader() {
return header;
}
public VQCodebook getCodebook() {
return codebook;
}
}
package azgracompress.quantization;
import azgracompress.data.V2i;
import azgracompress.quantization.vector.CodebookEntry;
import java.io.*;
// TODO(Moravec): If we want to use Huffman codes we have to save additional information with the codebook.
// This information can be probability or the absolute frequencies of codebook indices.
public class QuantizationValueCache {
private final String cacheFolder;
public QuantizationValueCache(final String cacheFolder) {
this.cacheFolder = cacheFolder;
new File(this.cacheFolder).mkdirs();
}
private File getCacheFileForScalarValues(final String trainFile, final int quantizationValueCount) {
final File inputFile = new File(trainFile);
final File cacheFile = new File(cacheFolder, String.format("%s_%d_bits.qvc",
inputFile.getName(), quantizationValueCount));
return cacheFile;
}
private File getCacheFileForVectorValues(final String trainFile,
final int codebookSize,
final int entryWidth,
final int entryHeight) {
final File inputFile = new File(trainFile);
final File cacheFile = new File(cacheFolder, String.format("%s_%d_%dx%d.qvc",
inputFile.getName(),
codebookSize,
entryWidth,
entryHeight));
return cacheFile;
}
public void saveQuantizationValues(final String trainFile, final int[] quantizationValues) throws IOException {
final int quantizationValueCount = quantizationValues.length;
final String cacheFile = getCacheFileForScalarValues(trainFile, quantizationValueCount).getAbsolutePath();
try (FileOutputStream fos = new FileOutputStream(cacheFile, false);
DataOutputStream dos = new DataOutputStream(fos)) {
for (final int qv : quantizationValues) {
dos.writeInt(qv);
}
} catch (IOException ex) {
throw new IOException(String.format("Failed to write cache to file: %s.\nInner Ex:\n%s",
cacheFile,
ex.getMessage()));
}
}
public void saveQuantizationValues(final String trainFile,
final CodebookEntry[] entries,
final V2i qVecDims) throws IOException {
final int codebookSize = entries.length;
final int entryWidth = qVecDims.getX();
final int entryHeight = qVecDims.getY();
final String cacheFile = getCacheFileForVectorValues(trainFile,
codebookSize,
entryWidth,
entryHeight).getAbsolutePath();
try (FileOutputStream fos = new FileOutputStream(cacheFile, false);
DataOutputStream dos = new DataOutputStream(fos)) {
dos.writeInt(codebookSize);
dos.writeInt(entryWidth);
dos.writeInt(entryHeight);
for (final CodebookEntry entry : entries) {
for (final int vectorValue : entry.getVector()) {
dos.writeInt(vectorValue);
}
}
}
}
public int[] readCachedValues(final String trainFile, final int quantizationValueCount) throws IOException {
final File cacheFile = getCacheFileForScalarValues(trainFile, quantizationValueCount);
int[] values = new int[quantizationValueCount];
try (FileInputStream fis = new FileInputStream(cacheFile);
DataInputStream dis = new DataInputStream(fis)) {
for (int i = 0; i < quantizationValueCount; i++) {
values[i] = dis.readInt();
}
}
return values;
}
public CodebookEntry[] readCachedValues(final String trainFile,
final int codebookSize,
final int entryWidth,
final int entryHeight) throws IOException {
final File cacheFile = getCacheFileForVectorValues(trainFile, codebookSize, entryWidth, entryHeight);
CodebookEntry[] codebook = new CodebookEntry[codebookSize];
try (FileInputStream fis = new FileInputStream(cacheFile);
DataInputStream dis = new DataInputStream(fis)) {
final int savedCodebookSize = dis.readInt();
final int savedEntryWidth = dis.readInt();
final int savedEntryHeight = dis.readInt();
assert (savedCodebookSize == codebookSize) : "Wrong codebook size";
assert (savedEntryWidth == entryWidth) : "Wrong entry width";
assert (savedEntryHeight == entryHeight) : "Wrong entry height";
final int entrySize = entryWidth * entryHeight;
for (int i = 0; i < codebookSize; i++) {
int[] vector = new int[entrySize];
for (int j = 0; j < entrySize; j++) {
vector[j] = dis.readInt();
}
codebook[i] = new CodebookEntry(vector);
}
}
return codebook;
}
public boolean areQuantizationValueCached(final String trainFile, final int quantizationValueCount) {
final File cacheFile = getCacheFileForScalarValues(trainFile, quantizationValueCount);
return cacheFile.exists();
}
public boolean areVectorQuantizationValueCached(final String trainFile,
final int codebookSize,
final int entryWidth,
final int entryHeight) {
final File cacheFile = getCacheFileForVectorValues(trainFile, codebookSize, entryWidth, entryHeight);
return cacheFile.exists();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment