src/trie.c
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include "list.h"
#include "trie.h"
static struct list_node *merge_tnode_list(struct list_node *list1,
struct list_node *list2) {
struct list_node dummy_head = { NULL, NULL }, *tail = &dummy_head;
while (list1 && list2) {
char chr1 = ((struct trie_node *) list1->data)->chr;
char chr2 = ((struct trie_node *) list2->data)->chr;
struct list_node **min = chr1 <= chr2 ? &list1 : &list2;
struct list_node *next = (*min)->next;
tail = tail->next = *min;
*min = next;
}
tail->next = list1 ? list1 : list2;
return dummy_head.next;
}
struct list_node *merge_sort_tnode(struct list_node *head) {
struct list_node *list1 = head;
if (!list1 || !list1->next)
return list1;
struct list_node *list2 = bisect_list(list1);
return merge_tnode_list(merge_sort_tnode(list1), merge_sort_tnode(list2));
}
static struct list_node *linear_search(const List *list, int value) {
if (!list || list->len == 0)
return NULL;
for (struct list_node *cur = list->head; cur != NULL; cur = cur->next) {
if (((struct trie_node *) cur->data)->chr == value)
return cur;
else if (((struct trie_node *) cur->data)->chr > value)
break;
}
return NULL;
}
static int with_char(void *arg1, void *arg2) {
struct trie_node *tn1 = ((struct list_node *) arg1)->data;
struct trie_node *tn2 = ((struct list_node *) arg2)->data;
if (tn1->chr == tn2->chr)
return 0;
return -1;
}
static bool trie_is_free_node(const struct trie_node *node) {
return node->children->len == 0 ? true : false;
}
static struct trie_node *trie_node_find(const struct trie_node *node, const char *prefix) {
struct trie_node *retnode = (struct trie_node *) node;
for (; *prefix; prefix++) {
struct list_node *child = linear_search(retnode->children, *prefix);
if (!child)
return NULL;
retnode = child->data;
}
return retnode;
}
struct trie_node *trie_create_node(char c) {
struct trie_node *new_node = malloc(sizeof(*new_node));
if (new_node) {
new_node->chr = c;
new_node->data = NULL;
new_node->children = list_create(NULL);
}
return new_node;
}
Trie *trie_create(void) {
Trie *trie = malloc(sizeof(*trie));
trie_init(trie);
return trie;
}
void trie_init(Trie *trie) {
trie->root = trie_create_node(' ');
trie->size = 0;
}
size_t trie_size(const Trie *trie) {
return trie->size;
}
static void *trie_node_insert(struct trie_node *root, const char *key, const void *data, size_t *size) {
struct trie_node *cursor = root;
struct trie_node *cur_node = NULL;
struct list_node *tmp = NULL;
for (; *key; key++) {
tmp = linear_search(cursor->children, *key);
if (!tmp) {
cur_node = trie_create_node(*key);
cursor->children = list_push(cursor->children, cur_node);
cursor->children->head = merge_sort_tnode(cursor->children->head);
} else {
cur_node = tmp->data;
}
cursor = cur_node;
}
if (!cursor->data)
(*size)++;
cursor->data = (void *) data;
return cursor->data;
}
static bool trie_node_recursive_delete(struct trie_node *node, const char *key,
size_t *size, bool *found) {
if (!node)
return false;
if (*key == '\0') {
if (node->data) {
*found = true;
if (node->data) {
free(node->data);
node->data = NULL;
}
free(node->data);
node->data = NULL;
if (*size > 0)
(*size)--;
return trie_is_free_node(node);
}
} else {
struct list_node *cur = linear_search(node->children, *key);
if (!cur)
return false;
struct trie_node *child = cur->data;
if (trie_node_recursive_delete(child, key + 1, size, found)) {
struct trie_node t = {*key, NULL, NULL};
struct list_node tmp = {&t, NULL};
list_remove(node->children, &tmp, with_char);
trie_node_free(child, size);
return (!node->data && trie_is_free_node(node));
}
}
return false;
}
static bool trie_node_search(const struct trie_node *root, const char *key, void **ret) {
struct trie_node *cursor = trie_node_find(root, key);
*ret = (cursor && cursor->data) ? cursor->data : NULL;
return !*ret ? false : true;
}
void *trie_insert(Trie *trie, const char *key, const void *data) {
assert(trie && key);
return trie_node_insert(trie->root, key, data, &trie->size);
}
bool trie_delete(Trie *trie, const char *key) {
assert(trie && key);
bool found = false;
if (strlen(key) > 0)
trie_node_recursive_delete(trie->root, key, &(trie->size), &found);
return found;
}
bool trie_find(const Trie *trie, const char *key, void **ret) {
assert(trie && key);
return trie_node_search(trie->root, key, ret);
}
void trie_prefix_delete(Trie *trie, const char *prefix) {
assert(trie && prefix);
struct trie_node *cursor = trie_node_find(trie->root, prefix);
if (!cursor)
return;
if (cursor->children->len == 0) {
trie_delete(trie, prefix);
return;
}
struct list_node *cur = cursor->children->head;
for (; cur; cur = cur->next) {
trie_node_free(cur->data, &(trie->size));
cur->data = NULL;
}
trie_delete(trie, prefix);
list_clear(cursor->children, 1);
}
static void trie_prefix_map_func2(struct trie_node *node,
void (*mapfunc)(struct trie_node *, void *), void *arg) {
if (trie_is_free_node(node)) {
mapfunc(node, arg);
return;
}
struct list_node *child = node->children->head;
for (; child; child = child->next)
trie_prefix_map_func2(child->data, mapfunc, arg);
mapfunc(node, arg);
}
void trie_prefix_map_tuple(Trie *trie, const char *prefix,
void (*mapfunc)(struct trie_node *, void *), void *arg) {
assert(trie);
if (!prefix) {
trie_prefix_map_func2(trie->root, mapfunc, arg);
} else {
struct trie_node *node = trie_node_find(trie->root, prefix);
if (!node)
return;
trie_prefix_map_func2(node, mapfunc, arg);
}
}
void trie_node_free(struct trie_node *node, size_t *size) {
if (!node)
return;
if (node->children) {
struct list_node *cur = node->children->head;
for (; cur; cur = cur->next)
trie_node_free(cur->data, size);
list_release(node->children, 0);
node->children = NULL;
}
if (node->data) {
free(node->data);
if (*size > 0)
(*size)--;
} else if (node->data) {
free(node->data);
if (*size > 0)
(*size)--;
}
free(node);
}
void trie_release(Trie *trie) {
if (!trie)
return;
trie_node_free(trie->root, &(trie->size));
free(trie);
}