summaryrefslogtreecommitdiff
path: root/src/utils/crypto/CBCCipher.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/utils/crypto/CBCCipher.scala')
-rw-r--r--src/utils/crypto/CBCCipher.scala38
1 files changed, 24 insertions, 14 deletions
diff --git a/src/utils/crypto/CBCCipher.scala b/src/utils/crypto/CBCCipher.scala
index d6d99ff..3bd0784 100644
--- a/src/utils/crypto/CBCCipher.scala
+++ b/src/utils/crypto/CBCCipher.scala
@@ -7,28 +7,38 @@ import ixee.cryptopals.utils.CryptoUtils._
import ixee.cryptopals.utils.TupleUtils._
import ixee.cryptopals.utils.ByteUtils._
-class CBCEncrypter(cipher: Cipher, init: Seq[Byte]) {
- val blockSize = 16
+class CBCCipher(private[this] val cipher: Cipher, private[this] val iv: Seq[Byte], mode: Int) extends IxeeCipher {
+ val blockSize = cipher.getBlockSize
- var state: Seq[Byte] = init
+ var state: Seq[Byte] = iv
var leftover: Seq[Byte] = Seq()
- def enc(data: Seq[Byte]): Seq[Byte] = {
- val (blocks, newLeftover) = blockized(leftover ++ data)(blockSize)
- leftover = newLeftover
- blocks.map(encBlock _).foldLeft(Seq[Byte]())(_ ++ _)
- }
+ lazy val handleBlock: Seq[Byte] => Seq[Byte] =
+ if (mode == Cipher.ENCRYPT_MODE) encBlock _
+ else decBlock _
- def encBlock(data: Seq[Byte]): Seq[Byte] = {
- state = cipher.update((data xor state).toArray)
- state
+ def update(data: Seq[Byte]): Seq[Byte] = {
+ val blocks = blockized(leftover ++ data).tap(updateLeftover)._1
+ blocks.foldLeft(Seq[Byte]())(_ ++ handleBlock(_))
}
+ private def decBlock(data: Seq[Byte]): Seq[Byte] =
+ (cipher.update(data.toArray).toSeq xor state).tap(_ => state = data)
+
+ private def encBlock(data: Seq[Byte]): Seq[Byte] =
+ cipher.update((data xor state).toArray).tap(state = _)
+
+ // wouldn't hurt to invalidate this object afterward, but meh
+ // TODO: strip padding!
def end(): Seq[Byte] =
- cipher.doFinal((pkcs7pad(leftover, blockSize) xor state).toArray)
+ if (mode == Cipher.DECRYPT_MODE)
+ else cipher.update((pkcs7pad(leftover, blockSize) xor state).toArray)
- def blockized(data: Seq[Byte])(size: Int): (Seq[Seq[Byte]], Seq[Byte]) =
- groupBlocks <-: data.splitAt(data.length - (data.length % size))
+ def blockized(data: Seq[Byte]): (Seq[Seq[Byte]], Seq[Byte]) =
+ groupBlocks <-: data.splitAt(data.length - (data.length % blockSize))
def groupBlocks: Seq[Byte] => Seq[Seq[Byte]] = _.grouped(blockSize).toSeq
+
+ def updateLeftover(pair: (Seq[Seq[Byte]], Seq[Byte])) =
+ leftover = pair._2
}