package ixee.cryptopals.utils import ixee.cryptopals.utils.crypto._ import ixee.cryptopals.utils.TupleUtils._ import ixee.cryptopals.utils.StreamUtils._ import ixee.cryptopals.utils.FunctionUtils._ import ixee.cryptopals.utils.ConversionUtils._ import javax.crypto.Cipher import javax.crypto.spec.SecretKeySpec object CryptoUtils { def pkcs7pad(s: Seq[Byte], blockSize: Int): Seq[Byte] = { val padLength = blockSize - (s.length % blockSize) s ++ Stream.continually(padLength.toByte).take(padLength) } def stripPkcs7Pad(s: Seq[Byte]): Seq[Byte] = s.dropRight(s.last) def cbcEncrypt(builder: CbcBuilder)(data: Seq[Byte]) = builder.encrypt.end(data) def cbcDecrypt(builder: CbcBuilder)(data: Seq[Byte]) = stripPkcs7Pad(builder.decrypt.end(data)) def ecbEncrypt(builder: EcbBuilder)(data: Seq[Byte]) = builder.encrypt.end(data) def ecbDecrypt(builder: EcbBuilder)(data: Seq[Byte]) = stripPkcs7Pad(builder.decrypt.end(data)) def detectMode(xs: Seq[Byte]): String = { def dupBlocks(xs: Seq[Seq[Byte]]) = pairsOf(xs).map(tup(_ == _)).count(_ == true) def countDupBlocks(xs: Seq[Byte]): Int = dupBlocks(xs.grouped(16).toSeq.init.toStream) //.... well very probably. if (countDupBlocks(xs) > 0) "ECB" else "CBC" } def isEcb(xs: Seq[Byte]): Boolean = detectMode(xs) == "ECB" def detectEcbBlockSize(encryptor: Seq[Byte] => Seq[Byte]): Int = { val pad = (2 to 512 by 2).map("@" * _).zipWithIndex.map { _.mapAll(_1 = _.asBytes, _2 = _ + 1) } val minPadSize = pad .map( { encryptor(_) } <-: _ ) .find( { x => isEcb(x._1) }) .map(_._2) minPadSize.get // if this was a None we have serious possibility of this not being ECB } def extractUnknownViaEcbOracle(encrypt: Seq[Byte] => Seq[Byte]) = { val blockSize = detectEcbBlockSize(encrypt) def blocksIn(xs: Seq[Byte]) = xs.length / blockSize val baseBlockCount = blocksIn(encrypt(Seq())) def rainbow(prefix: Seq[Byte]): Map[Seq[Byte], Byte] = { (0 to 255) .map(_.toByte) .map(prefix :+ _) .map(encrypt) .map(_.take(16).toSeq) .zipWithIndex .map(_ :-> { _.toByte } ) .toMap } def postfixRainbow(prefix: Seq[Byte]): Map[Seq[Byte], Byte] = { (0 to 255) .map(_.toByte) .map(_ +: prefix) .map(encrypt) .map(_.take(16).toSeq) .zipWithIndex .map(_ :-> { _.toByte } ) .toMap } def probeFirstBlockAndPaddings: (Seq[Byte], Map[Int, Seq[Byte]]) = { def prefix(known: Seq[Byte]) = (" " * (blockSize - 1 - known.length)).asBytes def genRainbow(known: Seq[Byte]) = rainbow(prefix(known) ++ known) def firstCryptedBlock(known: Seq[Byte]) = encrypt(prefix(known)).take(blockSize).toSeq def nextByte(known: Seq[Byte]) = genRainbow(known)(firstCryptedBlock(known)) (0 until 16).foldLeft((Seq[Byte](), Map[Int, Seq[Byte]]())) { (ac, idx) => (ac._1 :+ nextByte(ac._1), ac._2 + (idx -> encrypt(prefix(ac._1)))) } } def probeLastBlockSize: Int = { val firstLargerCiphertext = (0 until blockSize) .map(" " * _).map(_.asBytes).map(encrypt) .zipWithIndex .find(x => blocksIn(x._1) != baseBlockCount) // this will always be Some(_) because // somewhere between 0..blockSize WILL grow the text. firstLargerCiphertext.get._2 } val lastBlockSize = probeLastBlockSize println("Last block is " + lastBlockSize + " bytes") val (firstBlock, ciphertexts) = probeFirstBlockAndPaddings /* * zip together cipher blocks so that they look like * rot0b0, rot1b0, rot2b0, rot3b0, rot4b0, rot5b0, rot6b0, rot7b0 * rot0b1, rot1b1, rot2b1, ... * */ val cipherBlocks: Seq[(Int, Seq[Seq[Byte]])] = ciphertexts.toSeq.sortBy(_._1).map(_ :-> { x => x.grouped(16).toSeq }) val (middleBlocks, lastBlocks) = cipherBlocks.foldLeft((Seq[Seq[Seq[Byte]]](), Seq[Option[Seq[Byte]]]()))( (ac, curr) => { val maybeLastBlock = if (curr._2.length > baseBlockCount) Some(curr._2.last) else None val middleBlocks = curr._2.take(baseBlockCount).tail (ac._1 :+ middleBlocks, ac._2 :+ maybeLastBlock) }) :-> { _.flatten } println(lastBlocks.mkString("\n")) def breakLastBlock(blocks: Seq[Seq[Byte]]) = { /* * plaintext looks something like... * X 15 15 15 15 15 ... 15 * Y X 14 14 14 14 ... 14 * ... pkcs7 */ blocks.foldLeft(Seq[Byte]()) { (bytes, block) => { val postfix = pkcs7pad("?".asBytes ++ bytes, blockSize).tail val b = postfixRainbow(postfix)(block) println("Got " + b) bytes :+ b }} } // val lastBlockBytes = crackLastBlock(lastBlocks) // println("Also...") // println(rainbow(firstBlock.tail)(middleBlocks(1)(0))) val prefix = firstBlock.tail val currRainbow = rainbow(prefix) println(middleBlocks(0)(0))//ciphertexts(0).drop(16).take(16)) val next = currRainbow(middleBlocks(0)(0)) println("Next: " + new String(Array(next.toByte))) val nowPrefix = (prefix.tail :+ next) val r2 = rainbow(nowPrefix) val next2 = r2(ciphertexts(1).drop(16).take(16).toSeq) println("Next: " + new String(Array(next2.toByte))) val pref3 = (nowPrefix.tail :+ next2) val r3 = rainbow(pref3) val next3 = r3(ciphertexts(2).drop(16).take(16).toSeq) println("Next: " + new String(Array(next3.toByte))) def breakBlock(plaintext: Seq[Byte], padded: Seq[Seq[Byte]]): Seq[Byte] = padded.foldLeft(Seq[Byte]()) { (blockText: Seq[Byte], block: Seq[Byte]) => { val prefix = (plaintext ++ blockText).takeRight(blockSize - 1) val b = rainbow(prefix)(block) blockText :+ b }} val middleBytes = middleBlocks.transpose.init.foldLeft(firstBlock) { (plaintext, padded: Seq[Seq[Byte]]) => println("Block " + padded(0)) val bytes = breakBlock(plaintext, padded) println(".... Done!") plaintext ++ bytes } middleBytes ++ breakLastBlock(lastBlocks) } }