egmkang 服务端开发工程师

用Trie实现脏词过滤以及其优化

2014-06-06

公司之前的脏词过滤是用暴力匹配做的, 就是每一个脏词都会到输入里面做strstr匹配到了就过滤. 这样做的效率太低了, 因为有关部门给的脏词库已经有几千条, 甚至快上万条, 肯定不行.

trie, 又名字典树, 主要用来做单词/文字查找, 通过子节点不同选择不同匹配路径, 可以极大的减少匹配的次数, 从而获得相当高的效率. 这是trie sample: trie tree

比如我们有两个输入tota两个单词, 那么便利的过程便是:

root->'t'->'o', 发现'o'可以结尾, 返回
root->'t'->'a', 没有找到'a', 返回

由于算法比较简单, 我手写了一个demo, 运行了一下, 比暴力匹配算法性能提高四十多倍.

但是还不太甘心, 想着怎么做性能优化, 肯定有办法.

  1. UTF-32字符串构造trie, 这样整个树的高度就会急剧减小 因为UTF-8表达汉子, 至少需要3个以上字符, 如果直接用UTF-8构造trie的话, 那么整棵树的高度会增加3倍左右, 匹配次数肯定也会增加3倍, 这样肯定不合算. 当时也想过用UTF-16, 只是我们的代码不支持C++11, 否则也可以这样搞.
  2. 查找子节点优化 昨天想到这个. 绝大部分脏词没有什么共性, 今天在gdb里面看了一下, 果然root下面有2000+个child. 最开始实现的查找是std::find, 果然用std::lower_bound查找之后, 性能又提升了十几倍. 这个应该也可以用hash table来搞, 查找的速度会更快.
  3. 访问内存的优化(还未做) 因为每次查找下一个子节点, 都是需要到另外一块内存上面去查找, 不是CPU友好的. 所以一直想做优化, 但是还没有想到好的方法.

这边是代码:

#ifndef __TRIE_TREE_H__
#define __TRIE_TREE_H__
#include <locale.h>
#include <wchar.h>
#include <string.h>
#include <vector>
#include <algorithm>
#include <string>
#include <iostream>
  
class TrieTreeChild
{
public:
  TrieTreeChild(wchar_t inputChar)
    : current_char_(inputChar)
    , leaf_(false)
  {
  }
  
  TrieTreeChild& AddChild(wchar_t inputChar, bool finished)
  {
    TrieTreeChild inputNode(inputChar);
    std::vector<TrieTreeChild>::iterator iter = std::lower_bound(children_.begin(), children_.end(), inputNode);

    if (iter == children_.end() || iter->current_char_ != inputChar)
    {
      int distance = std::distance(children_.begin(), iter);
      children_.insert(iter, inputNode);
      iter = children_.begin();
      std::advance(iter, distance);
    }
    if (finished)
    {
      iter->leaf_ = true;
    }
    return *iter;
  }
  
  TrieTreeChild* GetChild(wchar_t inputChar)
  {
    TrieTreeChild inputNode(inputChar);
    std::vector<TrieTreeChild>::iterator iter = std::lower_bound(children_.begin(), children_.end(), inputNode);

    if (iter == children_.end() || iter->current_char_ != inputChar)
    {
      return NULL;
    }
    return &*iter;
  }
  
  bool operator == (const TrieTreeChild& node) const
  {
    return this->current_char_ == node.current_char_;
  }
  
  bool operator < (const TrieTreeChild& node) const
  {
    return this->current_char_ < node.current_char_;
  }
  
  bool IsLeaf() const { return leaf_; }
  
  void Reset()
  {
    this->children_.clear();
    this->current_char_ = wchar_t();
  }
private:
  friend class TrieTree;
  wchar_t current_char_;
  bool leaf_;
  std::vector<TrieTreeChild> children_;
};
  
class TrieTree
{
public:
  TrieTree()
    :root_(0)
    , need_set_locale_(true)
  {
    std::string locale = setlocale(LC_ALL, "en_US.utf8");
    std::transform(locale.begin(), locale.end(), locale.begin(), ::tolower);
    if (locale.find("utf8") != std::string::npos || locale.find("utf-8") != std::string::npos)
    {
      need_set_locale_ = false;
      setlocale(LC_ALL, locale.c_str());
    }
  }
  
  void AddMatchString(const char* str)
  {
    char *old_locale = NULL;
    if (need_set_locale_)
    {
      old_locale = setlocale(LC_ALL, "en_US.utf8");
    }
  
    size_t str_len = strlen(str);
    wchar_t wstr[str_len+1];
    size_t wstr_len = str_len;
    wstr[wstr_len] = 0;
  
    ConvertToWString(str, str_len, wstr, wstr_len);
    std::transform(wstr, wstr+wstr_len, wstr, ::tolower);
    AddMatchString(wstr, wstr_len);
  
    if (need_set_locale_)
    {
      setlocale(LC_ALL, old_locale);
    }
  }
  
  void Clear()
  {
    root_.Reset();
  }
  
  bool Match(const char *str)
  {
    char *old_locale = NULL;
    if (need_set_locale_)
    {
      old_locale = setlocale(LC_ALL, "en_US.utf8");
    }
  
    size_t str_len = strlen(str);
    wchar_t wstr[str_len+1];
    size_t wstr_len = str_len;
    wstr[wstr_len] = 0;
  
    ConvertToWString(str, str_len, wstr, wstr_len);
    std::transform(wstr, wstr+wstr_len, wstr, ::tolower);
    bool result = MatchString(wstr, wstr_len, &TrieTree::Break, L'*');
  
    if (need_set_locale_)
    {
      setlocale(LC_ALL, old_locale);
    }
    return result;
  }
  
  bool Transform(char* str, wchar_t mask = L'*')
  {
    char *old_locale = NULL;
    if (need_set_locale_)
    {
      old_locale = setlocale(LC_ALL, "en_US.utf8");
    }
  
    size_t str_len = strlen(str);
    wchar_t wstr[str_len+1];
    size_t wstr_len = str_len;
    wstr[wstr_len] = 0;
  
    ConvertToWString(str, str_len, wstr, wstr_len);
    std::transform(wstr, wstr+wstr_len, wstr, ::tolower);
    bool result = MatchString(wstr, wstr_len, &TrieTree::Replace, L'*');
    if (result)
    {
      ConvertToCString(wstr, wcslen(wstr), str, str_len);
    }
    if (need_set_locale_)
    {
      setlocale(LC_ALL, old_locale);
    }
    return result;
  }
  
private:
  typedef bool (*CallbackPtr)(wchar_t*, size_t, wchar_t);
  
  void AddMatchString(const wchar_t* str, size_t len)
  {
    TrieTreeChild *node = &root_;
    for(size_t i = 0; i < len; ++i)
    {
      node = &node->AddChild(str[i], i == (len-1));
    }
  }
  
  int32_t MatchString(wchar_t *wstr, size_t wstr_len, CallbackPtr ptr, wchar_t mask)
  {
    int32_t match_count = 0;
    for(size_t i = 0; i < wstr_len; ++i)
    {
      int32_t depth = 0;
      wchar_t *wstr_begin = wstr + i;
      TrieTreeChild *node = &root_;
      while(node != NULL)
      {
        node = node->GetChild(wstr_begin[depth]);
        depth = node ? depth+1 : 0;
        if (node != NULL && node->IsLeaf())
        {
          ++match_count;
          bool need_break = (*ptr)(wstr_begin, depth, mask);
          if (need_break)
          {
            return match_count;
          }
        }
      }
    }
    return match_count;
  }
  
  static bool Replace(wchar_t *wstr, size_t wstr_len, wchar_t mask)
  {
    for(size_t i = 0; i < wstr_len; ++i)
    {
      wstr[i] = mask;
    }
    return false;
  }
  
  static bool Break(wchar_t *wstr, size_t wstr_len, wchar_t mask)
  {
    return false;
  }
  
  static void ConvertToWString(const char* str, size_t str_len, wchar_t *wstr, size_t& wstr_len)
  {
    wstr_len = mbstowcs(wstr, str, wstr_len);
  }
  
  static void ConvertToCString(const wchar_t *wstr, size_t wstr_len, char* str, size_t& str_len)
  {
    str_len = wcstombs(str, wstr, str_len);
  }
  
private:
  TrieTreeChild root_;
  bool need_set_locale_;
};
#endif

这边是使用:

#include "trie_tree.h"
#include <iostream>
  
int main()
{
  TrieTree tree;
  tree.AddMatchString("脏词");
  tree.AddMatchString("脏字");
  
  char str[] = "我想说脏词B脏字你又把我怎么样";
  
  std::cout << tree.Match(str) << std::endl; //will output true
  std::cout << tree.Transform(str) << std::endl; //will ouput 2
  std::cout << str << std::endl;  //will output "我想说**B**你又把我怎么样"
  
  return 0;
}

Comments