apply same fix used on DecryptingPartInputStream to AttachmentCipherInputStream // FREEBIE

This commit is contained in:
Jake McGinty 2014-02-18 16:41:59 -08:00
parent 9d9a0ec218
commit 649f037ed8

View File

@ -52,6 +52,7 @@ public class AttachmentCipherInputStream extends FileInputStream {
private boolean done; private boolean done;
private long totalDataSize; private long totalDataSize;
private long totalRead; private long totalRead;
private byte[] overflowBuffer;
public AttachmentCipherInputStream(File file, byte[] combinedKeyMaterial) public AttachmentCipherInputStream(File file, byte[] combinedKeyMaterial)
throws IOException, InvalidMessageException throws IOException, InvalidMessageException
@ -125,6 +126,25 @@ public class AttachmentCipherInputStream extends FileInputStream {
} }
private int readIncremental(byte[] buffer, int offset, int length) throws IOException { private int readIncremental(byte[] buffer, int offset, int length) throws IOException {
int readLength = 0;
if (null != overflowBuffer) {
if (overflowBuffer.length > length) {
System.arraycopy(overflowBuffer, 0, buffer, offset, length);
overflowBuffer = Arrays.copyOfRange(overflowBuffer, length, overflowBuffer.length);
return length;
} else if (overflowBuffer.length == length) {
System.arraycopy(overflowBuffer, 0, buffer, offset, length);
overflowBuffer = null;
return length;
} else {
System.arraycopy(overflowBuffer, 0, buffer, offset, overflowBuffer.length);
readLength += overflowBuffer.length;
offset += readLength;
length -= readLength;
overflowBuffer = null;
}
}
if (length + totalRead > totalDataSize) if (length + totalRead > totalDataSize)
length = (int)(totalDataSize - totalRead); length = (int)(totalDataSize - totalRead);
@ -133,7 +153,24 @@ public class AttachmentCipherInputStream extends FileInputStream {
totalRead += read; totalRead += read;
try { try {
return cipher.update(internalBuffer, 0, read, buffer, offset); int outputLen = cipher.getOutputSize(read);
if (outputLen <= length) {
readLength += cipher.update(internalBuffer, 0, read, buffer, offset);
return readLength;
}
byte[] transientBuffer = new byte[outputLen];
outputLen = cipher.update(internalBuffer, 0, read, transientBuffer, 0);
if (outputLen <= length) {
System.arraycopy(transientBuffer, 0, buffer, offset, outputLen);
readLength += outputLen;
} else {
System.arraycopy(transientBuffer, 0, buffer, offset, length);
overflowBuffer = Arrays.copyOfRange(transientBuffer, length, outputLen);
readLength += length;
}
return readLength;
} catch (ShortBufferException e) { } catch (ShortBufferException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }