Source code for alex.corpustools.cued2utt_da_pairs
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This code is PEP8-compliant. See http://www.python.org/dev/peps/pep-0008.
from __future__ import unicode_literals
import argparse
from collections import namedtuple
import os
import os.path
import random
import xml.dom.minidom
from xml.parsers.expat import ExpatError
if __name__ == "__main__":
import autopath
from alex.corpustools.cued import find_logs
from alex.corpustools.text_norm_en import exclude_asr, exclude_by_dict, normalise_text
from alex.utils.various import get_text_from_xml_node
"""
This program extracts CUED semantic annotations from CUED call logs into
a format which can be later processed by cued-sem2ufal-sem.py program.
It scans for 'user-transcription.norm.xml' (or `user-transcription.xml' if the
former is not found in the log directory) to extract the transcriptions and the
semantics.
"""
XML_NORM_FNAME = 'user-transcription.norm.xml'
XML_PLAIN_FNAME = 'user-transcription.xml'
TurnRecord = namedtuple(
'TurnRecord',
['transcription', 'cued_da', 'cued_dahyp', 'asrhyp', 'audio'])
# The following specifies the requirements on individual fields of the data
# record for a user turn.
_field_requirements = {
'transcription': lambda rec: len(rec.transcription) == 1,
'semitran': lambda rec: len(rec.cued_da) == 1,
'semihyp': lambda rec: len(rec.cued_dahyp) != 0,
'asrhyp': lambda rec: len(rec.asrhyp) != 0,
'rec': lambda rec: len(rec.audio) == 1,
}
def _make_rec_filter(fields):
"""Makes a filter function checking all required fields are present."""
fld_filters = tuple(_field_requirements[fld] for fld in fields)
return lambda rec: all(fld_filter(rec) for fld_filter in fld_filters)
def _find_audio_for_turn(uturn, recs):
"""
Finds the recording that belongs to a given user turn by comparing time
spans.
Arguments:
uturn -- the XML element "userturn" for which the corresponding
recording should be found
recs -- a list of XML elements "rec" in the whole log
Returns one of `recs' in case that at least one matches the user turn, else
None.
"""
# Find recordings that were captured at the time the turn was starting.
recs_starttimes = [float(rec.getAttribute('starttime')) for rec in recs]
recs_endtimes = [float(rec.getAttribute('endtime')) for rec in recs]
uturn_starttime = float(uturn.getAttribute('starttime'))
matching_recs = [rec for (rec, start, end)
in zip(recs, recs_starttimes, recs_endtimes)
if start <= uturn_starttime < end]
# If multiple recordings were made (not sure whether this can happen too),
if len(matching_recs) > 1:
# Prefer those that belong to a user turn.
user_recs = [rec for rec in matching_recs
if rec.parentNode.tagName == 'userturn']
if user_recs:
matching_recs = user_recs
# If we still have to choose from multiple recordings,
if len(matching_recs) > 1:
# Choose the longest one. (Any other suggestions?)
return max((float(rec.getAttribute('endtime'))
- float(rec.getAttribute('starttime')),
rec) for rec in matching_recs)[1]
if matching_recs:
return matching_recs[0]
return None
[docs]def extract_trns_sems_from_file(fname, verbose, fields=None, normalise=True,
do_exclude=True, known_words=None,
robust=False):
"""
Extracts transcriptions and their semantic annotation from a CUED call log
file.
Arguments:
fname -- path towards the call log file
verbose -- print lots of output?
fields -- names of fields that should be required for the output.
Field names are strings corresponding to the element names in the
transcription XML format. (default: all five of them)
normalise -- whether to do normalisation on transcriptions
do_exclude -- whether to exclude transcriptions not considered suitable
known_words -- a collection of words. If provided, transcriptions are
excluded which contain other words. If not provided, excluded are
transcriptions that contain any of _excluded_characters. What
"excluded" means depends on whether the transcriptions are required
by being specified in `fields'.
robust -- whether to assign recordings to turns robustly or trust where
they are in the log. This could be useful for older CUED logs
where the elements sometimes escape to another <turn> than they
belong. However, in cases where `robust' leads to finding the
correct recording for the user turn, the log is damaged at other
places too, and the resulting turn record would be misleading.
Therefore, we recommend leaving robust=False.
Returns a list of TurnRecords.
"""
if verbose:
print 'Processing', fname
# Interpret the arguments.
if fields is None:
fields = ("transcription", "semitran", "semihyp", "asrhyp", "rec")
rec_filter = _make_rec_filter(fields)
# Load the file.
doc = xml.dom.minidom.parse(fname)
uturns = doc.getElementsByTagName("userturn")
if robust:
audios = [audio for audio in doc.getElementsByTagName("rec")
if not audio.getAttribute('fname').endswith('_all.wav')]
trns_sems = []
for uturn in uturns:
transcription = uturn.getElementsByTagName("transcription")
cued_da = uturn.getElementsByTagName("semitran")
cued_dahyp = uturn.getElementsByTagName("semihyp")
asrhyp = uturn.getElementsByTagName("asrhyp")
audio = uturn.getElementsByTagName("rec")
# If there was something recognised but nothing recorded, if in the
# robust mode,
if asrhyp and not audio and robust:
# Look for the recording elsewhere.
audio = [_find_audio_for_turn(uturn, audios)]
# This is the first form of the turn record, containing lists of XML
# elements and suited only for internal use.
rec = TurnRecord(transcription, cued_da, cued_dahyp, asrhyp, audio)
if not rec_filter(rec):
# Skip this node, it contains a wrong number of elements of either
# transcription, cued_da, cued_dahyp, asrhyp, or audio.
continue
# XXX Here we take always the first tag having the respective tag name.
transcription = get_text_from_xml_node(
rec.transcription[0]).lower() if rec.transcription else None
asrhyp = get_text_from_xml_node(
rec.asrhyp[0]).lower() if rec.asrhyp else None
# Filter the transcription and the ASR hypothesis through normalisation
# and excluding non-conformant utterances.
if transcription is not None:
if normalise:
transcription = normalise_text(transcription)
if do_exclude:
if known_words is not None:
trs_excluded = exclude_by_dict(transcription, known_words)
else:
trs_excluded = exclude_asr(transcription)
if trs_excluded:
if verbose:
print 'Excluded transcription: "{trs}".'.format(
trs=transcription)
if 'transcription' in fields:
continue
transcription = None
if asrhyp is not None:
if normalise:
asrhyp = normalise_text(asrhyp)
if do_exclude:
if known_words is not None:
asr_excluded = exclude_by_dict(asrhyp, known_words)
else:
asr_excluded = exclude_asr(asrhyp)
if asr_excluded:
if verbose:
print 'Excluded ASR hypothesis: "{asr}".'.format(
asr=asrhyp)
if 'asrhyp' in fields:
continue
asrhyp = None
cued_da = get_text_from_xml_node(
rec.cued_da[0]) if rec.cued_da else None
cued_dahyp = get_text_from_xml_node(
rec.cued_dahyp[0]) if rec.cued_dahyp else None
audio = rec.audio[0].getAttribute(
'fname').strip() if rec.audio else None
# Construct the resulting turn record.
rec = TurnRecord(transcription, cued_da, cued_dahyp, asrhyp, audio)
if verbose:
print "#1 f:", rec.audio
print "#2 t:", rec.transcription, "# s:", rec.cued_da
print "#3 a:", rec.asrhyp, "# s:", rec.cued_dahyp
print
if rec.cued_da or 'semitran' not in fields:
trns_sems.append(rec)
return trns_sems
[docs]def extract_trns_sems(infname, verbose, fields=None, ignore_list_file=None,
do_exclude=True, normalise=True, known_words=None):
"""
Extracts transcriptions and their semantic annotation from a directory
containing CUED call log files.
Arguments:
infname -- either a directory, or a file. In the first case, logs are
looked for below that directory. In the latter case, the file is
read line by line, each line specifying a directory or a glob
determining the call log to include.
verbose -- print lots of output?
fields -- names of fields that should be required for the output.
Field names are strings corresponding to the element names in the
transcription XML format. (default: all five of them)
ignore_list_file -- a file of absolute paths or globs (can be mixed)
specifying logs that should be skipped
normalise -- whether to do normalisation on transcriptions
do_exclude -- whether to exclude transcriptions not considered suitable
known_words -- a collection of words. If provided, transcriptions are
excluded which contain other words. If not provided, excluded are
transcriptions that contain any of _excluded_characters. What
"excluded" means depends on whether the transcriptions are required
by being specified in `fields'.
Returns a list of TurnRecords.
"""
# Interpret the arguments.
if fields is None:
fields = ("transcription", "semitran", "semihyp", "asrhyp", "rec")
# Find all the log files and call the worker function on them in sequel.
log_paths = find_logs(infname, ignore_list_file=ignore_list_file)
log_paths.sort()
turn_recs = list()
for log_path in log_paths:
try:
turn_recs.extend(extract_trns_sems_from_file(
log_path, verbose, fields=fields, normalise=normalise,
do_exclude=do_exclude, known_words=known_words))
except ExpatError:
# This happens for empty XML files, or whenever the XML file cannot
# be parsed.
continue
return turn_recs
[docs]def write_data(outdir, fname, data, tpt):
# TODO Document.
with open(os.path.join(outdir, fname), 'w') as outfile:
for rec in data:
outfile.write(tpt.format(rec=rec))
[docs]def write_trns_sem(outdir, fname, data):
write_data(outdir, fname, data, '{rec.transcription} <=> {rec.cued_da}\n')
[docs]def write_asrhyp_sem(outdir, fname, data):
write_data(outdir, fname, data, '{rec.asrhyp} <=> {rec.cued_da}\n')
[docs]def write_asrhyp_semhyp(outdir, fname, data):
write_data(outdir, fname, data, '{rec.asrhyp} <=> {rec.cued_dahyp}\n')
if __name__ == '__main__':
arger = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="""
This program extracts CUED semantic annotations from CUED call logs into
a format which can be later processed by the `cued-sem2ufal-sem.py'
program.
Note that no normalisation of the transcription or the recognised speech
is performed. Any normalisation of the input text should be done before
the SLU component starts to process the input text.
It scans for 'user-transcription.norm.xml' (or `user-transcription.xml'
if the former is not found in the log directory) to extract the
transcriptions and the semantics.
""")
arger.add_argument('-i', '--indir',
default='./cued_call_logs',
help=('an input directory with CUED call log files '
'(default: ./cued_call_logs)'))
arger.add_argument('-o', '--outdir', default='./cued_data',
help=('an output directory for files with audio and '
'their transcription (default: ./cued_data)'))
arger.add_argument('-v', '--verbose', action="store_true",
help='set verbose output')
arger.add_argument('-f', '--fields', nargs='+',
help=('fields of the XML transcription file that '
'should be extracted'))
args = arger.parse_args()
print 'Extracting semantics from the call logs...'
trns_sems = extract_trns_sems(args.indir, args.verbose, fields=args.fields)
num_turns = len(trns_sems)
# Fix shuffling of the data.
random.seed(0)
random.shuffle(trns_sems)
print "Total number of annotated user turns:", num_turns
annion_parts = {
'train': trns_sems[:int(0.8 * num_turns)],
'dev': trns_sems[int(0.8 * num_turns):int(0.9 * num_turns)],
'test': trns_sems[int(0.9 * num_turns):]}
if args.fields is None:
fields = ("transcription", "semitran", "semihyp", "asrhyp", "rec")
else:
fields = args.fields
if 'transcription' in fields and 'semitran' in fields:
print 'Saving gold transcriptions, gold semantics...'
for part_name, part in annion_parts.iteritems():
write_trns_sem(args.outdir,
'caminfo-{part}.sem'.format(part=part_name),
part)
if 'asrhyp' in fields and 'semitran' in fields:
print 'Saving ASR transcriptions, gold semantics...'
for part_name, part in annion_parts.iteritems():
write_asrhyp_sem(args.outdir,
'caminfo-{part}.asr.sem'.format(part=part_name),
part)
if 'asrhyp' in fields and 'semihyp' in fields:
print 'Saving ASR transcriptions, SLU semantics...'
for part_name, part in annion_parts.iteritems():
write_asrhyp_semhyp(
args.outdir,
'caminfo-{part}.asr.shyp.sem'.format(part=part_name),
part)
print 'Done. Output written to "{outdir}".'.format(outdir=args.outdir)