Module deepsegment.train

Expand source code
import random
random.seed(42)

import logging

import tensorflow as tf

try:
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
except:
    pass

import seqtag_keras

from seqtag_keras.models import BiLSTMCRF
from seqtag_keras.utils import load_glove
from seqtag_keras.trainer import Trainer
from progressbar import progressbar

from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam

import os
import re
import string
import pickle
import datetime

from seqeval.metrics import f1_score

def bad_sentence_generator(sent, remove_punctuation = None):
    """
        Returns sentence with completely/ partially removed punctuation.

        Parameters:
        sent (str): Sentence on which the punctuation removal operation is performed.
        
        remove_punctuation (int): removing punctuation completely if remove_punctuation ==0 or ==1, removing punctuation till a randomly selected point if remove_punctuation ==2

        Returns:
        str: Sentence with modified punctuation

    """

    if not remove_punctuation:
        remove_punctuation = random.randint(0, 3)

    break_point = random.randint(1, len(sent)-2)
    lower_case = random.randint(0, 2)

    if remove_punctuation <= 1:
        # removing punctuation completely if remove_punctuation ==0 or ==1
        sent = re.sub('['+string.punctuation+']', '', sent)
    
    elif remove_punctuation == 2:
        # removing punctuation till a randomly selected point if remove_punctuation ==2
        if random.randint(0,1) == 0:
            sent = re.sub('['+string.punctuation+']', '', sent[:break_point]) + sent[break_point:]
        # removing punctuation after a randomly selected point if remove_punctuation ==2        
        else:
            sent = sent[:break_point] + re.sub('['+string.punctuation+']', '', sent[break_point:])    
    
    if lower_case <= 1:
        # lower casing sentence 
        sent = sent.lower()
    
    return sent

def generate_data(lines, max_sents_per_example=6, n_examples=1000):
    """
        Generates training data for deepsegment from list of sentences.

        Parameters:
        lines (list): Base sentences for data generation.

        max_sents_per_example (int): Maximum number of sentences to be combined to form a single paragraph.
        
        n_examples (int): Number of training examples to be generated.
        
        Returns:
        list, list: Training data and corresponding labels in BIOU format.

    """
    x, y = [], []
    
    for current_i in progressbar(range(n_examples)):
        x.append([])
        y.append([])

        chosen_lines = []
        for _ in range(random.randint(1, max_sents_per_example)):
            chosen_lines.append(random.choice(lines))
        
        chosen_lines = [bad_sentence_generator(line, remove_punctuation=random.randint(0, 3)) for line in chosen_lines]
        
        for line in chosen_lines:
            words = line.strip().split()
            for word_i, word in enumerate(words):
                x[-1].append(word)
                label = 'O'
                if word_i == 0:
                    label = 'B-sent'
                y[-1].append(label)
    
    return x, y

def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path):
    """
        Trains a deepsegment model.

        Parameters:
        x (list): x generated from generate_data

        y (list): y generated from generate_data

        vx (list): x generated from generate_data

        vy (list): y generated from generate_data


        epochs (int): Max number of epochs.
        
        batch_size (int): batch_size

        save_folder (str): path for the directory where checkpoints should be saved.

        glove_path (str): path to 100d word vectors.
        

    """

    embeddings = load_glove(glove_path)
    
    checkpoint_path = os.path.join(save_folder, 'checkpoint')
    final_weights_path = os.path.join(save_folder, 'final_weights')
    params_path = os.path.join(save_folder, 'params')
    utils_path = os.path.join(save_folder, 'utils')    

    checkpoint = ModelCheckpoint(checkpoint_path, verbose=1, save_best_only=True, mode='max', monitor='f1')
    earlystop = EarlyStopping(patience=3, monitor='f1', mode='max')

    model = seqtag_keras.Sequence(embeddings=embeddings)
    
    model.fit(x, y, x_valid=vx, y_valid=vy, epochs=epochs, batch_size=batch_size, callbacks=[checkpoint, earlystop])

    model.save(final_weights_path, params_path, utils_path)


lang_code_mapping = {
    'english': 'en',
    'french': 'fr',
    'italian': 'it'
}

def finetune(lang_code, x, y, vx, vy, name=None, epochs=5, batch_size=16, lr=0.0001):
    """
        Finetunes an existing deepsegment model.

        Parameters:
        x (list): x generated from generate_data

        y (list): y generated from generate_data

        vx (list): x generated from generate_data

        vy (list): y generated from generate_data

        name (str): name of the fintuned checkpoint. (params, utils will be same for finetuned and base models.)


        epochs (int): Max number of epochs.
        
        batch_size (int): batch_size

        lr (float): initial learning rate.        

    """

    if not name:
        name = str(datetime.datetime.now()).split()
        name = '-'.join(name)
        print('Name not provided. The checkpoint will be named checkpoint_' + name)

    if lang_code in lang_code_mapping:
        lang_code = lang_code_mapping[lang_code]

    home = os.path.expanduser("~")
    lang_path = os.path.join(home, '.DeepSegment_' + lang_code)
    checkpoint_path = os.path.join(lang_path, 'checkpoint')
    utils_path = os.path.join(lang_path, 'utils')
    params_path = os.path.join(lang_path, 'params')

    p = pickle.load(open(utils_path, 'rb'))

    model = BiLSTMCRF(char_vocab_size=p.char_vocab_size,
                          word_vocab_size=p.word_vocab_size,
                          num_labels=p.label_size,
                          word_embedding_dim=100,
                          char_embedding_dim=25,
                          word_lstm_size=100,
                          char_lstm_size=25,
                          fc_dim=100,
                          dropout=0.2,
                          embeddings=None,
                          use_char=True,
                          use_crf=True)
    
    model, loss = model.build()
    model.compile(loss=loss, optimizer=Adam(learning_rate=lr))

    model.load_weights(checkpoint_path)

    temp_vx = p.transform(vx)
    lengths = map(len, vy)
    y_pred = model.predict(temp_vx)
    y_pred = p.inverse_transform(y_pred, lengths)
    orig_score = f1_score(vy, y_pred)
    print('Scores before finetuning: ')
    print(orig_score)
    temp_vx = None
    del temp_vx

    trainer = Trainer(model, preprocessor=p)
    
    checkpoint_path = checkpoint_path + '_' + name
    checkpoint = ModelCheckpoint(checkpoint_path, verbose=1, save_best_only=True, mode='max', monitor='f1')
    earlystop = EarlyStopping(patience=3, monitor='f1', mode='max')

    trainer.train(x, y, vx, vy,
                      epochs=epochs, batch_size=batch_size,
                      verbose=1, callbacks=[checkpoint, earlystop],
                      shuffle=True)

Functions

def bad_sentence_generator(sent, remove_punctuation=None)

Returns sentence with completely/ partially removed punctuation.

Parameters: sent (str): Sentence on which the punctuation removal operation is performed.

remove_punctuation (int): removing punctuation completely if remove_punctuation ==0 or ==1, removing punctuation till a randomly selected point if remove_punctuation ==2

Returns: str: Sentence with modified punctuation

Expand source code
def bad_sentence_generator(sent, remove_punctuation = None):
    """
        Returns sentence with completely/ partially removed punctuation.

        Parameters:
        sent (str): Sentence on which the punctuation removal operation is performed.
        
        remove_punctuation (int): removing punctuation completely if remove_punctuation ==0 or ==1, removing punctuation till a randomly selected point if remove_punctuation ==2

        Returns:
        str: Sentence with modified punctuation

    """

    if not remove_punctuation:
        remove_punctuation = random.randint(0, 3)

    break_point = random.randint(1, len(sent)-2)
    lower_case = random.randint(0, 2)

    if remove_punctuation <= 1:
        # removing punctuation completely if remove_punctuation ==0 or ==1
        sent = re.sub('['+string.punctuation+']', '', sent)
    
    elif remove_punctuation == 2:
        # removing punctuation till a randomly selected point if remove_punctuation ==2
        if random.randint(0,1) == 0:
            sent = re.sub('['+string.punctuation+']', '', sent[:break_point]) + sent[break_point:]
        # removing punctuation after a randomly selected point if remove_punctuation ==2        
        else:
            sent = sent[:break_point] + re.sub('['+string.punctuation+']', '', sent[break_point:])    
    
    if lower_case <= 1:
        # lower casing sentence 
        sent = sent.lower()
    
    return sent
def finetune(lang_code, x, y, vx, vy, name=None, epochs=5, batch_size=16, lr=0.0001)

Finetunes an existing deepsegment model.

Parameters: x (list): x generated from generate_data

y (list): y generated from generate_data

vx (list): x generated from generate_data

vy (list): y generated from generate_data

name (str): name of the fintuned checkpoint. (params, utils will be same for finetuned and base models.)

epochs (int): Max number of epochs.

batch_size (int): batch_size

lr (float): initial learning rate.

Expand source code
def finetune(lang_code, x, y, vx, vy, name=None, epochs=5, batch_size=16, lr=0.0001):
    """
        Finetunes an existing deepsegment model.

        Parameters:
        x (list): x generated from generate_data

        y (list): y generated from generate_data

        vx (list): x generated from generate_data

        vy (list): y generated from generate_data

        name (str): name of the fintuned checkpoint. (params, utils will be same for finetuned and base models.)


        epochs (int): Max number of epochs.
        
        batch_size (int): batch_size

        lr (float): initial learning rate.        

    """

    if not name:
        name = str(datetime.datetime.now()).split()
        name = '-'.join(name)
        print('Name not provided. The checkpoint will be named checkpoint_' + name)

    if lang_code in lang_code_mapping:
        lang_code = lang_code_mapping[lang_code]

    home = os.path.expanduser("~")
    lang_path = os.path.join(home, '.DeepSegment_' + lang_code)
    checkpoint_path = os.path.join(lang_path, 'checkpoint')
    utils_path = os.path.join(lang_path, 'utils')
    params_path = os.path.join(lang_path, 'params')

    p = pickle.load(open(utils_path, 'rb'))

    model = BiLSTMCRF(char_vocab_size=p.char_vocab_size,
                          word_vocab_size=p.word_vocab_size,
                          num_labels=p.label_size,
                          word_embedding_dim=100,
                          char_embedding_dim=25,
                          word_lstm_size=100,
                          char_lstm_size=25,
                          fc_dim=100,
                          dropout=0.2,
                          embeddings=None,
                          use_char=True,
                          use_crf=True)
    
    model, loss = model.build()
    model.compile(loss=loss, optimizer=Adam(learning_rate=lr))

    model.load_weights(checkpoint_path)

    temp_vx = p.transform(vx)
    lengths = map(len, vy)
    y_pred = model.predict(temp_vx)
    y_pred = p.inverse_transform(y_pred, lengths)
    orig_score = f1_score(vy, y_pred)
    print('Scores before finetuning: ')
    print(orig_score)
    temp_vx = None
    del temp_vx

    trainer = Trainer(model, preprocessor=p)
    
    checkpoint_path = checkpoint_path + '_' + name
    checkpoint = ModelCheckpoint(checkpoint_path, verbose=1, save_best_only=True, mode='max', monitor='f1')
    earlystop = EarlyStopping(patience=3, monitor='f1', mode='max')

    trainer.train(x, y, vx, vy,
                      epochs=epochs, batch_size=batch_size,
                      verbose=1, callbacks=[checkpoint, earlystop],
                      shuffle=True)
def generate_data(lines, max_sents_per_example=6, n_examples=1000)

Generates training data for deepsegment from list of sentences.

Parameters: lines (list): Base sentences for data generation.

max_sents_per_example (int): Maximum number of sentences to be combined to form a single paragraph.

n_examples (int): Number of training examples to be generated.

Returns: list, list: Training data and corresponding labels in BIOU format.

Expand source code
def generate_data(lines, max_sents_per_example=6, n_examples=1000):
    """
        Generates training data for deepsegment from list of sentences.

        Parameters:
        lines (list): Base sentences for data generation.

        max_sents_per_example (int): Maximum number of sentences to be combined to form a single paragraph.
        
        n_examples (int): Number of training examples to be generated.
        
        Returns:
        list, list: Training data and corresponding labels in BIOU format.

    """
    x, y = [], []
    
    for current_i in progressbar(range(n_examples)):
        x.append([])
        y.append([])

        chosen_lines = []
        for _ in range(random.randint(1, max_sents_per_example)):
            chosen_lines.append(random.choice(lines))
        
        chosen_lines = [bad_sentence_generator(line, remove_punctuation=random.randint(0, 3)) for line in chosen_lines]
        
        for line in chosen_lines:
            words = line.strip().split()
            for word_i, word in enumerate(words):
                x[-1].append(word)
                label = 'O'
                if word_i == 0:
                    label = 'B-sent'
                y[-1].append(label)
    
    return x, y
def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path)

Trains a deepsegment model.

Parameters: x (list): x generated from generate_data

y (list): y generated from generate_data

vx (list): x generated from generate_data

vy (list): y generated from generate_data

epochs (int): Max number of epochs.

batch_size (int): batch_size

save_folder (str): path for the directory where checkpoints should be saved.

glove_path (str): path to 100d word vectors.

Expand source code
def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path):
    """
        Trains a deepsegment model.

        Parameters:
        x (list): x generated from generate_data

        y (list): y generated from generate_data

        vx (list): x generated from generate_data

        vy (list): y generated from generate_data


        epochs (int): Max number of epochs.
        
        batch_size (int): batch_size

        save_folder (str): path for the directory where checkpoints should be saved.

        glove_path (str): path to 100d word vectors.
        

    """

    embeddings = load_glove(glove_path)
    
    checkpoint_path = os.path.join(save_folder, 'checkpoint')
    final_weights_path = os.path.join(save_folder, 'final_weights')
    params_path = os.path.join(save_folder, 'params')
    utils_path = os.path.join(save_folder, 'utils')    

    checkpoint = ModelCheckpoint(checkpoint_path, verbose=1, save_best_only=True, mode='max', monitor='f1')
    earlystop = EarlyStopping(patience=3, monitor='f1', mode='max')

    model = seqtag_keras.Sequence(embeddings=embeddings)
    
    model.fit(x, y, x_valid=vx, y_valid=vy, epochs=epochs, batch_size=batch_size, callbacks=[checkpoint, earlystop])

    model.save(final_weights_path, params_path, utils_path)