/*------------------------------------------------------------------------
/ OCB Version 3 Reference Code (Optimized C)     Last modified 26-MAY-2010
/-------------------------------------------------------------------------
/ Copyright (c) 2010 Ted Krovetz and Phillip Rogaway.
/
/ Permission to use, copy, modify, and/or distribute this software for any
/ purpose with or without fee is hereby granted, provided that the above
/ copyright notice and this permission notice appear in all copies.
/
/ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
/ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
/ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
/ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
/ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
/ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
/ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
/
/ Comments are welcome: Ted Krovetz <tdk@acm.org>
/------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------- */
/* Usage requirements                                                      */
/* ----------------------------------------------------------------------- */

/* - When AE_PENDING is passed as the 'final' parameter of any function,
/    the length parameters must be a multiple of (BPI*16).
/  - When available, SSE or AltiVec registers are used to manipulate data.
/    So, when on machines with these facilities, all pointers passed to
/    any function should be 16-byte aligned.
/  - Plaintext and ciphertext pointers may be equal (ie, plaintext gets
/    encrypted in-place), but no other pair of pointers may be equal.      */

/* ----------------------------------------------------------------------- */
/* User configuration options                                              */
/* ----------------------------------------------------------------------- */

/* This implementation has built-in support for multiple AES APIs. Set any
/  one of the following to non-zero to specify which to use. USE_AES_NI set
/  by itself only supports 128-bit keys. To use AES-NI with 192 or 256 bit
/  keys, set both USE_OPENSSL_AES and USE_AES_NI, in which case OpenSSL
/  handles key setup and AES-NI intrinsics are used for encryption.        */
#if !(USE_OPENSSL_AES || USE_CRYPTOPP_AES || USE_REFERENCE_AES ||           \
      USE_AES_NI || USE_VIA_ACE_AES || USE_KASPER_AES)
#define USE_OPENSSL_AES            0         /* http://openssl.org         */
#define USE_CRYPTOPP_AES           0         /* http://cryptopp.com        */
#define USE_REFERENCE_AES          0         /* Google: rijndael-alg-fst.c */
#define USE_AES_NI                 0         /* Uses compiler's intrinsics */
#define USE_VIA_ACE_AES            0
#define USE_KASPER_AES             1
#endif

/* MAX_KEY_BYTES specifies the maximum size key you intend to supply OCB, and
/  *must* be 16, 24, or 32. In *some* AES implementations it is possible to
/  limit internal key-schedule sizes, so keep this as small as possible.   */
#define MAX_KEY_BYTES             16

/* To eliminate the use of vector types, set the following non-zero        */
#define VECTORS_OFF                0

/* ----------------------------------------------------------------------- */
/* Derived configuration options - Adjust as needed                        */
/* ----------------------------------------------------------------------- */

/* These determine whether vectors should be used.                         */
#define USE_SSE2    ((__SSE2__ || (_M_IX86_FP>=2) || _M_X64) && !VECTORS_OFF)
#define USE_ALTIVEC (__ALTIVEC__ && !VECTORS_OFF)

/* These determine how to allocate 16-byte aligned vectors, if needed.     */
#define USE_MM_MALLOC      (USE_SSE2 && !(_M_X64 || __amd64__))
#define USE_POSIX_MEMALIGN (USE_ALTIVEC && __GLIBC__ && !__PPC64__)

/* ----------------------------------------------------------------------- */
/* Includes and compiler specific definitions                              */
/* ----------------------------------------------------------------------- */

#include "ae.h"
#include <stdlib.h>
#include <string.h>

/* Define standard sized integers                                          */
#if defined(_MSC_VER) && (_MSC_VER < 1600)
	typedef unsigned __int8  uint8_t;
	typedef unsigned __int32 uint32_t;
	typedef unsigned __int64 uint64_t;
	typedef          __int64 int64_t;
#else
	#include <stdint.h>
#endif

/* How to force specific alignment, request inline, restrict pointers      */
#if __GNUC__
	#define ALIGN(n) __attribute__ ((aligned(n)))
	#define inline __inline__
	#define restrict __restrict__
#elif _MSC_VER
	#define ALIGN(n) __declspec(align(n))
	#define inline __inline
	#define restrict __restrict
#elif __STDC_VERSION__ >= 199901L   /* C99: delete align, keep others      */
	#define ALIGN(n)
#else /* Not GNU/Microsoft/C99: delete alignment/inline/restrict uses.     */
	#define ALIGN(n)
	#define inline
	#define restrict
#endif

/* How to endian reverse a uint64_t                                        */
#if _MSC_VER
    #define bswap64(x) _byteswap_uint64(x)
#elif (__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ >= 3)) && !__arm__
    #define bswap64(x) __builtin_bswap64(x)
#elif __GNUC__ && __amd64__
    #define bswap64(x) ({uint64_t y=x;__asm__("bswapq %0":"+r"(y));y;})
#else

/* Build bswap64 out of two bswap32's                                      */
#if __GNUC__ && (__ARM_ARCH_6__ || __ARM_ARCH_6J__ || __ARM_ARCH_6K__ ||    \
    __ARM_ARCH_6Z__ || __ARM_ARCH_6ZK__ || __ARM_ARCH_6T2__ ||              \
    __ARM_ARCH_7__ || __ARM_ARCH_7A__ || __ARM_ARCH_7R__ || __ARM_ARCH_7M__)
	#define bswap32(x) ({uint32_t y; __asm__("rev %0, %1":"=r"(y):"r"(x));y;})
#elif __GNUC__ && __arm__
	#define bswap32(x)                             \
		({uint32_t t,y;                            \
		__asm__("eor     %1, %2, %2, ror #16\n\t" \
				"bic     %1, %1, #0x00FF0000\n\t" \
				"mov     %0, %2, ror #8\n\t"      \
				"eor     %0, %0, %1, lsr #8"      \
				: "=r"(y), "=&r"(t) : "r"(x));y;})
#elif __GNUC__ && __i386__
	#define bswap32(x) ({uint64_t y=x;__asm__("bswap %0":"+r"(y));y;})
#else        /* Some compilers recognize the following pattern */
	#define bswap32(x)                         \
	   ((((x) & 0xff000000u) >> 24) | \
		(((x) & 0x00ff0000u) >>  8) | \
		(((x) & 0x0000ff00u) <<  8) | \
		(((x) & 0x000000ffu) << 24))
#endif

static inline uint64_t bswap64(uint64_t x) {
	union { uint64_t ll; uint32_t l[2]; } w, r;
	w.ll = x;
	r.l[0] = bswap32(w.l[1]);
	r.l[1] = bswap32(w.l[0]);
	return r.ll;
}

#endif

/* ----------------------------------------------------------------------- */
/* Define blocks and operationss -- Patch if incorrect on your compiler.   */
/* ----------------------------------------------------------------------- */

#if USE_SSE2
    #include <xmmintrin.h>              /* SSE instructions and _mm_malloc */
    #include <emmintrin.h>              /* SSE2 instructions               */
    typedef ALIGN(16) __m128i block;
    #define xor_block(x,y)        _mm_xor_si128(x,y)
    #define zero_block()          _mm_setzero_si128()
    #define unequal_blocks(x,y) \
    					   (_mm_movemask_epi8(_mm_cmpeq_epi8(x,y)) != 0xffff)
	#if __SSSE3__
    #include <tmmintrin.h>        /* SSSE3 instructions              */
    #define swap_if_le(b) \
      _mm_shuffle_epi8(b,_mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15))
	#else
    static inline block swap_if_le(block b) {
		block a = _mm_shuffle_epi32  (b, _MM_SHUFFLE(0,1,2,3));
		a = _mm_shufflehi_epi16(a, _MM_SHUFFLE(2,3,0,1));
		a = _mm_shufflelo_epi16(a, _MM_SHUFFLE(2,3,0,1));
		return _mm_xor_si128(_mm_srli_epi16(a,8), _mm_slli_epi16(a,8));
    }
	#endif
#elif USE_ALTIVEC
    #include <altivec.h>
    typedef ALIGN(16) vector unsigned block;
    #define xor_block(x,y)        vec_xor(x,y)
    #define zero_block()          vec_splat_u32(0)
    #define unequal_blocks(x,y)   vec_any_ne(x,y)
    #define swap_if_le(b)         (b)
#else
    typedef struct { uint64_t l,r; } block;
    static block xor_block(block x, block y)  {x.l^=y.l; x.r^=y.r; return x;}
    static block zero_block(void)        { const block t = {0,0}; return t; }
    #define unequal_blocks(x, y)         ((((x).l^(y).l)|((x).r^(y).r)) != 0)
    static inline block swap_if_le(block b) {
		const union { unsigned x; unsigned char endian; } little = { 1 };
    	if (little.endian) {
			block a;
			a.l = bswap64(b.l);
			a.r = bswap64(b.r);
			return a;
    	} else
    		return b;
    }
#endif

/* Sometimes it is useful to view a block as an array of other types.
/  Doing so is technically undefined, but well supported in compilers.     */
typedef union {
	uint64_t u64[2]; uint32_t u32[4]; uint8_t u8[16]; block bl;
} block_multiview;

/* ----------------------------------------------------------------------- */
/* AES - Code uses OpenSSL API. Other implementations get mapped to it.    */
/* ----------------------------------------------------------------------- */

/*---------------*/
#if USE_OPENSSL_AES
/*---------------*/

#include <openssl/aes.h>                            /* http://openssl.org/ */

/*-----------------*/
#elif USE_KASPER_AES
/*-----------------*/

typedef struct { ALIGN(16) uint32_t bs_key[11][32]; } AES_KEY;

int AES_set_encrypt_key(const unsigned char *userKey, const int bits, AES_KEY *key) {
	kasper_keysetup(key, userKey);
	return 0;
}

void kasper_ecb_encrypt8(
  const AES_KEY* ctx, 
  const unsigned char* input, 
  unsigned char* output);                /* Message length in bytes. */ 

void AES_encrypt(const unsigned char *in, unsigned char *out, const AES_KEY *key)
{
	ALIGN(16) unsigned char buf[8*16];
	kasper_ecb_encrypt8(key, in, buf);
	*(block *)out = *(block *)buf;
}

#define AES_set_decrypt_key    AES_set_encrypt_key

/*-----------------*/
#elif USE_VIA_ACE_AES
/*-----------------*/

typedef struct { ALIGN(16) char str[16], cword[16]; } AES_KEY;

static inline
void via_xcryptecb(void *in, void *out, int nblks, const AES_KEY *key)
{
	__asm__ __volatile__("xcryptecb\n"
	        : "+S"(in), "+D"(out), "+c"(nblks)
	        : "d"(key->cword), "b"(key->str) : "memory");	
}
#define AES_encrypt(x,y,z)       via_xcryptecb(x,y,1,z)
#define AES_decrypt(x,y,z)       via_xcryptecb(x,y,1,z)

int AES_set_encrypt_key(const unsigned char *userKey, const int bits, AES_KEY *key) {
	__asm__ __volatile__ ("pushf\n\tpopf\n" : : : "cc"); /* Indicate new key */
	memcpy(key->str,userKey,bits/8);
	memset(key->cword,0,sizeof(key->cword));
	key->cword[0] = 10; /* Set ROUND bits to 10 */
	return 0;
}
int AES_set_decrypt_key(const unsigned char *userKey, const int bits, AES_KEY *key) {
	int rval = AES_set_encrypt_key(userKey, bits, key);
	key->cword[1]=2; /* Set CRYPT bit for decryption */
	return rval;
}

/*------------------*/
#elif USE_CRYPTOPP_AES
/*------------------*/

#include <cryptopp/cryptlib.h>
#include <cryptopp/aes.h>
using CryptoPP::BlockCipher;
using CryptoPP::AES;

typedef AES::Encryption AES_KEY;
#define AES_encrypt(x,y,z)       (z)->ProcessBlock((byte *)(x),(byte *)(y))
#define AES_decrypt(x,y,z)       (z)->ProcessBlock((byte *)(x),(byte *)(y))
#define AES_set_encrypt_key(x, y, z)         (z)->SetKey((byte *)(x),(y)/8)
#define AES_set_decrypt_key(x, y, z)         (z)->SetKey((byte *)(x),(y)/8)

/*-------------------*/
#elif USE_REFERENCE_AES
/*-------------------*/

#include "rijndael-alg-fst.h"              /* Barreto's Public-Domain Code */
typedef struct { uint32_t rd_key[MAX_KEY_BYTES+28]; int rounds; } AES_KEY;
#define AES_encrypt(x,y,z)    rijndaelEncrypt((z)->rd_key, (z)->rounds, x, y)
#define AES_decrypt(x,y,z)    rijndaelDecrypt((z)->rd_key, (z)->rounds, x, y)
#define AES_set_encrypt_key(x, y, z) \
 do {rijndaelKeySetupEnc((z)->rd_key, x, y); (z)->rounds = y/32+6;} while (0)
#define AES_set_decrypt_key(x, y, z) \
 do {rijndaelKeySetupDec((z)->rd_key, x, y); (z)->rounds = y/32+6;} while (0)

#endif

/*----------*/
#if USE_AES_NI        /* It is acceptable that USE_OPENSSL_AES is true too */
/*----------*/

#include <wmmintrin.h>
#define AES_encrypt AES_encrypt_ni /* Avoid name conflict in openssl/aes.h */
#define AES_decrypt AES_decrypt_ni /* Avoid name conflict in openssl/aes.h */

#if USE_OPENSSL_AES       /* Use OpenSSL's key setup instead of intrinsics */

#define AES_ROUNDS(_key)  ((_key).rounds)

#else /* !USE_OPENSSL_AES -- Use intrinsics for key setup. AES-128 only    */

typedef struct { __m128i rd_key[7+MAX_KEY_BYTES/4]; } AES_KEY;
#define AES_ROUNDS(_key)  (10)
static __m128i assist128(__m128i a, __m128i b)
{
    __m128i tmp = _mm_slli_si128 (a, 0x04);
    a = _mm_xor_si128 (a, tmp);
    tmp = _mm_slli_si128 (tmp, 0x04);
    a = _mm_xor_si128 (_mm_xor_si128 (a, tmp), _mm_slli_si128 (tmp, 0x04));
    return _mm_xor_si128 (a, _mm_shuffle_epi32 (b ,0xff));
}
static void AES_set_encrypt_key(const unsigned char *userKey,
                                const int bits, AES_KEY *key)
{
    __m128i *sched = key->rd_key;
    (void)bits; /* Supress "unused" warning */
    sched[ 0] = _mm_loadu_si128((__m128i*)userKey);
    sched[ 1] = assist128(sched[0], _mm_aeskeygenassist_si128(sched[0],0x1));
    sched[ 2] = assist128(sched[1], _mm_aeskeygenassist_si128(sched[1],0x2));
    sched[ 3] = assist128(sched[2], _mm_aeskeygenassist_si128(sched[2],0x4));
    sched[ 4] = assist128(sched[3], _mm_aeskeygenassist_si128(sched[3],0x8));
    sched[ 5] = assist128(sched[4], _mm_aeskeygenassist_si128(sched[4],0x10));
    sched[ 6] = assist128(sched[5], _mm_aeskeygenassist_si128(sched[5],0x20));
    sched[ 7] = assist128(sched[6], _mm_aeskeygenassist_si128(sched[6],0x40));
    sched[ 8] = assist128(sched[7], _mm_aeskeygenassist_si128(sched[7],0x80));
    sched[ 9] = assist128(sched[8], _mm_aeskeygenassist_si128(sched[8],0x1b));
    sched[10] = assist128(sched[9], _mm_aeskeygenassist_si128(sched[9],0x36));
}
static void AES_NI_set_decrypt_key(__m128i *dkey, const __m128i *ekey)
{
    int i;
    dkey[10] = ekey[0];
    for (i = 1; i <= 9; i++) dkey[10-i] = _mm_aesimc_si128(ekey[i]);
    dkey[0] = ekey[10];
}

#endif  /* !USE_OPENSSL_AES */

static inline void AES_encrypt(const unsigned char *in,
                        unsigned char *out, const AES_KEY *key)
{
	int j;
	const __m128i *sched = ((__m128i *)(key->rd_key));
	__m128i tmp = _mm_load_si128 ((__m128i*)in);
	tmp = _mm_xor_si128 (tmp,sched[0]);
	for (j=1; j<AES_ROUNDS(*key); j++)  tmp = _mm_aesenc_si128 (tmp,sched[j]);
	tmp = _mm_aesenclast_si128 (tmp,sched[j]);
	_mm_store_si128 ((__m128i*)out,tmp);
}
static inline void AES_decrypt(const unsigned char *in,
                        unsigned char *out, const AES_KEY *key)
{
	int j;
	const __m128i *sched = ((__m128i *)(key->rd_key));
	__m128i tmp = _mm_load_si128 ((__m128i*)in);
	tmp = _mm_xor_si128 (tmp,sched[0]);
	for (j=1; j<AES_ROUNDS(*key); j++)  tmp = _mm_aesdec_si128 (tmp,sched[j]);
	tmp = _mm_aesdeclast_si128 (tmp,sched[j]);
	_mm_store_si128 ((__m128i*)out,tmp);
}

#endif

/* ----------------------------------------------------------------------- */
/* Define OCB context structure.                                           */
/* ----------------------------------------------------------------------- */

#if USE_AES_NI
#define ALIGN_IF_USING_AES_NI      ALIGN(16)
#else
#define ALIGN_IF_USING_AES_NI
#endif

struct _ae_ctx {
	block offset; /* Register correct */
	block checksum;
	block ad_checksum;
	block ad_offset;
	block ad_offset_base;
	ALIGN_IF_USING_AES_NI AES_KEY encrypt_key;
	ALIGN_IF_USING_AES_NI AES_KEY decrypt_key;
	int   ad_valid;
};

/* ----------------------------------------------------------------------- */
/* Mask increment -- Needs to be fast for key-setup, on-the-fly N/L lookup */
/* ----------------------------------------------------------------------- */

#if USE_SSE2

static inline block increment(block b) {
	const __m128i mask = _mm_set_epi32(135,1,1,1);
    __m128i t = _mm_srai_epi32(b, 31);
    t = _mm_and_si128(t, mask);
    t = _mm_shuffle_epi32(t, _MM_SHUFFLE(2,1,0,3));
    b = _mm_slli_epi32(b, 1);
    return _mm_xor_si128(b,t);
}

#elif USE_ALTIVEC

static inline block increment(block b) {
	const vector unsigned char mask   = {135,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1};
	const vector unsigned char perm   = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0};
	const vector unsigned char shift7 = vec_splat_u8(7);
	const vector unsigned char shift1 = vec_splat_u8(1);
	vector unsigned char c = (vector unsigned char)b;
	vector unsigned char t = vec_sra(c,shift7);
	t = vec_and(t,mask);
	t = vec_perm(t,t,perm);
	c = vec_sl(c,shift1);
	return (block)vec_xor(c,t);
}

#elif __GNUC__ && __arm__

static inline block increment(block b) {
    __asm__ ("adds %1,%1,%1\n\t"
             "adcs %H1,%H1,%H1\n\t"
             "adcs %0,%0,%0\n\t"
             "adcs %H0,%H0,%H0\n\t"
             "eorcs %1,%1,#135"
    : "+r"(b.l), "+r"(b.r) : : "cc");
    return b;
}

#else

static inline block increment(block b) {
    uint64_t t = (uint64_t)((int64_t)b.l >> 63);
    b.l = (b.l + b.l) ^ (b.r >> 63);
    b.r = (b.r + b.r) ^ (t & 135);
    return b;
}

#endif

#define increment3(b) xor_block(b,increment(b))
#define increment3twice(b) xor_block(b,increment(increment(b))) /* 3(3x)=(2(2x))^x */

/* ----------------------------------------------------------------------- */
/* Public functions                                                        */
/* ----------------------------------------------------------------------- */

/* Some systems do not 16-byte-align dynamic allocations involving 16-byte
/  vectors. Adjust the following if your system is one of these            */

ae_ctx* ae_allocate(void *misc)
{ 
	void *p;
	(void) misc;                     /* misc unused in this implementation */
	#if USE_CRYPTOPP_AES
		p = new ae_ctx;
	#elif USE_MM_MALLOC
    	p = _mm_malloc(sizeof(ae_ctx),16); 
	#elif USE_POSIX_MEMALIGN
		if (posix_memalign(&p,16,sizeof(ae_ctx)) != 0) p = NULL;
	#else
		p = malloc(sizeof(ae_ctx)); 
	#endif
	return (ae_ctx *)p;
}

void ae_free(ae_ctx *ctx)
{
	#if USE_CRYPTOPP_AES
		delete ctx;
	#elif USE_MM_MALLOC
		_mm_free(ctx);
	#else
		free(ctx);
	#endif
}

/* ----------------------------------------------------------------------- */

int ae_clear (ae_ctx *ctx) /* Zero ae_ctx and undo initialization          */
{
	memset(ctx, 0, sizeof(ae_ctx));
	return AE_SUCCESS;
}

int ae_ctx_sizeof(void) { return (int) sizeof(ae_ctx); }

/* ----------------------------------------------------------------------- */

int ae_init(ae_ctx *ctx, const void *key, int key_len, int nonce_len, int tag_len)
{
    block tmp;
    
    /* Initialize encryption & decryption keys */
    AES_set_encrypt_key((unsigned char *)key, key_len*8, &ctx->encrypt_key);
    #if !USE_OPENSSL_AES && USE_AES_NI
    AES_NI_set_decrypt_key(ctx->decrypt_key.rd_key,ctx->encrypt_key.rd_key);
    #else
    AES_set_decrypt_key((unsigned char *)key, (int)(key_len*8), &ctx->decrypt_key);
    #endif
    
    /* Initialize variables */
	tmp = zero_block();
	AES_encrypt((unsigned char *)&tmp, (unsigned char *)&tmp, &ctx->encrypt_key);
	ctx->ad_offset_base = increment3twice(swap_if_le(tmp));
    
    return AE_SUCCESS;
}

/* ----------------------------------------------------------------------- */
/* How to ECB encrypt BPI blocks in-place. BPI must be 4, 8 or 16          */
/* ----------------------------------------------------------------------- */

#if USE_AES_NI

#define BPI 4
#define ENCRYPT_BLOCKS(num,arr,ctx)  do {                                          \
    int i,j; \
	for (i=0; i<num; ++i) \
	    arr[i] =_mm_xor_si128(arr[i], ((block*)(ctx->encrypt_key.rd_key))[0]);     \
	for(j=1; j<AES_ROUNDS(ctx->encrypt_key); ++j)                                  \
	    for (i=0; i<num; ++i) \
		    arr[i] = _mm_aesenc_si128(arr[i], ((block*)(ctx->encrypt_key.rd_key))[j]); \
	for (i=0; i<num; ++i) \
	    arr[i] =_mm_aesenclast_si128(arr[i], ((block*)(ctx->encrypt_key.rd_key))[j]);  \
	} while (0)

#define ENCRYPT_BLOCKSx(num,arr,ctx)  do {                                          \
	arr[0] =_mm_xor_si128(arr[0], ((block*)(ctx->encrypt_key.rd_key))[0]);         \
	arr[1] =_mm_xor_si128(arr[1], ((block*)(ctx->encrypt_key.rd_key))[0]);         \
	arr[2] =_mm_xor_si128(arr[2], ((block*)(ctx->encrypt_key.rd_key))[0]);         \
	arr[3] =_mm_xor_si128(arr[3], ((block*)(ctx->encrypt_key.rd_key))[0]);         \
	for(j=1; j<AES_ROUNDS(ctx->encrypt_key); ++j) {                                \
		arr[0] = _mm_aesenc_si128(arr[0], ((block*)(ctx->encrypt_key.rd_key))[j]); \
		arr[1] = _mm_aesenc_si128(arr[1], ((block*)(ctx->encrypt_key.rd_key))[j]); \
		arr[2] = _mm_aesenc_si128(arr[2], ((block*)(ctx->encrypt_key.rd_key))[j]); \
		arr[3] = _mm_aesenc_si128(arr[3], ((block*)(ctx->encrypt_key.rd_key))[j]); \
	} \
	arr[0] =_mm_aesenclast_si128(arr[0], ((block*)(ctx->encrypt_key.rd_key))[j]);  \
	arr[1] =_mm_aesenclast_si128(arr[1], ((block*)(ctx->encrypt_key.rd_key))[j]);  \
	arr[2] =_mm_aesenclast_si128(arr[2], ((block*)(ctx->encrypt_key.rd_key))[j]);  \
	arr[3] =_mm_aesenclast_si128(arr[3], ((block*)(ctx->encrypt_key.rd_key))[j]);  \
	} while (0)

#elif USE_KASPER_AES

#define BPI 8
#define ENCRYPT_BLOCKS(num,arr,ctx) kasper_ecb_encrypt8(&ctx->encrypt_key, (unsigned char *)arr, (unsigned char *)arr)

#elif USE_VIA_ACE_AES

#define BPI 16
#define ENCRYPT_BLOCKS(num,arr,ctx) via_xcryptecb(arr, arr, num, &ctx->encrypt_key)

#elif USE_CRYPTOPP_AES

#define BPI 16
#define ENCRYPT_BLOCKS(num,arr,ctx) (ctx)->encrypt_key.AdvancedProcessBlocks((byte *)(arr), NULL, (byte *)(arr), (num*16), 0)

#else

#define BPI 4
#define ENCRYPT_BLOCKS(num,arr,ctx)                                                            \
	do { int j; for (j=0;j<num;++j) \
		AES_encrypt((unsigned char *)&(arr[j]), (unsigned char *)&(arr[j]), &ctx->encrypt_key);  \
	} while (0)

#endif

/* ----------------------------------------------------------------------- */

int ae_encrypt(ae_ctx     * restrict ctx,
               const void *nonce,
               const void *pt,
               int         pt_len,
               const void *ad,
               int         ad_len,
               void       *ct,
               void       *tag,
               int         final)
{
    block t1, t2, om1, ad_offset, offset, ad_checksum, checksum, oa[BPI], ta[BPI];
    block_multiview tmp;
    unsigned i, j, k, remaining=0, ad_remaining=0, iters, ad_iters, ad_valid;
    block       * restrict ctp = (block *)ct;
    const block * restrict ptp = (block *)pt;
    const block * restrict adp = (block *)ad;

	/* When nonce is non-null we know that this is the start of a new message.
	 * If so, update cached AES if needed and initialize offsets/checksums.
	 */
    if (nonce) { /* Indicates start of new message */
		AES_encrypt((unsigned char *)nonce, (unsigned char *)&offset, &ctx->encrypt_key);
		offset = swap_if_le(offset);     /* Make offset "register correct" */
        checksum = zero_block();
        if (ad_len >= 0) {
        	ad_checksum = zero_block();
        	ad_offset = ctx->ad_offset_base;
        	ad_valid = 0;
        } else {
			ad_offset    = ctx->ad_offset;
			ad_checksum  = ctx->ad_checksum;
			ad_valid     = ctx->ad_valid;
        }
    } else {
        /* If not a new message, restore values from ctx */
        offset       = ctx->offset;
        checksum     = ctx->checksum;
        ad_offset    = ctx->ad_offset;
        ad_checksum  = ctx->ad_checksum;
        ad_valid     = ctx->ad_valid;
    }
    
    ad_valid |= ad_len;  /* Ensure ctx->ad_valid!=0 if ever ad_len!=0 */
    
    /* Calculate how many 64-byte iterations needed, and what's left after */
    ad_iters = ad_len/(BPI*16);
    iters    = pt_len/(BPI*16);
    if (final) {
    	ad_remaining = ad_len % (BPI*16);
    	remaining    = pt_len % (BPI*16);
    	if (ad_remaining == 0 && ad_iters != 0) {
    		ad_iters -= 1;
    		ad_remaining += (BPI*16);
    	}
    	if (remaining == 0 && iters != 0) {
    		iters -= 1;
    		remaining += (BPI*16);
    	}
    }

	/* Handle associated data BPI-blocks at a time. */
    while (ad_iters) {
		for (k=0; k<BPI; k+=4) {
        	ad_offset = increment(ad_offset);
			ta[k]     = xor_block(swap_if_le(ad_offset), adp[k]);
        	ad_offset = increment(ad_offset);
			ta[k+1]     = xor_block(swap_if_le(ad_offset), adp[k+1]);
        	ad_offset = increment(ad_offset);
			ta[k+2]     = xor_block(swap_if_le(ad_offset), adp[k+2]);
        	ad_offset = increment(ad_offset);
			ta[k+3]     = xor_block(swap_if_le(ad_offset), adp[k+3]);
		}
		ENCRYPT_BLOCKS(BPI,ta,ctx);  /* ECB BPI-blocks in-place */
		for (k=0; k<BPI; k+=4) {
			ad_checksum = xor_block(ad_checksum, ta[k]);
			ad_checksum = xor_block(ad_checksum, ta[k+1]);
			ad_checksum = xor_block(ad_checksum, ta[k+2]);
			ad_checksum = xor_block(ad_checksum, ta[k+3]);
		}
        adp += BPI;
    	--ad_iters;
    }
    
	/* Encrypt plaintext data BPI-blocks at a time. */
    while (iters) {
 		for (k=0; k<BPI; k+=4) {
			offset   = increment(offset);
			oa[k]    = swap_if_le(offset);
			ta[k]    = xor_block(oa[k], ptp[k]);
			checksum = xor_block(checksum, ptp[k]);
			offset   = increment(offset);
			oa[k+1]    = swap_if_le(offset);
			ta[k+1]    = xor_block(oa[k+1], ptp[k+1]);
			checksum = xor_block(checksum, ptp[k+1]);
			offset   = increment(offset);
			oa[k+2]    = swap_if_le(offset);
			ta[k+2]    = xor_block(oa[k+2], ptp[k+2]);
			checksum = xor_block(checksum, ptp[k+2]);
			offset   = increment(offset);
			oa[k+3]    = swap_if_le(offset);
			ta[k+3]    = xor_block(oa[k+3], ptp[k+3]);
			checksum = xor_block(checksum, ptp[k+3]);
		}
		ENCRYPT_BLOCKS(BPI,ta,ctx);
		for (k=0; k<BPI; k+=4) {
			ctp[k]   = xor_block(ta[k], oa[k]);
			ctp[k+1] = xor_block(ta[k+1], oa[k+1]);
			ctp[k+2] = xor_block(ta[k+2], oa[k+2]);
			ctp[k+3] = xor_block(ta[k+3], oa[k+3]);
		}
        ptp += BPI;
        ctp += BPI;
        --iters;
    }
    
	/* If final is non-zero, then this is the end of the message. */
    if (final) {

#if BPI <= 4
        if (ad_valid) {
			while (ad_remaining > 16) {
				t1 = *adp;
				ad_offset = increment(ad_offset);
				om1 = swap_if_le(ad_offset);
				t1 = xor_block(t1, om1);
				AES_encrypt((unsigned char *)&t1, (unsigned char *)&t1, &ctx->encrypt_key);
				ad_checksum = xor_block(ad_checksum, t1);
				adp += 1;
				ad_remaining -= 16;
			}
			ad_offset = increment(ad_offset);
			if (ad_remaining == 16) {
				ad_offset = increment3(ad_offset);
				ad_checksum = xor_block(ad_checksum, *adp);
			} else {
				ad_offset = increment3twice(ad_offset);
				t1 = zero_block();
				memcpy(&t1, adp, ad_remaining);
				((unsigned char *)&t1)[ad_remaining] = (unsigned char)0x80;
				ad_checksum = xor_block(ad_checksum, t1);
			}
			om1 = swap_if_le(ad_offset);
			ad_checksum = xor_block(ad_checksum, om1);
			AES_encrypt((unsigned char *)&ad_checksum, (unsigned char *)&ad_checksum, &ctx->encrypt_key);
		}

        while (remaining > 16) {
			t1 = *ptp;
			offset = increment(offset);
			om1 = swap_if_le(offset);
			checksum = xor_block(checksum, t1);
			t1 = xor_block(t1, om1);
			AES_encrypt((unsigned char *)&t1, (unsigned char *)&t1, &ctx->encrypt_key);
			*ctp = xor_block(t1, om1);
            ptp += 1;
            ctp += 1;
            remaining -= 16;
        }
		offset = increment(offset);
		t1 = swap_if_le(offset);
		((char *)&t1)[15] ^= (char)(remaining * 8);
		AES_encrypt((unsigned char *)&t1, (unsigned char *)&t2, &ctx->encrypt_key);
		t1 = t2;
		memcpy(&t2,ptp,remaining);  /* t2 = pt||pad */
		t1 = xor_block(t1,t2);      /* ct||0   */
		#if SAFE_OUTPUT_BUFFERS
		*ctp = t1;
		#else
		memcpy(ctp,&t1,remaining);
		#endif
		checksum = xor_block(checksum, t2);

#else

        if (ad_valid) {
			k = 0;
			while (ad_remaining > 16) {
				ad_offset = increment(ad_offset);
				ta[k]     = xor_block(swap_if_le(ad_offset), adp[k]);
				++k;
				ad_remaining -= 16;
			}
			ENCRYPT_BLOCKS(k,ta,ctx);
			for (j=0; j<k; ++j)
				ad_checksum = xor_block(ad_checksum, ta[j]);
			ad_offset = increment(ad_offset);
			if (ad_remaining == 16) {
				ad_offset = increment3(ad_offset);
				ad_checksum = xor_block(ad_checksum, adp[k]);
			} else {
				ad_offset = increment3twice(ad_offset);
				tmp.bl = zero_block();
				memcpy(tmp.u8, adp+k, ad_remaining);
				tmp.u8[ad_remaining] = (unsigned char)0x80;
				ad_checksum = xor_block(ad_checksum, tmp.bl);
			}
			ad_offset = swap_if_le(ad_offset);
			ad_checksum = xor_block(ad_checksum, ad_offset);
			AES_encrypt((unsigned char *)&ad_checksum, (unsigned char *)&ad_checksum, &ctx->encrypt_key);
		}

		k = 0;
		while (remaining > 16) {
			offset   = increment(offset);
			oa[k]    = swap_if_le(offset);
			ta[k]    = xor_block(oa[k], ptp[k]);
			checksum = xor_block(checksum, ptp[k]);
			++k;
			remaining -= 16;
		}
		offset = increment(offset);
		tmp.bl = swap_if_le(offset);
		tmp.u8[15] ^= (remaining * 8);
		ta[k] = tmp.bl;
		ENCRYPT_BLOCKS(k+1,ta,ctx);
		for (j=0; j<k; ++j)
			ctp[j] = xor_block(ta[j], oa[j]);
		if (remaining == 16) {
			ctp[k] = xor_block(ta[k], ptp[k]);
			checksum = xor_block(checksum, ptp[k]);
		} else {
			block t2, t1 = ta[k];
			t2 = t1;
			memcpy(&t2,ptp+k,remaining);  /* pt||pad */
			#if SAFE_OUTPUT_BUFFERS
			ctp[k] = xor_block(t1,t2);      /* ct||0   */
			#else
			t1 = xor_block(t1,t2);      /* ct||0   */
			memcpy(ctp,&t1,remaining);
			#endif
			checksum = xor_block(checksum, t2);
		}
#endif        
        /* Generate tag and place at the correct location
         */
		offset = increment3(offset);
		offset = swap_if_le(offset);
		checksum = xor_block(checksum, offset);
        AES_encrypt((unsigned char *)&checksum, tmp.u8, &ctx->encrypt_key);
        if (ad_valid)
        	tmp.bl = xor_block(tmp.bl, ad_checksum); /* ad_checksum is zero if no AD */
        if (tag) {
            *(block *)tag = tmp.bl;
        } else {
            memcpy((char *)(ctp + k) + remaining, tmp.u8, 16);
            pt_len += 16;
        }
        
    } else {
        /* If not done with message, store values to ctx */
        ctx->offset = offset;
        ctx->checksum = checksum;
        ctx->ad_offset = ad_offset;
        ctx->ad_checksum = ad_checksum;
        ctx->ad_valid = ad_valid;
    }
    return (int) pt_len;
}

/* ----------------------------------------------------------------------- */
/* Simple test program                                                     */
/* ----------------------------------------------------------------------- */

#if 0

#include <stdio.h>
#include <time.h>

static void pbuf(void *p, unsigned len, const void *s)
{
    unsigned i;
    if (s)
        printf("%s", (char *)s);
    for (i = 0; i < len; i++)
        printf("%02X", (unsigned)(((unsigned char *)p)[i]));
    printf("\n");
}

#define VAL_LEN 1000
static void validate()
{
	int ptag = 1;
	int pct = 1;
    ALIGN(16) char pt[VAL_LEN];
    ALIGN(16) char ct[VAL_LEN+16];
    ALIGN(16) char tag[16];
    ALIGN(16) char nonce[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
    ALIGN(16) char key[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
    ALIGN(16) char expected[] = {0xC4,0xEE,0xF0,0x43,0x21,0xA9,0x6B,0x2D,
                                 0x5A,0x4E,0x11,0x48,0xAD,0x9B,0xB4,0x1F};
    
    block result;
    ae_ctx ctx;
    int i;
    
    for (i = 0; i < 40; ++i)
    	pt[i] = i;
    
    ae_init(&ctx, key, 16, 16, 16);
    /* pbuf(&ctx, sizeof(ctx), "CTX: "); */

	result = zero_block();
	i = 0;
	ae_encrypt(&ctx,nonce,pt,i,NULL,0,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 8;
	ae_encrypt(&ctx,nonce,pt,i,NULL,0,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 16;
	ae_encrypt(&ctx,nonce,pt,i,NULL,0,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 24;
	ae_encrypt(&ctx,nonce,pt,i,NULL,0,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 32;
	ae_encrypt(&ctx,nonce,pt,i,NULL,0,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 40;
	ae_encrypt(&ctx,nonce,pt,i,NULL,0,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);

	i = 8;
	ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 16;
	ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 24;
	ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 32;
	ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);
	i = 40;
	ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);

	pbuf(&result, 16, "Xor of ID tags: ");

	i = 1000;
	memset(pt,0,1000);
	ae_encrypt(&ctx,nonce,pt,768,pt,768,ct,tag,AE_PENDING);
	ae_encrypt(&ctx,NULL,pt+768,232,pt+768,232,ct+768,tag,AE_FINALIZE);
    if (ptag) pbuf(tag, 16, "Validation string: ");
    if (pct) pbuf(ct, i, "Ciphertext: ");
	result = xor_block(result,*(block *)tag);

	if (memcmp(&result,expected,16) == 0)
		printf("Pass\n");
	else
		printf("Fail\n");
}

static void do_ocb(double Hz)
{
    ALIGN(16) char b1[2048];
    ALIGN(16) char tag[16];
    ALIGN(16) char nonce[16];
    ALIGN(16) char key[] = "abcdefghijklmnop";
    clock_t c;
    #if USE_VIA_ACE_AES || USE_AES_NI
    unsigned i, CALLS = (unsigned)(Hz/5000);
    #else
    unsigned i, CALLS = (unsigned)(Hz/25000);
    #endif
    
    ae_ctx* ctx = ae_allocate(NULL);
    if (ctx) {
        ae_init(ctx, key, 16,16,16);
        ae_encrypt(ctx, nonce, b1, (int)sizeof(b1),NULL,0, b1, tag, AE_FINALIZE);
        nonce[15] += 1;
        c = clock();
        for (i = 0; i < CALLS; i++) {
            ae_encrypt(ctx, nonce, b1, (int)sizeof(b1),NULL,0, b1, tag, AE_FINALIZE);
            nonce[15] += (char) 1;
        }
        c = clock() - c;
        printf("OCB\n%.2f seconds.\n", ((float)c)/CLOCKS_PER_SEC);
        printf("%.1f cpb.\n", (((double)c)/CLOCKS_PER_SEC) * 
               Hz / (sizeof(b1) * CALLS));
        ae_free(ctx);
    }
}

static void do_raw(double Hz)
{
    ALIGN(16) AES_KEY key;
    ALIGN(16) char b1[2048];
    ALIGN(16) char userkey[] = "abcdefghijklmnop";
    clock_t c;
    unsigned i, CALLS = (unsigned)(Hz/25000);
    
    AES_set_encrypt_key((void *)userkey, 128, &key);
    c = clock();
    for (i = 0; i < CALLS * (unsigned)(sizeof(b1)/16); i++) {
        AES_encrypt((void*)b1, (void*)b1, &key);
    }
    c = clock() - c;
    printf("RAW\n%.2f seconds.\n", ((float)c)/CLOCKS_PER_SEC);
    printf("%.1f cpb.\n", (((double)c)/CLOCKS_PER_SEC) * 
           Hz / (sizeof(b1) * CALLS));
}

int main(int argc, char **argv)
{
	double Hz;
	if ((argc != 2)) {
		printf("Usage: %s MHz\n", argv[0]);
		return 0;
	} else {
		Hz = 1e6 * strtol(argv[1], (char **)NULL, 10);
	}
    validate();
    do_ocb(Hz);
    do_raw(Hz);
    return 0;
}
#endif

#if USE_AES_NI && USE_OPENSSL_AES
char infoString[] = "OCB2 (AES-NI w/ OpenSSL Keying)";
#elif USE_AES_NI
char infoString[] = "OCB2 (AES-NI)";
#elif USE_CRYPTOPP_AES
char infoString[] = "OCB2 (Crypto++)";
#elif USE_REFERENCE_AES
char infoString[] = "OCB2 (Reference)";
#elif USE_OPENSSL_AES
char infoString[] = "OCB2 (OpenSSL)";
#elif USE_KASPER_AES
char infoString[] = "OCB2 (Kasper/Schwabe)";
#else
char infoString[] = "OCB2";
#endif
