- Safely escape user input with mysql_real_escape_string

- fixed some memory leaks
This commit is contained in:
Geir Hauge 2012-12-03 12:20:20 +00:00
parent 59e7d4782e
commit dc6b93166b
4 changed files with 79 additions and 72 deletions

View File

@ -124,7 +124,7 @@ int
member(char *gr) { member(char *gr) {
char *username; char *username;
char *group; char group[65];
struct group *g; struct group *g;
struct passwd *p; struct passwd *p;
@ -141,11 +141,8 @@ member(char *gr) {
username = p->pw_name; username = p->pw_name;
/* Copy string, but cut at '_' */ /* Copy string, but cut at '_' */
group = strdup(gr); strncpy(group, gr, 64);
if (group == NULL) { group[64] = '\0';
fprintf(stderr, "Couldn't allocate memory. Terminating.");
exit(1);
}
// ettersom man kan få inn gruppenavn med underscore i, må man rett og // ettersom man kan få inn gruppenavn med underscore i, må man rett og
// slett prøve seg fram for å sjekke om det er en gruppe personen er med // slett prøve seg fram for å sjekke om det er en gruppe personen er med
@ -170,9 +167,9 @@ member(char *gr) {
if (g) { if (g) {
/* Check if user is member of group */ /* Check if user is member of group */
while(*g->gr_mem != NULL) { while(*g->gr_mem != NULL) {
char * member = *g->gr_mem; char * member = *g->gr_mem++;
#if DEBUG #if DEBUG
printf("Medlem: %s\n", *g->gr_mem); printf("Medlem: %s\n", member);
#endif #endif
if (strcmp(member,username) == 0) { if (strcmp(member,username) == 0) {
@ -180,8 +177,6 @@ member(char *gr) {
printf("You have access to '%s'\n", gr); printf("You have access to '%s'\n", gr);
#endif #endif
return 1; /* OK */ return 1; /* OK */
} else {
*g->gr_mem++;
} }
} }
#if DEBUG #if DEBUG
@ -222,7 +217,7 @@ char **get_group_names(int *numgroups)
return NULL; return NULL;
} }
grouplist = malloc(sizeof(char *)); grouplist = malloc(33 * sizeof(char *));
real_nr_groups = 0; real_nr_groups = 0;
for (i = 0; i < nr_groups; i++) { for (i = 0; i < nr_groups; i++) {
@ -230,9 +225,7 @@ char **get_group_names(int *numgroups)
/* Go to next grp if it doesn't have a name */ /* Go to next grp if it doesn't have a name */
if (g != NULL) { if (g != NULL) {
grouplist = (char **) realloc(grouplist, (real_nr_groups+2) * sizeof(char *)); grouplist[real_nr_groups++] = strdup(g->gr_name);
grouplist[real_nr_groups] = strdup(g->gr_name);
real_nr_groups++;
} else { } else {
fprintf(stderr, "Omitting gid %d, no entry in group-file.\n", gids[i]); fprintf(stderr, "Omitting gid %d, no entry in group-file.\n", gids[i]);
} }
@ -250,3 +243,10 @@ reload(MYSQL *pmysql)
{ {
return mysql_reload(pmysql); return mysql_reload(pmysql);
} }
/* same as strcpy, but returns a pointer to the end of dest instead of start */
char *strmov(char *dest, const char *src) {
while ((*dest++ = *src++))
;
return dest-1;
}

View File

@ -33,6 +33,9 @@ version(void);
extern int extern int
read_config_file(void); read_config_file(void);
/* same as strcpy, but returns a pointer to the end of dest instead of start */
extern char *strmov(char *, const char *);
#ifdef _mysql_h #ifdef _mysql_h

View File

@ -23,12 +23,6 @@
const char dbname_validchars[] = const char dbname_validchars[] =
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"; "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-";
/* same as strcpy, but returns a pointer to the end of dest instead of start */
char *strmov(char *dest, const char *src) {
while ((*dest++ = *src++))
;
return dest-1;
}
/* Returns true if dbname contains only characters in dbname_validchars. */ /* Returns true if dbname contains only characters in dbname_validchars. */
int dbname_isclean(char* dbname) { int dbname_isclean(char* dbname) {

View File

@ -37,32 +37,50 @@ usage()
int int
is_password_set(MYSQL *pmysql, const char *user) is_password_set(MYSQL *pmysql, const char *user)
{ {
char query[1024]; char query[1024], *end;
MYSQL_RES *res; MYSQL_RES *res;
int rows; int rows;
MYSQL_ROW row; MYSQL_ROW row;
int check = 0;
end = strmov(query, "SELECT password FROM user WHERE user='");
end += mysql_real_escape_string(pmysql, end, user, strlen(user));
*end++ = '\'';
*end = '\0';
sprintf(query, "select password from user where user = '%s'", user);
if (mysql_query(pmysql, query)) if (mysql_query(pmysql, query))
dberror(pmysql, "Failed to look up password for user '%s'.", user); dberror(pmysql, "Failed to look up password for user '%s'.", user);
res = mysql_store_result(pmysql); res = mysql_store_result(pmysql);
rows = mysql_num_rows(res); rows = mysql_num_rows(res);
if (rows == 0)
return -1;
if (rows > 1) if (rows > 1)
{
mysql_free_result(res);
return dberror(NULL, "Query for password for user '%s' gave %d results!", return dberror(NULL, "Query for password for user '%s' gave %d results!",
user, rows); user, rows);
}
else if (rows < 1) {
mysql_free_result(res);
return -1;
}
row = mysql_fetch_row(res); row = mysql_fetch_row(res);
return (row[0] && (strlen(row[0]) > 0)); check = (row[0] && (strlen(row[0]) > 0));
mysql_free_result(res);
return check;
} }
int int
create(MYSQL *pmysql, const char *user) create(MYSQL *pmysql, const char *user)
{ {
char query[1024]; char query[1024], *end;
end = strmov(query, "INSERT INTO user (host, user) VALUES ('%', '");
end += mysql_real_escape_string(pmysql, end, user, strlen(user));
end = strmov(end, "')");
sprintf(query, "insert into user (host, user) values ('%%', '%s')", user);
if (mysql_query(pmysql, query)) if (mysql_query(pmysql, query))
return dberror(pmysql, "Failed to create user '%s'.", user); return dberror(pmysql, "Failed to create user '%s'.", user);
@ -73,9 +91,13 @@ create(MYSQL *pmysql, const char *user)
int int
delete(MYSQL *pmysql, const char *user) delete(MYSQL *pmysql, const char *user)
{ {
char query[1024]; char query[1024], *end;
end = strmov(query, "DELETE FROM user WHERE user='");
end += mysql_real_escape_string(pmysql, end, user, strlen(user));
*end++ = '\'';
*end = '\0';
sprintf(query, "delete from user where user = '%s'", user);
if (mysql_query(pmysql, query)) if (mysql_query(pmysql, query))
return dberror(pmysql, "Failed to delete user '%s'.", user); return dberror(pmysql, "Failed to delete user '%s'.", user);
@ -87,39 +109,12 @@ int
passwd(MYSQL *pmysql, const char *user) passwd(MYSQL *pmysql, const char *user)
{ {
char prompt[1024]; char prompt[1024];
char query[1024]; char query[1024], *end;
char *password, *confirm_password; char *password, *confirm_password;
MYSQL_RES *res;
int rows;
MYSQL_ROW row;
/* if (is_password_set(pmysql, user) == -1) /* no such mysql user */
if (is_password_set(pmysql, user)) return dberror(NULL, "User '%s' does not exist."
{ " You must create it first.\n", user);
sprintf(prompt, "(current) MySQL password for user '%s': ", user);
password = getpass(prompt);
sprintf(query, "select password = password('%s') from user "
"where user = '%s'", password, user);
if (mysql_query(pmysql, query))
return dberror(pmysql, "Failed to check old password for user '%s'.",
user);
res = mysql_store_result(pmysql);
rows = mysql_num_rows(res);
if (rows == 0)
return dberror(NULL, "Check for old password for user '%s' "
"returned empty.", user);
if (rows > 1)
return dberror(NULL, "Check for old password for user '%s' "
"returned more than one row!", user);
row = mysql_fetch_row(res);
if (strcmp(row[0], "1") != 0)
{
fprintf(stderr, "%s: Wrong password entered for user '%s'.\n",
program_name, user);
return 1;
}
}
*/
sprintf(prompt, "New MySQL password for user '%s': ", user); sprintf(prompt, "New MySQL password for user '%s': ", user);
password = getpass(prompt); password = getpass(prompt);
@ -133,11 +128,16 @@ passwd(MYSQL *pmysql, const char *user)
} }
free(confirm_password); free(confirm_password);
sprintf(query, "update user set password = password('%s') " end = strmov(query, "UPDATE user SET password = PASSWORD('");
"where user = '%s'", password, user); end += mysql_real_escape_string(pmysql, end, password, strlen(password));
end = strmov(end, "') WHERE user='");
end += mysql_real_escape_string(pmysql, end, user, strlen(user));
*end++ = '\'';
*end = '\0';
if (mysql_query(pmysql, query)) if (mysql_query(pmysql, query))
return dberror(pmysql, "Failed to set new password for user '%s'.", user); return dberror(pmysql, "Failed to set new password for user '%s'.", user);
if (mysql_affected_rows(pmysql) > 1) if (mysql_affected_rows(pmysql) != 1)
dberror(NULL, "%d rows affected by password update for user '%s'!", dberror(NULL, "%d rows affected by password update for user '%s'!",
mysql_affected_rows(pmysql), user); mysql_affected_rows(pmysql), user);
@ -170,7 +170,7 @@ show(MYSQL *pmysql, const char *user)
char ** char **
list(MYSQL *pmysql) list(MYSQL *pmysql)
{ {
char query[4096]; char query[4096], *end;
char **usrgroups, **cp; char **usrgroups, **cp;
MYSQL_RES *res; MYSQL_RES *res;
int rows, numgroups; int rows, numgroups;
@ -180,18 +180,27 @@ list(MYSQL *pmysql)
struct passwd *p; struct passwd *p;
p = getpwuid(getuid()); p = getpwuid(getuid());
sprintf(query, "select user from user where user='%s' or user like '%s\\_%%'",
p->pw_name, p->pw_name); end = strmov(query, "SELECT user FROM user WHERE user='");
// sprintf(query, "select user from user where user='myhr' or user like 'myhr\\_%%'"); end += mysql_real_escape_string(pmysql, end, p->pw_name, strlen(p->pw_name));
end = strmov(end, "' OR user LIKE '");
end += mysql_real_escape_string(pmysql, end, p->pw_name, strlen(p->pw_name));
end = strmov(end, "\\_%'");
numgroups = 0; numgroups = 0;
usrgroups = get_group_names(&numgroups); usrgroups = get_group_names(&numgroups);
cp = usrgroups; cp = usrgroups;
while (*cp) { while (*cp) {
sprintf(&query[strlen(query)], " or user='%s' or user like '%s\\_%%'", *cp, *cp); end = strmov(end, " OR user='");
end += mysql_real_escape_string(pmysql, end, *cp, strlen(*cp));
end = strmov(end, "' OR user LIKE '");
end += mysql_real_escape_string(pmysql, end, *cp, strlen(*cp));
end = strmov(end, "\\_%'");
free(*cp);
cp++; cp++;
} }
free(usrgroups);
#ifdef DEBUG #ifdef DEBUG
printf("about to run query: %s\n", query); printf("about to run query: %s\n", query);
@ -218,6 +227,8 @@ list(MYSQL *pmysql)
userlist[i] = NULL; userlist[i] = NULL;
mysql_free_result(res);
return userlist; return userlist;
} }
@ -229,7 +240,7 @@ main(int argc, char *argv[])
MYSQL mysql; MYSQL mysql;
mysql_init(&mysql); mysql_init(&mysql);
char **dblist, **p; char **dblist, **p;
char *user; char user[65];
program_name = argv[0]; program_name = argv[0];
@ -262,7 +273,7 @@ main(int argc, char *argv[])
else else
return wrong_use("unrecognized command '%s'.", argv[1]); /* XXX */ return wrong_use("unrecognized command '%s'.", argv[1]); /* XXX */
/* all other than show requires at lease one USER argument. */ /* all other than show requires at least one USER argument. */
if ((command != c_show) && (argc < 3)) if ((command != c_show) && (argc < 3))
return wrong_use(NULL); return wrong_use(NULL);
@ -288,7 +299,6 @@ main(int argc, char *argv[])
free(dblist); free(dblist);
} }
else { else {
user = malloc(64);
/* for each supplied database name, perform the requested action */ /* for each supplied database name, perform the requested action */
for (i = 2; i < argc; i++) { for (i = 2; i < argc; i++) {