読者です 読者をやめる 読者になる 読者になる

ツタンラーメンの忘備録

プログラミングや精神疾患、ラーメンについて書いていきます。たぶん。

chainerで音声認識

あとで整形します。
メインコード

from chainer import Link, Chain, ChainList, Variable
import chainer.functions as F
import chainer.links as L
import chainer
from chainer import training
from chainer.training import extensions
import numpy as np
import MFCC

class VoiceAnalysis(Chain):
    def __init__(self, n_units=64, n_out=1):
        super(VoiceAnalysis, self).__init__(
            l1 = L.Linear(None, n_units),
            l2 = L.Linear(None, n_units),
            l3 = L.Linear(None, n_out),
        )
    def __call__(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = F.sigmoid(self.l2(h1))
        o = self.l3(h2)
        return o

def main():
    batchsize = 100
    epoch = 150
    frequency = -1
    out = 'result'
    unit = 1000

    model = L.Classifier(VoiceAnalysis(unit, 3))
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    name_list = ["en_a", "ja_a", "no_voice"]
    train = MFCC.read_ceps(name_list)
    test = train
    train_iter = chainer.iterators.SerialIterator(train, batchsize)
    test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer)
    #print(updater)
    trainer = training.Trainer(updater, (epoch, 'epoch'), out)
    #print(trainer)
    trainer.extend(extensions.Evaluator(test_iter, model))
    trainer.extend(extensions.LogReport())

    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy']))

    trainer.extend(extensions.ProgressBar())
    trainer.run()

    test = MFCC.read_ceps(["test_ja_a", "test_en_a", "test_no_voice"])
    pred_list = []
    for (data, label) in test:
        pred_data = model.predictor(np.array([data]).astype(np.float32)).data
        pred_list.append((pred_data, label))
        print(pred_data, label)

if __name__ == '__main__':
    main()


MFCCの部分。そのへんに落ちているのだと、普通にライブラリが読み込めず断念

from python_speech_features import mfcc
import scipy
from scipy import io
from scipy.io import wavfile
import glob
import numpy as np
import os
from chainer.datasets import tuple_dataset


def write_ceps(ceps,fn):
    base_fn,ext = os.path.splitext(fn)
    data_fn = base_fn + ".ceps"
    np.save(data_fn,ceps)

def create_ceps(fn):
    (rate, X) = io.wavfile.read(fn)
    ceps = mfcc(X, rate)
    isNan = False
    for num in ceps:
        if np.isnan(num[1]):
            isNan = True
    if isNan == False:
        write_ceps(ceps,fn)

def read_ceps(name_list):
    X,y = [],[]
    train = []
    base_dir = os.getcwd()
    for label,name in enumerate(name_list):
        for fn in glob.glob(os.path.join(base_dir,name,"*.ceps.npy")):
            ceps = np.load(fn)
            num_ceps = len(ceps)
            X.append(np.mean(ceps[:],axis=0))
            y.append(label)
            train.append((np.float32(np.mean(ceps[:],axis=0)), np.int32(label))) #ここどう考えてもきもい
    return train