逆文档频率TF-IDF算法

 2022-08-01    0 条评论    16327 浏览

nlp

功能

计算每个文档中每个词在文本中的重要程度,也就是词权重。以浮点值的形式显示。官方名称:逆文档频率。

应用场景

  1. 用于抽取出最能代表当前文本的前n个词。也叫特征选择。
  2. 以文本为单位,得出来每个文本对应词的权重集,也就相当于特征向量。也叫特征抽取。一般用于文本相似度对比。

公式

定义

t:词、d:文本、D文本集

公式描述

原理解释

  • TF:当前词在当前文本中的词频。可以是出现的次数,也可以是出现次数除以当前文本的词数。
  • DF:当前词在整个文本集中,出现过的文本的个数。
  • IDF:为了减小一些通用的不重要的词组,我们使用逆文档频率来度量一个词是否是无代表意义的通用性的词组。就是文档总数|D|除以当前词在整个文本集中,出现过的文本的个数(DF(t,D)得出DF值,再求它的对数log值。

注意:这里求得是自然对数。并且由于使用了对数,所以上下+1,避免出现当词出现在所有文本集中的情况下,log计算结果得0.

Spark公式实现

spark针对机器学习的ml包下面又开源的api用来计算TF-IDF值。

源码地址:http://spark.apache.org/docs/latest/ml-features.html#tf-idf

官网源码有调用实例,包括新版本和老版本的。这里就不举例了。

源码说明

入参
  1. 默认分好词的文本集,格式为dataSet格式。
  2. numFeatures 预先设置好hashing表的长度。必须为2的n次方,官网说明默认为2的18次方。

返回

返回每个词对应的hashing卡槽的位置,每个词在当前文本出现的频率和每个词的词权重TF-IDF值。

代码顺序

先计算TF值,再根据TF值计算IDF值。这里IDFModel模型可以保存下来。一般使用大量标准语料训练模型,保存下来。用来预测接下来的文本特征。

代码原理

spark提供的API使用的是hash表的方式存储每个词的hash值。再对hash值进行计算,以此来提高速度。

源码jar包位置

org.apache.spark.mllib.feature.HashingTF; org.apache.spark.mllib.feature.IDF;

源码

创建TF表:创建了一个指定长度HashTable表,key值为单词的murmur3Hash计算取mod后的值,value为TF值

@Since("1.1.0")
def transform(document: Iterable[_]): Vector = {
  //创建一个指定长度的HashMap<Int,Double>
  val termFrequencies = mutable.HashMap.empty[Int, Double]
  //计算单词的单词频数和hash码
  val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
  val hashFunc: Any => Int = getHashFunction
  //遍历文档的单词
  document.foreach { term =>
    val i = Utils.nonNegativeMod(hashFunc(term), numFeatures) //获得单词的hash码
    termFrequencies.put(i, setTF(i)) //单词频数统计
  }
  //把结果转换成稀疏向量
  Vectors.sparse(numFeatures, termFrequencies.toSeq)
}
//取正余
def nonNegativeMod(x: Int, mod: Int): Int = {
  val rawMod = x % mod
  rawMod + (if (rawMod < 0) mod else 0)
}

计算hash值使用的是Murmur3_x86_32算法,如下:

package com.alg.tfidf;

import org.apache.spark.unsafe.Platform;

/**
 * 32-bit Murmur3 hasher.  This is based on Guava's Murmur3_32HashFunction.
 */
public final class Murmur3_x86_32 {
  private static final int C1 = 0xcc9e2d51;
  private static final int C2 = 0x1b873593;

  private final int seed;

  public Murmur3_x86_32(int seed) {
    this.seed = seed;
  }

  @Override
  public String toString() {
    return "Murmur3_32(seed=" + seed + ")";
  }

  public int hashInt(int input) {
    return hashInt(input, seed);
  }

  public static int hashInt(int input, int seed) {
    int k1 = mixK1(input);
    int h1 = mixH1(seed, k1);

    return fmix(h1, 4);
  }

  public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
    return hashUnsafeWords(base, offset, lengthInBytes, seed);
  }

  public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
    // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
    assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
    int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
    return fmix(h1, lengthInBytes);
  }

  public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
    assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
    int lengthAligned = lengthInBytes - lengthInBytes % 4;
    int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
    for (int i = lengthAligned; i < lengthInBytes; i++) {
      int halfWord = Platform.getByte(base, offset + i);
      int k1 = mixK1(halfWord);
      h1 = mixH1(h1, k1);
    }
    return fmix(h1, lengthInBytes);
  }

  private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
    assert (lengthInBytes % 4 == 0);
    int h1 = seed;
    for (int i = 0; i < lengthInBytes; i += 4) {
      int halfWord = Platform.getInt(base, offset + i);
      int k1 = mixK1(halfWord);
      h1 = mixH1(h1, k1);
    }
    return h1;
  }

  public int hashLong(long input) {
    return hashLong(input, seed);
  }

  public static int hashLong(long input, int seed) {
    int low = (int) input;
    int high = (int) (input >>> 32);

    int k1 = mixK1(low);
    int h1 = mixH1(seed, k1);

    k1 = mixK1(high);
    h1 = mixH1(h1, k1);

    return fmix(h1, 8);
  }

  private static int mixK1(int k1) {
    k1 *= C1;
    k1 = Integer.rotateLeft(k1, 15);
    k1 *= C2;
    return k1;
  }

  private static int mixH1(int h1, int k1) {
    h1 ^= k1;
    h1 = Integer.rotateLeft(h1, 13);
    h1 = h1 * 5 + 0xe6546b64;
    return h1;
  }

  // Finalization mix - force all bits of a hash block to avalanche
  private static int fmix(int h1, int length) {
    h1 ^= length;
    h1 ^= h1 >>> 16;
    h1 *= 0x85ebca6b;
    h1 ^= h1 >>> 13;
    h1 *= 0xc2b2ae35;
    h1 ^= h1 >>> 16;
    return h1;
  }
}

通过这个算法计算出来的hash值,再进行正余计算,得出来的就是每个词对应的mod。 也就是所谓的hash表对应的卡槽位置

图片中:5000是numFeatures的值,第二个int类型的数组就是每个词计算出来的mod值,第三个double数组就是每个词对应的tf-idf值。

纯代码实现TF-IDF

这里写了一个纯java实现的tf-idf工具类。只为了学习代码原理:

import java.util.*;

public class TF_IDF_OLD {
    public static void main(String[] args) {
        List<String> list = getStrings();

        // 计算每个文档的TF值
        List<Map<String, Double>> tfList = new ArrayList<>();
        for (String document : list) {
            tfList.add(tfCalculate(document));
        }

        // 计算每个词的IDF值
        Map<String, Double> idfMap = calculateIDF(list);

        // 输出每个文档的TF-IDF值
        for (int i = 0; i < list.size(); i++) {
            System.out.println("Document " + (i + 1) + " TF-IDF:");
            Map<String, Double> tfidf = tfidfCalculate(list.size(), list, tfList.get(i), idfMap);
            for (Map.Entry<String, Double> entry : tfidf.entrySet()) {
                System.out.println(entry.getKey() + ": " + entry.getValue());
            }
            System.out.println();
        }
    }

    // 获取示例数据
    private static List<String> getStrings() {
        List<String> list = new ArrayList<>();
        list.add("人 士 人物 人士 人氏 人选");
        list.add("人 士 人物 人士 人氏 人选");
        list.add("人 士 人物 人士 人氏 人选");
        list.add("人 士 人物 人士 人氏 人选");
        list.add("人 士 人物 人士 人氏 人选");
        list.add("人 士 人物 人士 人氏 人选");
        list.add("人类 生人 全人类");
        list.add("人手 人员 人口 人丁 口 食指");
        list.add("劳力 劳动力 工作者");
        list.add("匹夫 个人");
        list.add("家伙 东西 货色 厮 崽子 兔崽子 狗崽子 小子 杂种 畜生 混蛋 王八蛋 竖子 鼠辈 小崽子");
        list.add("者 手 匠 客 主 子 家 夫 翁 汉 员 分子 鬼 货 棍 徒");
        list.add("每人 各人 每位");
        list.add("该人 此人");
        list.add("人民 民 国民 公民 平民 黎民 庶 庶民 老百姓 苍生 生灵 生人 布衣 白丁 赤子 氓 群氓 黔首 黎民百姓 庶人 百姓 全民 全员 萌");
        list.add("群众 大众 公众 民众 万众 众生 千夫");
        return list;
    }

    // 计算每个文档的TF值
    public static Map<String, Double> tfCalculate(String document) {
        Map<String, Double> wordCount = new HashMap<>();
        String[] words = document.split(" ");
        for (String word : words) {
            wordCount.put(word, wordCount.getOrDefault(word, 0.0) + 1.0);
        }

        double totalWords = words.length;
        Map<String, Double> tf = new HashMap<>();
        for (Map.Entry<String, Double> entry : wordCount.entrySet()) {
            tf.put(entry.getKey(), entry.getValue() / totalWords);
        }

        return tf;
    }

    // 计算所有文档的IDF值
    public static Map<String, Double> calculateIDF(List<String> documents) {
        Map<String, Integer> docCount = new HashMap<>();
        int totalDocs = documents.size();

        for (String document : documents) {
            Set<String> uniqueWords = new HashSet<>(Arrays.asList(document.split(" ")));
            for (String word : uniqueWords) {
                docCount.put(word, docCount.getOrDefault(word, 0) + 1);
            }
        }

        Map<String, Double> idf = new HashMap<>();
        for (Map.Entry<String, Integer> entry : docCount.entrySet()) {
            String word = entry.getKey();
            int count = entry.getValue();
            idf.put(word, Math.log((double) (totalDocs + 1) / (count + 1)));
        }

        return idf;
    }

    // 计算TF-IDF值
    public static Map<String, Double> tfidfCalculate(int totalDocs, List<String> documents, Map<String, Double> tf, Map<String, Double> idf) {
        Map<String, Double> tfidf = new HashMap<>();
        for (String word : tf.keySet()) {
            // 获取该词在文档中出现的文档数量
            int docFrequency = 0;
            for (String document : documents) {
                if (document.contains(word)) {
                    docFrequency++;
                }
            }

            // 计算IDF值
            double idfValue = idf.getOrDefault(word, 0.0);
            double tfidfValue = tf.get(word) * idfValue;
            tfidf.put(word, tfidfValue);
        }

        return tfidf;
    }
}