/*
 * Copyright (c) 1997-2005 Erez Zadok <ezk@cs.stonybrook.edu>
 * Copyright (c) 2001-2005 Stony Brook University
 *
 * For specific licensing information, see the COPYING file distributed with
 * this package, or get one from ftp://ftp.filesystems.org/pub/fistgen/COPYING.
 *
 * This Copyright notice must be kept intact and distributed with all
 * fistgen sources INCLUDING sources generated by fistgen.
 */
/*
 * fxns.c: expand fist functions to their values, if needed
 * Fistgen sources.
 */

#ifdef HAVE_CONFIG_H
# include <config.h>
#endif /* HAVE_CONFIG_H */


/* store code for all auto-generated functions */
bdt_t *auto_generated_functions;
/* store all "extern" definitions for auto-generated functions */
ecl_t auto_generated_externs;

/* forward definitions */
static int expand_fistMemCpy(char *buf, int argc, char *argv[]);
static int expand_fistSetErr(char *buf, int argc, char *argv[]);
static int expand_fistLastErr(char *buf, int argc, char *argv[]);
static int expand_fistReturnErr(char *buf, int argc, char *argv[]);
static int expand_fistMalloc(char *buf, int argc, char *argv[]);
static int expand_fistFree(char *buf, int argc, char *argv[]);
static int expand_fistStrEq(char *buf, int argc, char *argv[]);
static int expand_fistLookup(char *buf, int argc, char *argv[]);
static int expand_fistSkipName(char *buf, int argc, char *argv[]);
static int expand_fistGetIoctlData(char *buf, int argc, char *argv[]);
static int expand_fistSetIoctlData(char *buf, int argc, char *argv[]);
static int expand_fistGetFileData(char *buf, int argc, char *argv[]);
static int expand_fistSetFileData(char *buf, int argc, char *argv[]);
/* forward definitions for two dummy test functions */
static int expand_fistFoo(char *buf, int argc, char *argv[]);
static int expand_fistBar(char *buf, int argc, char *argv[]);

/* fist auxiliary function expand functions */
static aux_exp_fxn_t aux_exp_fxns[] = {
  /* name         minargs maxargs func */
  {"MemCpy",		3, 3, expand_fistMemCpy},
  {"SetErr",		1, 1, expand_fistSetErr},
  {"LastErr",		0, 0, expand_fistLastErr},
  {"ReturnErr",		0, 1, expand_fistReturnErr},
  {"Malloc",		1, 1, expand_fistMalloc},
  {"Free",		2, 2, expand_fistFree},
  {"StrEq",		2, 2, expand_fistStrEq},
  {"Lookup",		5, 5, expand_fistLookup},
  {"SkipName",		0, 1, expand_fistSkipName},
  {"GetIoctlData",	3, 3, expand_fistGetIoctlData},
  {"SetIoctlData",	3, 3, expand_fistSetIoctlData},
  {"GetFileData",	4, 4, expand_fistGetFileData},
  {"SetFileData",	4, 4, expand_fistSetFileData},
  /* two dummy test functions */
  {"Foo",		0, 2, expand_fistFoo},
  {"Bar",		2, 3, expand_fistBar},
  {NULL,		0, 0, NULL}
};

/* append parenthesized list of arguments into buf */
static void
print_func_args(char *buf, int argc, char *argv[])
{
  int i;

  if (argc <= 0) {
    strcat(buf, "()");
    return;
  }

  strcat(buf, "(");
  for (i=0; i<argc-1; i++) {
    strcat(buf, argv[i]);
    strcat(buf, ", ");
  }
  strcat(buf, argv[argc-1]);
  strcat(buf, ")");
}


/*
 * Expand a fist function.
 * Note: value returned is a static variable.
 * If cannot expand value, leaves it the same.
 * Return TRUE/FALSE if function could be exapnded ok (FALSE may indicate
 * a syntax error such as wrong number of arguments).
 */
int
expand_fist_fxn(char *buf, const char *fxn, int argc, char *argv[])
{
  char *cp;
  aux_exp_fxn_t *efp;

  /* skip past "fist" */
  if (strncmp(fxn, "fist", 4) != 0) {
    sprintf(buf, "invalid fist auxiliary function name: %s", fxn);
    return FALSE;
  }
  cp = (char *) &fxn[4];

  for (efp = aux_exp_fxns; efp->auxname; efp++) {
    if (STREQ(cp, efp->auxname)) {
      if (argc < efp->min_args) {
	sprintf(buf, "%s takes at least %d arguments", fxn, efp->min_args);
	return FALSE;
      }
      if (argc > efp->max_args) {
	sprintf(buf, "%s takes no more than %d arguments", fxn, efp->max_args);
	return FALSE;
      }
      return (efp->func)(buf, argc, argv);
    }
  }

  /* any unknown function is an error */
  sprintf(buf, "unknown fist auxiliary function: %s", fxn);
  return FALSE;
}


/* expand fistMemCpy function */
static int
expand_fistMemCpy(char *buf, int argc, char *argv[])
{
#ifdef __solaris__
  /* bcopy swaps the first two arguments of memcpy */
  char *tmp = argv[0];
  argv[0] = argv[1];
  argv[1] = tmp;
  //  sprintf(buf, "bcopy(%s, %s, %s)", argv[1], argv[0], argv[2]);
  strcpy(buf, "bcopy");
#else /* not __solaris__ */
  strcpy(buf, "memcpy");
#endif /* not __solaris__ */
  print_func_args(buf, argc, argv);
  return TRUE;
}


/* expand fistSetErr function */
static int
expand_fistSetErr(char *buf, int argc, char *argv[])
{
#ifdef __linux__
  sprintf(buf, "err = -%s", argv[0]);
#else /* not __Linux__ */
  sprintf(buf, "error = %s", argv[0]);
#endif /* not __Linux__ */
  return TRUE;
}


/* expand fistLastErr function */
static int
expand_fistLastErr(char *buf, int argc, char *argv[])
{
#ifdef __freebsd__
  strcpy(buf, "error");
#else /* not __freebsd__ */
  strcpy(buf, "err");
#endif /* not __freebsd__ */
  return TRUE;
}


/* expand fistReturnErr function */
static int
expand_fistReturnErr(char *buf, int argc, char *argv[])
{
#ifdef __linux__
  if (argc == 0)
    strcpy(buf, "goto out");
  else
    sprintf(buf, "do {err = -%s; goto out;} while(0)", argv[0]);
#else /* not __linux */
  if (argc == 0)
    strcpy(buf, "return err");
  else
    sprintf(buf, "do {err = %s; return err;} while(0)", argv[0]);
#endif /* not __linux */
  return TRUE;
}


/* expand fistMalloc function */
static int
expand_fistMalloc(char *buf, int argc, char *argv[])
{
  sprintf(buf, FIST_KMEM_ALLOC, argv[0]);
  return TRUE;
}


/* expand fistFree function */
static int
expand_fistFree(char *buf, int argc, char *argv[])
{
#ifdef __freebsd__
  sprintf(buf, FIST_KMEM_FREE, argv[0]);
#else /* not __freebsd__ */
  sprintf(buf, FIST_KMEM_FREE, argv[0], argv[1]);
#endif /* not __freebsd__ */
  return TRUE;
}


/* expand fistStrEq function */
static int
expand_fistStrEq(char *buf, int argc, char *argv[])
{
  sprintf(buf, "(!strcmp(%s, %s))", argv[0], argv[1]);
  return TRUE;
}


/* expand fistLookup function */
static int
expand_fistLookup(char *buf, int argc, char *argv[])
{
  need_aux_sources = TRUE;

#ifdef __linux__
  sprintf(buf, "(hidden_dentry = fist_lookup(%s, %s, &(%s), %s, %s))",
	  argv[0], argv[1], argv[2], argv[3], argv[4]);
#endif /* __linux__ */
#ifdef __solaris__
  sprintf(buf, "(err = fist_lookup(%s, %s, vpp, pnp, flags, rdir, cr, %s, %s))",
	  argv[0], argv[1], argv[3], argv[4]);
#endif /* __solaris__ */
#ifdef __freebsd__
  sprintf(buf, "(err = fist_lookup(%s, %s, ap->a_vpp, ap->a_cnp, flags, NULL, cr, %s, %s))",
	  argv[0], argv[1], argv[3], argv[4]);
#endif /* __freebsd__ */

  return TRUE;
}


/* expand fistSkipName function */
static int
expand_fistSkipName(char *buf, int argc, char *argv[])
{
#ifdef __linux__
  sprintf(buf, "return 0");
#endif /* __linux__ */
#ifdef __solaris__
  sprintf(buf, "{kmem_free(temp_name, temp_length); continue;}");
#endif /* __solaris__ */
#ifdef __freebsd__
  sprintf(buf, "{free(temp_name, M_TEMP); continue;}");
#endif /* __freebsd__ */

  return TRUE;
}


/****************************************************************************/
/*** IOCTL FUNCTIONS SUPPORT						  ***/
/****************************************************************************/

/* expand fistGetIoctlData function */
static int
expand_fistGetIoctlData(char *buf, int argc, char *argv[])
{
  char name[MAX_BUF_LEN], fxnbuf[MAX_BUF_LEN];
  char mcpfxn[MAX_BUF_LEN];

  /* form name of special ioctl function */
  sprintf(name, "fist_get_ioctl_data_%s_%s", argv[0], argv[1]);

  /* replace function */
  sprintf(buf, "%s(%s, sizeof(%s), (void *)arg)", name, argv[2], argv[2]);

  /* check if support function if needed */
  if (fist_search_bdt(name, auto_generated_functions))
    return TRUE;

  /* produce code for function */
  sprintf(mcpfxn, FIST_MEMCPY_4_FROM_IOCTL, argv[1]);
#ifndef __freebsd__
  sprintf(fxnbuf, "int\n" \
"%s(void *out, int len, void *arg)\n{\n" \
"  int ret;\n" \
"  struct _fist_ioctl_%s this_ioctl;\n" \
"\n" \
"  ret = %s;\n" \
"  if (ret >= 0)\n" \
"    %s;\n" \
"  return ret;\n" \
"}\n", name, argv[0], FIST_COPY_FROM_USER_4_IOCTL, mcpfxn);
#else /* __freebsd__ */
  sprintf(fxnbuf, "int\n" \
"%s(void *out, int len, void *arg)\n{\n" \
"  struct _fist_ioctl_%s this_ioctl;\n" \
"\n" \
"  %s;\n" \
"  %s;\n" \
"\n" \
"  return 0;\n" \
"}\n", name, argv[0], FIST_COPY_FROM_USER_4_IOCTL, mcpfxn);
#endif /*__freebsd__ */

  if (!append_to_bdt(&auto_generated_functions, fxnbuf, name)) {
    sprintf(buf, "cannot store ioctl function %s", name);
    return FALSE;
  }
  /* side effect: append "extern" definition to fist header file */
  sprintf(fxnbuf, "extern int %s(void *out, int len, void *arg);\n", name);
  if (!ecl_strcat(&auto_generated_externs, fxnbuf)) {
    sprintf(buf, "cannot store ioctl extern definition %s", name);
    return FALSE;
  }
  return TRUE;
}


/* expand fistSetIoctlData function */
static int
expand_fistSetIoctlData(char *buf, int argc, char *argv[])
{
  char name[MAX_BUF_LEN], fxnbuf[MAX_BUF_LEN];
  char mcpfxn[MAX_BUF_LEN];

  /* form name of special ioctl function */
  sprintf(name, "fist_set_ioctl_data_%s_%s", argv[0], argv[1]);

  /* replace function */
  sprintf(buf, "%s(%s, sizeof(%s), (void *)arg)", name, argv[2], argv[2]);

  /* check if support function if needed */
  if (fist_search_bdt(name, auto_generated_functions))
    return TRUE;

  /* produce code for function */
  sprintf(mcpfxn, FIST_MEMCPY_4_TO_IOCTL, argv[1]);
#ifndef __freebsd__
  sprintf(fxnbuf, "int\n" \
"%s(void *in, int len, void *arg)\n{\n" \
"  int ret;\n" \
"  struct _fist_ioctl_%s this_ioctl;\n" \
"\n" \
"  %s;\n" \
"  ret = %s;\n" \
"  return ret;\n" \
"}\n", name, argv[0], mcpfxn, FIST_COPY_TO_USER_4_IOCTL);
#else /* __freebsd__ */
  sprintf(fxnbuf, "int\n" \
"%s(void *in, int len, void *arg)\n{\n" \
"  struct _fist_ioctl_%s this_ioctl;\n" \
"\n" \
"  %s;\n" \
"  %s;\n" \
"  return 0;\n" \
"}\n", name, argv[0], mcpfxn, FIST_COPY_TO_USER_4_IOCTL);
#endif /*__freebsd__ */

  if (!append_to_bdt(&auto_generated_functions, fxnbuf, name)) {
    sprintf(buf, "cannot store ioctl function %s", name);
    return FALSE;
  }
  /* side effect: append "extern" definition to fist header file */
  sprintf(fxnbuf, "extern int %s(void *in, int len, void *arg);\n", name);
  if (!ecl_strcat(&auto_generated_externs, fxnbuf)) {
    sprintf(buf, "cannot store ioctl extern definition %s", name);
    return FALSE;
  }
  return TRUE;
}


/*
 * print out all auto-generated functions
 */
void
print_auto_generated_functions(FILE *fp)
{
  bdt_t *bdt = auto_generated_functions;

  if (bdt)
    fputs("\n\n", fp);
  while (bdt) {
    fputs(bdt->full_line, fp);
    fputs("\n\n", fp);
    bdt = bdt->next;
  }
}


/****************************************************************************/
/*** FILEFORMAT FUNCTIONS SUPPORT					  ***/
/****************************************************************************/

/* expand fistGetFileData function */
static int
expand_fistGetFileData(char *buf, int argc, char *argv[])
{
  char name[MAX_BUF_LEN], fxnbuf[MAX_BUF_LEN];
  char mcpfxn[MAX_BUF_LEN];

  /* form name of special fileformat function */
  sprintf(name, "fist_get_fileformat_data_%s_%s", argv[1], argv[2]);

  /* replace function */
  sprintf(buf, "%s(%s, &%s, sizeof(%s))", name, argv[0], argv[3], argv[3]);

  /* check if support function if needed */
  if (fist_search_bdt(name, auto_generated_functions))
    return TRUE;

  /* produce code for function */
  sprintf(mcpfxn, FIST_MEMCPY_4_FF_READ, argv[2]);
  sprintf(fxnbuf, "int\n" \
"%s(const char *filename, void *out, int len)\n{\n" \
"  int ret;\n" \
"  struct _fist_fileformat_%s this_ff;\n" \
"\n" \
"  ret = %s_read_file(filename, &this_ff,\n" \
"			sizeof(struct _fist_fileformat_%s));\n" \
"  if (ret < 0)\n" \
"    return ret;\n" \
"\n" \
"  %s;\n" \
"  return 0;\n" \
"}\n", name, argv[1], fist_globals.fg_fsname, argv[1], mcpfxn);

  if (!append_to_bdt(&auto_generated_functions, fxnbuf, name)) {
    sprintf(buf, "cannot store fileformat function %s", name);
    return FALSE;
  }
  /* side effect: append "extern" definition to fist header file */
  sprintf(fxnbuf, "extern int %s(const char *filename, void *out, int len);\n", name);
  if (!ecl_strcat(&auto_generated_externs, fxnbuf)) {
    sprintf(buf, "cannot store fileformat extern definition %s", name);
    return FALSE;
  }
  return TRUE;
}


/* expand fistSetFileData function */
static int
expand_fistSetFileData(char *buf, int argc, char *argv[])
{
  char name[MAX_BUF_LEN], fxnbuf[MAX_BUF_LEN];
  char mcpfxn[MAX_BUF_LEN];

  /* form name of special fileformat function */
  sprintf(name, "fist_set_fileformat_data_%s_%s", argv[1], argv[2]);

  /* replace function */
  sprintf(buf, "%s(%s, &%s, sizeof(%s))", name, argv[0], argv[3], argv[3]);

  /* check if support function if needed */
  if (fist_search_bdt(name, auto_generated_functions))
    return TRUE;

  /* produce code for function */
  sprintf(mcpfxn, FIST_MEMCPY_4_FF_WRITE, argv[2]);
  sprintf(fxnbuf, "int\n" \
"%s(const char *filename, void *in, int len)\n{\n" \
"  int ret;\n" \
"  struct _fist_fileformat_%s this_ff;\n" \
"\n" \
"  ret = %s_read_file(filename, &this_ff,\n" \
"			sizeof(struct _fist_fileformat_%s));\n" \
"  if (ret < 0)\n" \
"    return ret;\n" \
"  %s;\n" \
"  ret = %s_write_file(filename, &this_ff,\n" \
"			 sizeof(struct _fist_fileformat_%s));\n" \
"  if (ret < 0)\n" \
"    return ret;\n" \
"  return 0;\n" \
"}\n", name, argv[1], fist_globals.fg_fsname, argv[1], mcpfxn, fist_globals.fg_fsname, argv[1]);

  if (!append_to_bdt(&auto_generated_functions, fxnbuf, name)) {
    sprintf(buf, "cannot store fileformat function %s", name);
    return FALSE;
  }
  /* side effect: append "extern" definition to fist header file */
  sprintf(fxnbuf, "extern int %s(const char *filename, void *in, int len);\n", name);
  if (!ecl_strcat(&auto_generated_externs, fxnbuf)) {
    sprintf(buf, "cannot store fileformat extern definition %s", name);
    return FALSE;
  }
  return TRUE;
}


/****************************************************************************/
/*** DUMMY TEST FUNCTIONS						  ***/
/****************************************************************************/

/* expand fistFoo function */
static int
expand_fistFoo(char *buf, int argc, char *argv[])
{
  sprintf(buf, "fist_foo<<");
  print_func_args(buf, argc, argv);
  strcat(buf, ">>");
  return TRUE;
}


/* expand fistBar function */
static int
expand_fistBar(char *buf, int argc, char *argv[])
{
  sprintf(buf, "fist_bar<<");
  print_func_args(buf, argc, argv);
  strcat(buf, ">>");
  return TRUE;
}
