summaryrefslogtreecommitdiff
path: root/src/utils/crypto/CBCCipher.scala
blob: 227b635da6c950906529052092fde18514b045a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package ixee.cryptopals.utils.crypto

import javax.crypto.Cipher
import ixee.cryptopals.utils.ConversionUtils._
import ixee.cryptopals.utils.FunctionUtils._
import ixee.cryptopals.utils.CryptoUtils._
import ixee.cryptopals.utils.TupleUtils._
import ixee.cryptopals.utils.ByteUtils._

class CBCCipher(private[this] val cipher: Cipher, private[this] val iv: Seq[Byte], mode: Int) extends IxeeCipher {
  val blockSize = cipher.getBlockSize

  var state: Seq[Byte] = iv
  var leftover: Seq[Byte] = Seq()

  lazy val handleBlock: Seq[Byte] => Seq[Byte] =
    if (mode == Cipher.ENCRYPT_MODE) encBlock _
    else decBlock _

  def update(data: Seq[Byte]): Seq[Byte] = {
    val blocks = blockized(leftover ++ data).tap(updateLeftover)._1
    blocks.foldLeft(Seq[Byte]())(_ ++ handleBlock(_))
  }

  private def decBlock(data: Seq[Byte]): Seq[Byte] =
    (cipher.update(data.toArray).toSeq xor state).tap(_ => state = data)

  private def encBlock(data: Seq[Byte]): Seq[Byte] =
    cipher.update((data xor state).toArray).tap(state = _)

  // wouldn't hurt to invalidate this object afterward, but meh
  // TODO: strip padding!
  // to do it right really requires writing decryption as its own part
  //   it's already obvious that's necessary, but to do padding stripping
  //   properly, the last block must be withheld until an end() call is made
  //   which is much different stateful behavior from encryption.
  //
  //   in cryptoUtils for now.
  def end(): Seq[Byte] =
    if (mode == Cipher.DECRYPT_MODE) Seq()
    else cipher.update((pkcs7pad(leftover, blockSize) xor state).toArray)

  def blockized(data: Seq[Byte]): (Seq[Seq[Byte]], Seq[Byte]) =
    groupBlocks <-: data.splitAt(data.length - (data.length % blockSize))

  def groupBlocks: Seq[Byte] => Seq[Seq[Byte]] = _.grouped(blockSize).toSeq

  def updateLeftover(pair: (Seq[Seq[Byte]], Seq[Byte])) =
    leftover = pair._2
}