package ixee.cryptopals.utils import ixee.cryptopals.utils.crypto._ import ixee.cryptopals.utils.TupleUtils._ import ixee.cryptopals.utils.StreamUtils._ import ixee.cryptopals.utils.FunctionUtils._ import ixee.cryptopals.utils.ConversionUtils._ import javax.crypto.Cipher import javax.crypto.spec.SecretKeySpec object CryptoUtils { def pkcs7pad(s: Seq[Byte], blockSize: Int): Seq[Byte] = { val padLength = blockSize - (s.length % blockSize) s ++ Stream.continually(padLength.toByte).take(padLength) } def stripPkcs7Pad(s: Seq[Byte]): Seq[Byte] = s.dropRight(s.last) def cbcEncrypt(builder: CbcBuilder)(data: Seq[Byte]) = builder.encrypt.end(data) def cbcDecrypt(builder: CbcBuilder)(data: Seq[Byte]) = stripPkcs7Pad(builder.decrypt.end(data)) def ecbEncrypt(builder: EcbBuilder)(data: Seq[Byte]) = builder.encrypt.end(data) def ecbDecrypt(builder: EcbBuilder)(data: Seq[Byte]) = stripPkcs7Pad(builder.decrypt.end(data)) def detectMode(xs: Seq[Byte]): String = { def dupBlocks(xs: Seq[Seq[Byte]]) = pairsOf(xs).map(tup(_ == _)).count(_ == true) def countDupBlocks(xs: Seq[Byte]): Int = dupBlocks(xs.grouped(16).toSeq.init.toStream) //.... well very probably. if (countDupBlocks(xs) > 0) "ECB" else "CBC" } def isEcb(xs: Seq[Byte]): Boolean = detectMode(xs) == "ECB" def detectEcbBlockSize(encryptor: Seq[Byte] => Seq[Byte]): Int = { val pad = (2 to 512 by 2).map("@" * _).zipWithIndex.map { _.mapAll(_1 = _.asBytes, _2 = _ + 1) } val minPadSize = pad .map( { encryptor(_) } <-: _ ) .find( { x => isEcb(x._1) }) .map(_._2) minPadSize.get // if this was a None we have serious possibility of this not being ECB } def extractUnknownViaEcbOracle(encrypt: Seq[Byte] => Seq[Byte]) = { class Ciphertext(encrypt: Seq[Byte] => Seq[Byte], prefix: Seq[Byte] = Seq(), private val overrideBlockSize: Option[Int] = None) { private lazy val cipherBytes = encrypt(prefix) lazy val blockSize = overrideBlockSize.getOrElse(detectEcbBlockSize(encrypt)) lazy val blocks = cipherBytes.grouped(blockSize).toSeq lazy val blockCount = blocks.length lazy val paddingSize = prefix.length } val baseCiphert = new Ciphertext(encrypt) val ciphertexts = Seq(baseCiphert) ++ (1 to 15) .map( x => ("@" * x).asBytes ) .map(new Ciphertext(encrypt, _, overrideBlockSize = Some(baseCiphert.blockSize))) val blockSize = baseCiphert.blockSize //detectEcbBlockSize(encrypt) //val baseCiphertext = baseCiphert.cipherBytes //encrypt(Seq()) def blocksIn(xs: Seq[Byte]) = xs.length / blockSize val baseBlockCount = baseCiphert.blockCount //blocksIn(baseCiphertext) def rainbow(const: Seq[Byte], generator: (Byte, Seq[Byte]) => Seq[Byte]) = (0 to 255) .map(_.toByte) .map(generator(_, const)) .map(encrypt) .map(_.take(blockSize).toSeq) .zipWithIndex .map(_ :-> { _.toByte } ) .toMap def rainbowSuffix(prefix: Seq[Byte]): Map[Seq[Byte], Byte] = rainbow(prefix, (byte, seq) => seq :+ byte) def rainbowPrefix(suffix: Seq[Byte]): Map[Seq[Byte], Byte] = rainbow(suffix, (byte, seq) => byte +: seq) def probeFirstBlock: Seq[Byte] = { def prefix(known: Seq[Byte]) = ("@" * (blockSize - 1 - known.length)).asBytes def genRainbow(known: Seq[Byte]) = rainbowSuffix(prefix(known) ++ known) def firstCryptedBlock(known: Seq[Byte]) = encrypt(prefix(known)).take(blockSize).toSeq def nextByte(known: Seq[Byte]) = known :+ genRainbow(known)(firstCryptedBlock(known)) (0 until 16).foldLeft(Seq[Byte]()) { (ac, idx) => nextByte(ac) } } def probeLastBlockSize: Int = { val firstLargerCiphertext = (0 until blockSize) .map(" " * _).map(_.asBytes).map(encrypt) .zipWithIndex .find(x => blocksIn(x._1) != baseBlockCount) // this will always be Some(_) because // somewhere between 0..blockSize WILL grow the text. blockSize - firstLargerCiphertext.get._2 } val firstBlock = probeFirstBlock /* * zip together cipher blocks so that they look like * rot0b0, rot1b0, rot2b0, rot3b0, rot4b0, rot5b0, rot6b0, rot7b0 * rot0b1, rot1b1, rot2b1, ... * */ val (middleBlocks, lastBlocks) = ciphertexts.foldRight((Seq[Seq[Seq[Byte]]](), Seq[Option[Seq[Byte]]]()))( (curr, ac) => { val maybeLastBlock = if (curr.blockCount > baseBlockCount) Some(curr.blocks.last) else None val middleBlocks = curr.blocks.take(baseBlockCount).tail (ac._1 :+ middleBlocks, maybeLastBlock +: ac._2) }) :-> { _.flatten :+ baseCiphert.blocks.last } //baseCiphertext.takeRight(16)} def breakLastBlock(blocks: Seq[Seq[Byte]]) = { /* * plaintext looks something like... * X 15 15 15 15 15 ... 15 * Y X 14 14 14 14 ... 14 * ... pkcs7 */ // drop the first block because it will be 16 16 16 16 16 16 ... 16 blocks.tail.foldLeft(Seq[Byte]()) { (bytes, block) => { val postfix = pkcs7pad("?".asBytes ++ bytes, blockSize).tail rainbowPrefix(postfix)(block) +: bytes }} } def breakBlock(plaintext: Seq[Byte], padded: Seq[Seq[Byte]]): Seq[Byte] = padded.foldLeft(Seq[Byte]()) { (blockText: Seq[Byte], block: Seq[Byte]) => { val prefix = (plaintext ++ blockText).takeRight(blockSize - 1) val b = rainbowSuffix(prefix)(block) blockText :+ b }} val middleBytes = middleBlocks.transpose.init.foldLeft(firstBlock) { (plaintext, padded: Seq[Seq[Byte]]) => plaintext ++ breakBlock(plaintext, padded) } middleBytes ++ breakLastBlock(lastBlocks) } }