/*
 * @(#) $Header: /tmp/cvs/mysql-admutils/mysql-useradm.c,v 1.12 2011-09-22 12:17:18 geirha Exp $
 *
 * mysql-useradm.c
 *
 */ 

#include <stdio.h>
#include <stdarg.h>
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <pwd.h>
#include <sys/types.h>
#include <unistd.h>
#include <mysql.h>
#include "mysql-admutils.h"

int
usage()
{
  printf("Usage: %s COMMAND [USER]...\n", program_name);
  printf("Create, delete or change password for the USER(s),\n");
  printf("as determined by the COMMAND.  Valid COMMANDs:\n");
  printf("\n");
  printf("  create     create the USER(s).\n");
  printf("  delete     delete the USER(s).\n");
  printf("  passwd     change the MySQL password for the USER(s).\n");
  printf("  show       give information about the USERS(s), or, if\n");
  printf("             none are given, all the users you have.\n");
  printf("\n");
  printf("Report bugs to orakel@ntnu.no\n");
  return 0;
}


int
is_password_set(MYSQL *pmysql, const char *user)
{
  char query[1024], *end;
  MYSQL_RES *res;
  int rows;
  MYSQL_ROW row;
  int check = 0;

  end = strmov(query, "SELECT authentication_string FROM user WHERE user='");
  end += mysql_real_escape_string(pmysql, end, user, strlen(user));
  *end++ = '\'';
  *end = '\0';

  if (mysql_query(pmysql, query))
    dberror(pmysql, "Failed to look up password for user '%s'.", user);
  res = mysql_store_result(pmysql);
  rows = mysql_num_rows(res);

  if (rows > 1)
  {
    mysql_free_result(res);
    return dberror(NULL, "Query for password for user '%s' gave %d results!", 
                   user, rows);
  }
  else if (rows < 1) {
    mysql_free_result(res);
    return -1;
  }

  row = mysql_fetch_row(res);
  check = (row[0] && (strlen(row[0]) > 0));
  mysql_free_result(res);

  return check;
}


int
create(MYSQL *pmysql, const char *user)
{
  char query[1024], *end;

  end = strmov(query, "CREATE USER '");
  end += mysql_real_escape_string(pmysql, end, user, strlen(user));
  end = strmov(end, "'");

  if (mysql_query(pmysql, query))
    return dberror(pmysql, "Failed to create user '%s'.", user);

  return 0;
}


int
delete(MYSQL *pmysql, const char *user)
{
  char query[1024], *end;

  end = strmov(query, "DROP USER '");
  end += mysql_real_escape_string(pmysql, end, user, strlen(user));
  *end++ = '\'';
  *end = '\0';

  if (mysql_query(pmysql, query))
    return dberror(pmysql, "Failed to delete user '%s'.", user);

  return 0;
}


int
passwd(MYSQL *pmysql, const char *user)
{
  char prompt[1024];
  char query[1024], *end;
  char *password, *confirm_password;

  if (is_password_set(pmysql, user) == -1) /* no such mysql user */
    return dberror(NULL, "User '%s' does not exist."
                         " You must create it first.\n", user);

  sprintf(prompt, "New MySQL password for user '%s': ", user);
  password = getpass(prompt);
  confirm_password = strdup(password);
  sprintf(prompt, "Retype new MySQL password for user '%s': ", user);
  password = getpass(prompt);
  if (strcmp(password, confirm_password) != 0)
    {
      free(confirm_password);
      return dberror(NULL, "Sorry, passwords do not match.");
    }
  free(confirm_password);

  end = strmov(query, "ALTER user '");
  end += mysql_real_escape_string(pmysql, end, user, strlen(user));
  end = strmov(end, "' IDENTIFIED BY '");
  end += mysql_real_escape_string(pmysql, end, password, strlen(password));
  *end++ = '\'';
  *end = '\0';
  
  if (mysql_query(pmysql, query))
    return dberror(pmysql, "Failed to set new password for user '%s'.", user);
  if (mysql_affected_rows(pmysql) != 1)
    dberror(NULL, "%d rows affected by password update for user '%s'!", 
            mysql_affected_rows(pmysql), user);

  fprintf(stderr, "Password updated for user '%s'.\n", user);
  return 0;
}


int
show(MYSQL *pmysql, const char *user)
{
  switch (is_password_set(pmysql, user))
    {
    case -1:
      break;
    case 0:
      printf("User '%s': ", user);
      printf("no password set.\n");
      break;
    case 1:
      printf("User '%s': ", user);
      printf("password set.\n");
      break;
    }
  return 0;
}


/* return a list of the user's databases */
char **
list(MYSQL *pmysql)
{
  char query[40960], *end;
  char **usrgroups, **cp;
  MYSQL_RES *res;
  int rows, numgroups;
  MYSQL_ROW row;
  char **userlist;
  int i;
  struct passwd *p;

  p = getpwuid(getuid());
  
  end = strmov(query, "SELECT user FROM user WHERE user='");
  end += mysql_real_escape_string(pmysql, end, p->pw_name, strlen(p->pw_name));
  end = strmov(end, "' OR user LIKE '");
  end += mysql_real_escape_string(pmysql, end, p->pw_name, strlen(p->pw_name));
  end = strmov(end, "\\_%'");

  numgroups = 0;

  usrgroups = get_group_names(&numgroups);
  cp = usrgroups;
  while (cp && *cp) {
    end = strmov(end, " OR user='");
    end += mysql_real_escape_string(pmysql, end, *cp, strlen(*cp));
    end = strmov(end, "' OR user LIKE '");
    end += mysql_real_escape_string(pmysql, end, *cp, strlen(*cp));
    end = strmov(end, "\\_%'");
    free(*cp);
    cp++;
  }
  free(usrgroups);

#ifdef DEBUG
  printf("about to run query: %s\n", query);
#endif

  if (mysql_query(pmysql, query))
    {
      dberror(pmysql, "Failed to look up %s's users.", p->pw_name);
      return NULL;
    }
  res = mysql_store_result(pmysql);
  rows = mysql_num_rows(res);
  userlist = malloc((rows + 1) * sizeof(char *)); 
  if (!userlist)
    {
      dberror(NULL, "%s: Out of memory.\n", program_name);
      return NULL;
    }
  for (i = 0; i < rows; i++)
    if ((row = mysql_fetch_row(res)))
      {
        userlist[i] = strdup(row[0]);
      }

  userlist[i] = NULL;

  mysql_free_result(res);

  return userlist;
}

int
main(int argc, char *argv[])
{
  int i;
  enum { c_create, c_delete, c_passwd, c_show } command;
  MYSQL mysql;
  mysql_init(&mysql);
  char **dblist, **p;
  char user[65];

  program_name = argv[0];

  for (i = 1; i < argc; i++)
    if (strcmp(argv[i], "--help") == 0)
      return usage();
  for (i = 1; i < argc; i++)
    if (strcmp(argv[i], "--version") == 0)
      return version();
  
#ifdef DEBUG
  printf("NB NB NB: denne versjonen av programmet er kompilert med -DDEBUG, og\n");
  printf("kan komme til � skrive ut ekstra informasjon. Dette er ikke farlig,\n");
  printf("og programmet b�r virke som vanlig.\n");
#endif

  if (argc < 2)
    return wrong_use(NULL);

  /* check that the supplied command is valid */

  if (strcmp(argv[1], "create") == 0)
    command = c_create;
  else if (strcmp(argv[1], "delete") == 0)
    command = c_delete;
  else if (strcmp(argv[1], "passwd") == 0)
    command = c_passwd;
  else if (strcmp(argv[1], "show") == 0)
    command = c_show;
  else
    return wrong_use("unrecognized command '%s'.", argv[1]); /* XXX */

  /* all other than show requires at least one USER argument. */
  if ((command != c_show) && (argc < 3))
    return wrong_use(NULL);

/*  read_config_file(); */
  read_toml_file();

  /* connect to the database server and select the mysql database */
  if (!mysql_real_connect(&mysql, db_server, db_user, db_passwd, db_name, 0, NULL, 0))
    return dberror(&mysql, "Cannot connect to database server '%s'.", 
                   db_server);
  if (mysql_select_db(&mysql, db_name))
    return dberror(&mysql, "Cannot select database '%s'.", db_name);

  if ((command == c_show) && (argc == 2))
    {
      dblist = list(&mysql);
      p = dblist;
      while (p && *p)
        {
          show(&mysql, *p);
          free(*p);
          p++;
        }
      free(dblist);
    }
  else {
    /* for each supplied database name, perform the requested action */

    for (i = 2; i < argc; i++) {
      strncpy(user, argv[i], 32);
      user[33] = '\0';
          if (! (owner(user) || member(user)))
            {
              if (command == c_create)
                  dberror(NULL,"Unable to create mysql-user '%s'.\n"

                     "A mysql-user must start with either '%s_' or "
                     "'groupname_', where groupname is a unix  group you are a "
                     "member of. Type \"groups\" to see which groups you are "
                     "a member of.\n", 
                     user, getpwuid(getuid())->pw_name);
              else
                  dberror(NULL, "You are not in charge of mysql-user: '%s'.  Skipping.", user);
              continue;
            }

          switch (command)
            {
            case c_create:
              if(name_isclean(user)) {
                create(&mysql, user);
              } else {
                dberror(NULL, "User name '%s' contains invalid characters.\n"
                  "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", user);
              }
              break;
            case c_delete:
              if(name_isclean(user)) {
                delete(&mysql, user);
              } else {
                dberror(NULL, "User name '%s' contains invalid characters.\n"
                  "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", user);
              }
              break;
            case c_passwd:
              if(name_isclean(user)) {
                passwd(&mysql, user);
              } else {
                dberror(NULL, "User name '%s' contains invalid characters.\n"
                  "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", user);
              }
              break;
            case c_show:
              if(name_isclean(user)) {
                show(&mysql, user);
              } else {
                dberror(NULL, "User name '%s' contains invalid characters.\n"
                  "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", user);
              }
              break;
            default:
              fprintf(stderr, "This point should never be reached.\n");
              exit(1);
            }
        }
    } 

  reload(&mysql);
  mysql_close(&mysql);

  return 0;
}