← 返回首页

NLP数据增强实战:用500条数据训出可用模型

2022-02-18 | 数据增强 NLP 小样本
做一个意图分类项目,客户只给了500条标注数据。直接训练F1只有62%,通过数据增强扩展到5000条后,F1提升到81%。这篇文章记录我尝试过的所有增强方法和效果对比。

问题背景

2022年初接了一个智能客服意图分类项目,需要识别20种用户意图。客户说"我们有很多数据",结果给过来一看:

这数据量,直接训练肯定过拟合。只能靠数据增强了。

方法1:EDA(Easy Data Augmentation)

原理

EDA包含4种简单的文本操作:

实现

import random
import jieba
from synonyms import nearby  # 同义词库

class EDAugmenter:
    """中文EDA数据增强"""
    
    def __init__(self, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1):
        self.alpha_sr = alpha_sr
        self.alpha_ri = alpha_ri
        self.alpha_rs = alpha_rs
        self.p_rd = p_rd
    
    def augment(self, text: str, num_aug: int = 4) -> list:
        """对一条文本生成num_aug条增强数据"""
        words = list(jieba.cut(text))
        augmented = []
        
        for _ in range(num_aug):
            aug_words = words.copy()
            
            # 随机选择一种增强方式
            method = random.choice(['sr', 'ri', 'rs', 'rd'])
            
            if method == 'sr':
                aug_words = self._synonym_replacement(aug_words)
            elif method == 'ri':
                aug_words = self._random_insertion(aug_words)
            elif method == 'rs':
                aug_words = self._random_swap(aug_words)
            else:
                aug_words = self._random_deletion(aug_words)
            
            augmented.append(''.join(aug_words))
        
        return augmented
    
    def _synonym_replacement(self, words: list) -> list:
        """同义词替换"""
        n = max(1, int(len(words) * self.alpha_sr))
        new_words = words.copy()
        
        # 找可以替换的词(有同义词的)
        replaceable = []
        for i, word in enumerate(words):
            syns = self._get_synonyms(word)
            if syns:
                replaceable.append((i, syns))
        
        # 随机替换n个
        random.shuffle(replaceable)
        for i, syns in replaceable[:n]:
            new_words[i] = random.choice(syns)
        
        return new_words
    
    def _get_synonyms(self, word: str) -> list:
        """获取同义词"""
        try:
            syns, scores = nearby(word)
            # 只取相似度>0.7的
            return [s for s, sc in zip(syns, scores) if sc > 0.7 and s != word][:5]
        except:
            return []
    
    def _random_insertion(self, words: list) -> list:
        """随机插入"""
        new_words = words.copy()
        n = max(1, int(len(words) * self.alpha_ri))
        
        for _ in range(n):
            # 随机选一个词,找它的同义词
            word = random.choice(words)
            syns = self._get_synonyms(word)
            if syns:
                # 随机位置插入
                pos = random.randint(0, len(new_words))
                new_words.insert(pos, random.choice(syns))
        
        return new_words
    
    def _random_swap(self, words: list) -> list:
        """随机交换"""
        new_words = words.copy()
        n = max(1, int(len(words) * self.alpha_rs))
        
        for _ in range(n):
            if len(new_words) >= 2:
                i, j = random.sample(range(len(new_words)), 2)
                new_words[i], new_words[j] = new_words[j], new_words[i]
        
        return new_words
    
    def _random_deletion(self, words: list) -> list:
        """随机删除"""
        if len(words) == 1:
            return words
        return [w for w in words if random.random() > self.p_rd]

# 使用
eda = EDAugmenter()
text = "我想查询一下订单的物流信息"
augmented = eda.augment(text, num_aug=4)
# ['我想查询订单的快递信息', '我想一下查询订单的物流信息', ...]
⚠️ EDA的问题: 中文同义词库质量参差不齐,经常替换出不通顺的句子。后来我改用了词向量找近义词,效果好一些。

方法2:回译(Back Translation)

原理

中文 → 英文 → 中文,利用翻译模型的"改写"能力生成新句子。

from transformers import MarianMTModel, MarianTokenizer

class BackTranslator:
    """回译数据增强"""
    
    def __init__(self):
        # 中→英
        self.zh2en_model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
        self.zh2en_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
        
        # 英→中
        self.en2zh_model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-zh')
        self.en2zh_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-zh')
    
    def translate(self, text: str, model, tokenizer) -> str:
        inputs = tokenizer(text, return_tensors="pt", padding=True)
        outputs = model.generate(**inputs)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    def augment(self, text: str) -> str:
        # 中文 → 英文
        en_text = self.translate(text, self.zh2en_model, self.zh2en_tokenizer)
        # 英文 → 中文
        back_text = self.translate(en_text, self.en2zh_model, self.en2zh_tokenizer)
        return back_text

# 使用
bt = BackTranslator()
text = "我想查询订单的物流状态"
augmented = bt.augment(text)
# "我想检查订单的物流状况"

多语言回译

def multi_back_translate(text: str, languages: list = ['en', 'ja', 'de']) -> list:
    """
    通过多种语言回译,生成更多样化的数据
    中文→英文→中文
    中文→日文→中文
    中文→德文→中文
    """
    results = []
    for lang in languages:
        # 加载对应的模型
        zh2lang = load_model(f'zh-{lang}')
        lang2zh = load_model(f'{lang}-zh')
        
        # 回译
        intermediate = translate(text, zh2lang)
        back = translate(intermediate, lang2zh)
        
        if back != text:  # 确保有变化
            results.append(back)
    
    return results
💡 回译的优势: 生成的句子语法正确、语义保持。缺点是慢,而且需要GPU。

方法3:MLM填充

原理

用BERT的MLM能力,随机mask一些词,让模型预测替换。

from transformers import BertTokenizer, BertForMaskedLM
import torch
import random

class MLMAugmenter:
    """基于MLM的数据增强"""
    
    def __init__(self, model_name='bert-base-chinese'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertForMaskedLM.from_pretrained(model_name)
        self.model.eval()
    
    def augment(self, text: str, mask_ratio: float = 0.15) -> str:
        tokens = self.tokenizer.tokenize(text)
        
        # 随机选择要mask的位置
        n_mask = max(1, int(len(tokens) * mask_ratio))
        mask_positions = random.sample(range(len(tokens)), min(n_mask, len(tokens)))
        
        # 替换为[MASK]
        masked_tokens = tokens.copy()
        for pos in mask_positions:
            masked_tokens[pos] = '[MASK]'
        
        # 编码
        masked_text = ''.join(masked_tokens).replace('##', '')
        inputs = self.tokenizer(masked_text, return_tensors='pt')
        
        # 预测
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = outputs.logits
        
        # 替换mask位置
        input_ids = inputs['input_ids'][0].tolist()
        for i, token_id in enumerate(input_ids):
            if token_id == self.tokenizer.mask_token_id:
                # 取top-5随机选一个
                top_tokens = torch.topk(predictions[0, i], 5).indices.tolist()
                new_token_id = random.choice(top_tokens)
                input_ids[i] = new_token_id
        
        # 解码
        augmented = self.tokenizer.decode(input_ids, skip_special_tokens=True)
        return augmented.replace(' ', '')

# 使用
mlm_aug = MLMAugmenter()
text = "请帮我查询快递到哪了"
augmented = mlm_aug.augment(text)
# "请帮我查看快递到哪了"

方法4:GPT生成

Few-shot生成

import openai

def gpt_augment(text: str, label: str, num: int = 5) -> list:
    """用GPT生成同类别的新样本"""
    
    prompt = f"""你是一个数据增强助手。请根据给定的样本,生成{num}条语义相似但表达不同的句子。

类别: {label}
原始样本: {text}

要求:
1. 保持语义和意图不变
2. 使用不同的词汇和句式
3. 每行一条,不要编号

生成的句子:"""

    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.8,  # 适当提高多样性
        max_tokens=200
    )
    
    # 解析结果
    generated = response.choices[0].message.content.strip().split('\n')
    return [g.strip() for g in generated if g.strip()]

# 使用
text = "我想退货"
label = "退货退款"
augmented = gpt_augment(text, label, num=5)
# ['我要申请退货', '能帮我办理退货吗', '这个商品我想退', ...]
⚠️ GPT增强的成本: 每条约$0.001,500条原始数据扩10倍需要$5。效果最好但成本也最高。

方法5:实体替换

import random

class EntityReplacer:
    """基于实体替换的数据增强"""
    
    def __init__(self):
        # 预定义实体库
        self.entity_dict = {
            'PRODUCT': ['手机', '电脑', '衣服', '鞋子', '耳机', '平板'],
            'TIME': ['昨天', '今天', '上周', '前天', '刚才'],
            'NUMBER': ['一个', '两个', '三件', '五个'],
            'CITY': ['北京', '上海', '广州', '深圳', '杭州']
        }
    
    def augment(self, text: str, entities: list) -> list:
        """
        entities: [(entity_text, entity_type), ...]
        例如: [('手机', 'PRODUCT'), ('昨天', 'TIME')]
        """
        results = []
        
        for entity_text, entity_type in entities:
            if entity_type in self.entity_dict:
                # 获取同类型的其他实体
                replacements = [e for e in self.entity_dict[entity_type] 
                               if e != entity_text]
                
                # 生成替换后的文本
                for rep in replacements[:3]:  # 每个实体最多替换3个
                    new_text = text.replace(entity_text, rep)
                    if new_text != text:
                        results.append(new_text)
        
        return results

# 使用(需要先做NER识别实体)
replacer = EntityReplacer()
text = "我昨天买的手机想退货"
entities = [('昨天', 'TIME'), ('手机', 'PRODUCT')]
augmented = replacer.augment(text, entities)
# ['我今天买的手机想退货', '我昨天买的电脑想退货', ...]

效果对比

实验设置

结果

方法 增强后数据量 F1 提升 耗时
无增强(baseline) 500 62.3% - -
EDA 2500 71.5% +9.2% 5min
回译(英文) 1000 73.8% +11.5% 30min
MLM填充 2500 72.1% +9.8% 20min
实体替换 1500 68.7% +6.4% 2min
GPT生成 3000 78.2% +15.9% 15min
组合(EDA+回译+GPT) 5000 81.4% +19.1% 50min

最佳实践

def augment_dataset(samples: list, target_size: int = 5000) -> list:
    """
    综合数据增强流程
    
    策略:
    1. 每条数据先用EDA生成3条(快速扩充)
    2. 再用回译生成1条(高质量)
    3. 少样本类别用GPT额外补充
    """
    augmented = []
    eda = EDAugmenter()
    bt = BackTranslator()
    
    # 统计每个类别的样本数
    class_counts = Counter(s['label'] for s in samples)
    avg_count = sum(class_counts.values()) / len(class_counts)
    
    for sample in samples:
        text, label = sample['text'], sample['label']
        augmented.append(sample)  # 保留原始数据
        
        # EDA增强
        eda_samples = eda.augment(text, num_aug=3)
        for aug_text in eda_samples:
            augmented.append({'text': aug_text, 'label': label})
        
        # 回译增强
        bt_text = bt.augment(text)
        if bt_text != text:
            augmented.append({'text': bt_text, 'label': label})
        
        # 少样本类别额外用GPT增强
        if class_counts[label] < avg_count * 0.5:
            gpt_samples = gpt_augment(text, label, num=3)
            for gpt_text in gpt_samples:
                augmented.append({'text': gpt_text, 'label': label})
    
    # 去重
    seen = set()
    unique = []
    for s in augmented:
        if s['text'] not in seen:
            seen.add(s['text'])
            unique.append(s)
    
    return unique[:target_size]
✅ 核心经验:
  1. 多方法组合效果最好:单一方法容易引入偏差
  2. 质量优先于数量:不是越多越好,太多噪声会影响模型
  3. 少样本类别重点增强:解决类别不平衡问题
  4. 一定要去重:增强后可能有大量重复

失败教训

⚠️ 这些坑我踩过:
  • EDA参数太大:alpha设成0.3,生成的句子完全不通顺
  • 回译用机翻API:百度/Google翻译API结果太稳定,增强效果差
  • 增强数据不人工检查:混入了很多错误数据,模型学歪了
  • 只增强训练集:验证集也增强了,导致评估结果虚高

更新记录:
2022-02-18: 初版发布
2022-08-10: 补充GPT增强方法
2023-05-20: 更新效果对比数据