summaryrefslogtreecommitdiff
path: root/src/solvers/XorDecrypt.scala
blob: 1a28637787c542111fe1c3605bda26137950a195 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package ixee.cryptopals.solvers

import ixee.cryptopals.utils._
import ixee.cryptopals.utils.ByteUtils._
import ixee.cryptopals.utils.StreamUtils._
import ixee.cryptopals.utils.FunctionUtils._
import ixee.cryptopals.utils.ConversionUtils._

object XorDecrypt {
  implicit val freq = Frequencies.cryptologicalMathematics //cornell40kSample

  def inferKeySizes(xs: Seq[Byte]): Seq[Int] = {
    inferKeySizesWithHammingDistance(xs).map(_._2)
  }

  def inferKeySize(xs: Seq[Byte]): Int =
    inferKeySizes(xs).head

  // not so great for small key sizes.
  // TODO: Why exactly does this work?
  def inferKeySizesWithHammingDistance(xs: Seq[Byte]): Seq[(Double, Int)] = {
    val MinKeySize = 1
    val MaxKeySize = 40

    (MinKeySize to MaxKeySize)
      .map(normalizedHammingDistanceForKeySize(xs, _))
      .zipWithIndex
      .map( { case (distance, keySize) => (distance, keySize + 1) } )
      .sortBy(_._1)
  }

  def normalizedHammingDistanceForKeySize(xs: Seq[Byte], keySize: Int) = {
    val grouped = xs.grouped(keySize).toArray.init
      // ez hack to drop non-`keySize` subcomponents of xs
    (grouped.init zip grouped.tail).map(tup(hammingDistance(_, _))).reduce(_ + _) / (grouped.length - 1.0) / keySize
  }

  def crackForMultiByteKey(ciphertext: Seq[Byte], inspectSet: Seq[Int] = Seq()): String =
    decryptToAscii(ciphertext)(findBestMultiByteKey(ciphertext, inspectSet))

  def findBestMultiByteKey(xs: Seq[Byte], inspectSet: Seq[Int] = Seq()): Seq[Byte] =
    continuous(findBestMultiByteKeyOfSize(inferKeySize(xs))(xs, inspectSet))
/*
  def crackForKeySize(xs: Seq[Byte], keySize: Int): String = {
    (findBestMultiByteKeyOfSize(keySize) _ :| continuous _ :| decryptToAscii(xs) _)(xs)
  }
*/
  def decrypt(ciphertext: Seq[Byte])(key: Seq[Byte]): Seq[Byte] =
    ciphertext xor key

  def decryptToAscii(ciphertext: Seq[Byte])(key: Seq[Byte]): String =
    decrypt(ciphertext)(key).asAscii

    // TODO: make an xor key a proper type.
  def decryptToAsciiSingleByteKey(ciphertext: Seq[Byte])(key: Byte): String =
    decryptToAscii(ciphertext)(Stream.continually(key))

  def tryBestDecrypt(ciphertext: Seq[Byte]): Seq[Byte] = {
    (findBestSingleByteKey _ :| { x => Stream.continually(x) } :| decrypt(ciphertext) _)(ciphertext)
  }

  def tryBestDecryptToText: Seq[Byte] => String =
    tryBestDecrypt _ :| ((_: Seq[Byte]).asAscii)

  def findBestMultiByteKeyOfSize(keySize: Int)(ciphertext: Seq[Byte], inspectSet: Seq[Int] = Seq()): Seq[Byte] = {
    // inspectSet is the set of key bytes for which we want debugging information
    ciphertext.grouped(keySize).toSeq.init.transpose.zipWithIndex
      .map { case (byteCiphertext, idx) => {
        if (inspectSet.contains(idx)) {
          println("For idx = " + idx)
          println(findBestSingleByteKeyWithCandidates(byteCiphertext))
        }
        findBestSingleByteKey(byteCiphertext)
      }}
  }

  def findBestSingleByteKeyWithCandidates(ciphertext: Seq[Byte]): Seq[(Double, String, Byte)] =
    candidates(ciphertext).sortBy(_._1).map {
      case (score, key) => (score, decryptToAsciiSingleByteKey(ciphertext)(key.toByte), key.toByte)
    }

  def findBestSingleByteKey(ciphertext: Seq[Byte]): Byte = {
    candidates(ciphertext)
      .reduceOption( (a: (Double, Int), b: (Double, Int)) =>
        if (a._1 < b._1) a else b
      ).map(_._2.toByte).getOrElse(0.toByte)
  }

  def candidates(ciphertext: Seq[Byte]) =
    (0 until 256)
      .map(_.toByte)
      .map(x => TextScorer.score(ciphertext xor Stream.continually(x)))
      .zipWithIndex
      .filter(_._1 != -1.0)

  def candidatesAsAscii(ciphertext: Seq[Byte]) =
    candidates(ciphertext).sortBy(_._1).map(x =>
      decrypt(ciphertext)(Seq(x._2.toByte)).asAscii
    )
}