diff options
Diffstat (limited to 'src/utils/crypto/CBCCipher.scala')
-rw-r--r-- | src/utils/crypto/CBCCipher.scala | 38 |
1 files changed, 24 insertions, 14 deletions
diff --git a/src/utils/crypto/CBCCipher.scala b/src/utils/crypto/CBCCipher.scala index d6d99ff..3bd0784 100644 --- a/src/utils/crypto/CBCCipher.scala +++ b/src/utils/crypto/CBCCipher.scala @@ -7,28 +7,38 @@ import ixee.cryptopals.utils.CryptoUtils._ import ixee.cryptopals.utils.TupleUtils._ import ixee.cryptopals.utils.ByteUtils._ -class CBCEncrypter(cipher: Cipher, init: Seq[Byte]) { - val blockSize = 16 +class CBCCipher(private[this] val cipher: Cipher, private[this] val iv: Seq[Byte], mode: Int) extends IxeeCipher { + val blockSize = cipher.getBlockSize - var state: Seq[Byte] = init + var state: Seq[Byte] = iv var leftover: Seq[Byte] = Seq() - def enc(data: Seq[Byte]): Seq[Byte] = { - val (blocks, newLeftover) = blockized(leftover ++ data)(blockSize) - leftover = newLeftover - blocks.map(encBlock _).foldLeft(Seq[Byte]())(_ ++ _) - } + lazy val handleBlock: Seq[Byte] => Seq[Byte] = + if (mode == Cipher.ENCRYPT_MODE) encBlock _ + else decBlock _ - def encBlock(data: Seq[Byte]): Seq[Byte] = { - state = cipher.update((data xor state).toArray) - state + def update(data: Seq[Byte]): Seq[Byte] = { + val blocks = blockized(leftover ++ data).tap(updateLeftover)._1 + blocks.foldLeft(Seq[Byte]())(_ ++ handleBlock(_)) } + private def decBlock(data: Seq[Byte]): Seq[Byte] = + (cipher.update(data.toArray).toSeq xor state).tap(_ => state = data) + + private def encBlock(data: Seq[Byte]): Seq[Byte] = + cipher.update((data xor state).toArray).tap(state = _) + + // wouldn't hurt to invalidate this object afterward, but meh + // TODO: strip padding! def end(): Seq[Byte] = - cipher.doFinal((pkcs7pad(leftover, blockSize) xor state).toArray) + if (mode == Cipher.DECRYPT_MODE) + else cipher.update((pkcs7pad(leftover, blockSize) xor state).toArray) - def blockized(data: Seq[Byte])(size: Int): (Seq[Seq[Byte]], Seq[Byte]) = - groupBlocks <-: data.splitAt(data.length - (data.length % size)) + def blockized(data: Seq[Byte]): (Seq[Seq[Byte]], Seq[Byte]) = + groupBlocks <-: data.splitAt(data.length - (data.length % blockSize)) def groupBlocks: Seq[Byte] => Seq[Seq[Byte]] = _.grouped(blockSize).toSeq + + def updateLeftover(pair: (Seq[Seq[Byte]], Seq[Byte])) = + leftover = pair._2 } |