diff --git a/src/org/thoughtcrime/securesms/crypto/DecryptingPartInputStream.java b/src/org/thoughtcrime/securesms/crypto/DecryptingPartInputStream.java index 860f424673..82e7805339 100644 --- a/src/org/thoughtcrime/securesms/crypto/DecryptingPartInputStream.java +++ b/src/org/thoughtcrime/securesms/crypto/DecryptingPartInputStream.java @@ -24,6 +24,7 @@ import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.util.Arrays; +import java.lang.System; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; @@ -53,6 +54,7 @@ public class DecryptingPartInputStream extends FileInputStream { private boolean done; private long totalDataSize; private long totalRead; + private byte[] overflowBuffer; public DecryptingPartInputStream(File file, MasterSecret masterSecret) throws FileNotFoundException { super(file); @@ -122,6 +124,25 @@ public class DecryptingPartInputStream extends FileInputStream { } private int readIncremental(byte[] buffer, int offset, int length) throws IOException { + int readLength = 0; + if (null != overflowBuffer) { + if (overflowBuffer.length > length) { + System.arraycopy(overflowBuffer, 0, buffer, offset, length); + overflowBuffer = Arrays.copyOfRange(overflowBuffer, length, overflowBuffer.length); + return length; + } else if (overflowBuffer.length == length) { + System.arraycopy(overflowBuffer, 0, buffer, offset, length); + overflowBuffer = null; + return length; + } else { + System.arraycopy(overflowBuffer, 0, buffer, offset, overflowBuffer.length); + readLength += overflowBuffer.length; + offset += readLength; + length -= readLength; + overflowBuffer = null; + } + } + if (length + totalRead > totalDataSize) length = (int)(totalDataSize - totalRead); @@ -131,7 +152,25 @@ public class DecryptingPartInputStream extends FileInputStream { try { mac.update(internalBuffer, 0, read); - return cipher.update(internalBuffer, 0, read, buffer, offset); + + int outputLen = cipher.getOutputSize(read); + + if (outputLen <= length) { + readLength += cipher.update(internalBuffer, 0, read, buffer, offset); + return readLength; + } + + byte[] transientBuffer = new byte[outputLen]; + outputLen = cipher.update(internalBuffer, 0, read, transientBuffer, 0); + if (outputLen <= length) { + System.arraycopy(transientBuffer, 0, buffer, offset, outputLen); + readLength += outputLen; + } else { + System.arraycopy(transientBuffer, 0, buffer, offset, length); + overflowBuffer = Arrays.copyOfRange(transientBuffer, length, outputLen); + readLength += length; + } + return readLength; } catch (ShortBufferException e) { throw new AssertionError(e); }