#!/usr/bin/ruby

# Reference implementation of draft-irtf-cfrg-ocb-03. Not optimized for speed.

require 'ruby-aes/aes_alg'  # Get it at <http://rubyforge.org/projects/ruby-aes>

$supress_printing = false

def pbuf(s,prompt='')
    print prompt
    0.upto(s.size-1) do |i|
        printf("%02X",s[i])
        printf("\n     ") if ((i+1)%24 == 0) and (i < s.size-1)
    end
    putc("\n")
end

class Ocb

    MAXL   = 11       # Can handle messages upto 2^(MAXL+4) bytes
    TAGLEN = 16       # Length of produced tag, in bytes

    def initialize(ocbKey)
        @L = Array.new(MAXL)  # to hold precomputed L components
        @E = AesAlg::new(ocbKey.length*8, 'ECB', ocbKey)
        @Lstar = @E.encrypt_block(0.chr * 16)
        @Ldollar = double(@Lstar)
        @L[0] = double(@Ldollar)
        1.upto(MAXL-1) { |i| @L[i] = double(@L[i-1]) }

        if false
        pbuf(@Lstar, 'L_* = ')
        pbuf(@Ldollar, 'L_$ = ')
        MAXL.times { |i| pbuf(@L[i], 'L_'+i.to_s+' = ') }
        end
    end

    # ------------------------------------------------------------------------

    def encrypt(n,a,m)
        checksum = 0.chr * 16
        c = ""
        nonce = (0.chr * (15 - n.length)) + 1.chr + n
        nonce[0] |= (((TAGLEN * 8) % 128) << 1)
        top = String.new(nonce)
        top[-1] &= 192
        bot = nonce[-1] & 63
        ktop = @E.encrypt_block(top)
        str = ktop + exor(ktop[0,8], ktop[1,8])  # str is 16+8 bytes
        pbuf(ktop,"Ktop      : ") if not $supress_printing
        puts("bottom    : " + bot.to_s) if not $supress_printing
        pbuf(str,"Stretch   : ") if not $supress_printing
        delta = shift_left(str, bot)[0,16]
        pbuf(delta,"Offset_0  : ") if not $supress_printing
        l = m.length/16  # number of full blocks
        1.upto(l) do |i|
            #pbuf(delta,"delta before ")
            delta = exor(delta, @L[ntz(i)])
            pbuf(@L[ntz(i)],"L_" + ntz(i).to_s + "       : ") if not $supress_printing
            pbuf(delta, "Offset_" + i.to_s + "  : ") if not $supress_printing
            block = m[(i-1)*16,16]
            #pbuf(block)
            checksum = exor(checksum,block)
            pbuf(checksum,"Checksum_" + i.to_s + ": ") if not $supress_printing
            block = exor(block,delta)
            block = @E.encrypt_block(block)
            block = exor(block,delta)
            c << block
        end
        if (16*l < m.length)
            #pbuf(delta,"delta before ")
            delta = exor(delta,@Lstar)
            pbuf(delta,"Offset_*  : ") if not $supress_printing
            pbuf(@Lstar,"L_*       : ") if not $supress_printing
            pad = @E.encrypt_block(delta)
            block = m[l*16..-1]         # the remainder of the string
            block = exor(block,pad)     # turn block into ciphertext
            c << block
            block = exor(block,pad)     # Undo pad xor to get block back
            block << (0x80.chr + (0.chr * (15 - block.length)))  # pad to 16 bytes
            checksum = exor(checksum,block)
            pbuf(checksum,"Checksum_*: ") if not $supress_printing
        end
        #pbuf(delta,"delta ")
        #pbuf(@Ldollar,"@Ldollar ")
        delta = exor(delta, @Ldollar)
        pbuf(@Ldollar,"L_$       : ") if not $supress_printing
        checksum = exor(checksum,delta)
        final = @E.encrypt_block(checksum)
        auth = hash(a)
        tag = exor(auth, final)
        c << tag[0,TAGLEN]
        return c
    end

    # ------------------------------------------------------------------------

    def decrypt(n,a,c)
        raise "Ciphertext must be at least TAGLEN bytes" if (c.length < TAGLEN)
        t = c.slice!(-TAGLEN..-1)  # peel off tag from c
        checksum = 0.chr * 16
        m = ""
        nonce = (0.chr * (15 - n.length)) + 1.chr + n
        nonce[0] |= (((TAGLEN * 8) % 128) << 1)
        top = String.new(nonce)
        top[-1] &= 192
        bot = nonce[-1] & 63
        ktop = @E.encrypt_block(top)
        str = ktop + exor(ktop, ktop[1,8])
        delta = (shift_left(str, bot))[0,16]
        l = c.length/16        # number of full blocks
        1.upto(l) do |i|
            delta = exor(delta, @L[ntz(i)])
            block = c[(i-1)*16,16]
            block = exor(block,delta)
            block = @E.decrypt_block(block)
            block = exor(block,delta)
            m << block
            checksum = exor(checksum,block)
        end
        if (16*l < c.length)
            delta = exor(delta,@Lstar)
            pad = @E.encrypt_block(delta)
            block = c[l*16..-1]         # the remainder of the string
            block = exor(block,pad)     # turn block into decrypted message
            m << block
            block << (0x80.chr + (0.chr * (15 - block.length)))  # pad to 16 bytes
            checksum = exor(checksum,block)
        end
        delta = exor(delta, @Ldollar)
        checksum = exor(checksum,delta)
        final = @E.encrypt_block(checksum)
        auth = hash(a)
        tag = exor(auth, final)
        if (tag[0,TAGLEN] == t)
            return m
        else
            return nil
        end
    end

    # ------------------------------------------------------------------------

private

    # ------------------------------------------------------------------------

    def hash(a)   # a is a list of strings
        delta = 0.chr * 16
        sum = 0.chr * 16
        l = a.length/16  # number of full blocks
        1.upto(l) do |i|
            delta = exor(delta, @L[ntz(i)])
            block = a[(i-1)*16,16]
            block = exor(block,delta)
            block = @E.encrypt_block(block)
            sum = exor(sum,block)
        end
        if (16*l < a.length)
            delta = exor(delta,@Lstar)
            block = a[l*16..-1]         # the remainder of the string
            block << (0x80.chr + (0.chr * (15 - block.length)))  # pad to 16 bytes
            block = exor(block,delta)
            block = @E.encrypt_block(block)
            sum = exor(sum,block)
        end
        return sum
    end

    # ------------------------------------------------------------------------

    def double(x)  # returns doubled copy of x
        y = String.new(x)   # Copy x
        hi_bit = y[0] & 0x80
        15.times { |i| y[i] = ((y[i] << 1) & 0xff) | ((y[i+1] >> 7) & 0x01) }
        y[15] <<= 1
        y[15] ^= 135 if (hi_bit != 0)
        return y
    end

    # ------------------------------------------------------------------------

    def shift_left(x,n)  # returns x << n
        shift_bytes     = n/8     # Num bytes to skip to get to the nth bit
        remaining_shift = n%8     # Num bits remain to shift (0-7)
        rval = 0.chr * x.length   # Start with x.len zeros, then update
        (x.length-shift_bytes-1).times do |i|
            rval[i] = (((x[shift_bytes+i  ] << remaining_shift) & 0xff) |
                        (x[shift_bytes+i+1] >> (8-remaining_shift)))
        end
        rval[x.length-shift_bytes-1] = ((x[-1] << remaining_shift) & 0xff)
        return rval
    end

    # ------------------------------------------------------------------------

    def exor(s1,s2)  # xor two strings, as many bytes as shorter of them.
        n = [s1.length, s2.length].min
        rval = 0.chr * n;
        n.times { |i| rval[i] = s1[i] ^ s2[i] }
        return rval
    end

    # ------------------------------------------------------------------------

    def ntz(x)  # returns the number of trailing zero bits.
        rval = 0
        while (x & 1 == 0) do
            rval += 1
            x >>= 1
        end
        return rval
    end

end


def vector(o, n, a, p)
    c = o.encrypt(n,a,p)
    putc("\n")
    pbuf(a,'  A: ')
    pbuf(p,'  P: ')
    pbuf(c,'  C: ')
end

k = ''; 0.upto(15) { |i| k << i.chr }  # Key used in RFC
o = Ocb::new(k)

$supress_printing = true

n = ''; 0.upto(11) { |i| n << i.chr }  # Nonce used in RFC
vector(o,n,'','')
a=''; 0.upto(7) { |i| a << i.chr }
p=''; 0.upto(7) { |i| p << i.chr }
vector(o,n,a,p)
vector(o,n,a,'')
vector(o,n,'',p)
a=''; 0.upto(15) { |i| a << i.chr }
p=''; 0.upto(15) { |i| p << i.chr }
vector(o,n,a,p)
vector(o,n,a,'')
vector(o,n,'',p)
a=''; 0.upto(23) { |i| a << i.chr }
p=''; 0.upto(23) { |i| p << i.chr }
vector(o,n,a,p)
vector(o,n,a,'')
vector(o,n,'',p)
a=''; 0.upto(31) { |i| a << i.chr }
p=''; 0.upto(31) { |i| p << i.chr }
vector(o,n,a,p)
vector(o,n,a,'')
vector(o,n,'',p)
a=''; 0.upto(39) { |i| a << i.chr }
p=''; 0.upto(39) { |i| p << i.chr }
vector(o,n,a,p)
vector(o,n,a,'')
vector(o,n,'',p)

$supress_printing = false
vector(o,n,'',p)
$supress_printing = true

k = 0.chr * 16
o = Ocb::new(k)
c = ''
0.upto(127) do |i|
    s = 0.chr * i
    n = (0.chr * 11) + (i.chr)
    c << o.encrypt(n,s,s)
    c << o.encrypt(n,'',s)
    c << o.encrypt(n,s,'')
end
n = 0.chr * 12
t = o.encrypt(n,c,'')
pbuf(t,'Output: ')
