"""Defines objects used to read and write XML for Axiell EMu"""
from __future__ import annotations
import csv
import datetime as dt
import glob
import json
import logging
import mmap
import os
import re
import shutil
import sys
import tempfile
import time
import zipfile
from collections.abc import Callable, Generator
from pathlib import Path
from typing import Any
from warnings import warn
from joblib import Parallel, delayed
from lxml import etree
from .utils import flatten, is_nesttab, is_ref, is_ref_tab, is_tab, strip_tab
logger = logging.getLogger(__name__)
[docs]
class EMuReader:
"""Read records from an EMu XML file into dicts
Parameters
----------
path : str | Path
path to a file or directory
json_path : str or Path
path to a JSON file used to cache records for faster reading
Attributes
----------
path : str | Path
path to a file or directory
json_path : str | Path
path to a JSON file used to cache records for faster reading
files : list
list of file-like objects, each of which is an EMu XML file
module : str
the name of an EMu module
"""
#: EMuConfig : module-wide configuration parameters. Set automatically
#: when an EMuConfig object is created.
#:
#: :meta hide-value:
config = None
#: EMuSchema : info about a specific EMu configuration. Set automatically
#: when an EMuSchema object is created.
#:
#: :meta hide-value:
schema = None
def __init__(
self, path: str | Path, rec_class: Callable = dict, json_path: str | Path = None
):
self.path = str(path)
self._rec_class = rec_class
self.json_path = json_path
self.files = []
self.module = None
self._fields = None
self._get_files()
self._load_schema()
# Private attributes used to display notifications
self._job_start = None
self._job_done = False
self._notify_start = None
self._notify_count = 0
def __iter__(self) -> Generator:
for rec in self.from_file():
yield rec
def __len__(self) -> int:
counts = self.counts()
if isinstance(counts, int):
return counts
raise NotImplementedError("Not implemented when multiple source files provided")
@property
def fields(self) -> dict:
if self._fields is None:
self._fields = self._parse_file_schema()
return self._fields
[docs]
def from_file(self) -> Generator[dict]:
"""Reads data from file, using JSON if possible
Yields
------
dict
EMu record
"""
if not self.json_path:
return self.from_xml()
# If JSON is older than the newest XML file, regenerate it
try:
if os.path.getmtime(self.json_path) < self.files[-1].getmtime():
logger.info("Regenerating JSON (XML is newer)")
self.to_json()
except FileNotFoundError:
logger.info("Generating JSON (JSON not found)")
self.to_json()
return self.from_json()
[docs]
def from_xml(self, start: int = 0, limit: int = None) -> Generator[dict]:
"""Reads data from XML
Parameters
----------
start : int
index of record to start processing
limit : int
number of records to process from start. If omitted, all records
are processed.
Yields
------
dict
EMu record
"""
try:
for filelike in self.files:
logger.info("Reading records from %s", filelike)
self._job_start = None
self._job_done = False
self._notify_start = None
self._notify_count = 0
with filelike.open("rb") as source:
try:
context = etree.iterparse(source, events=["end"], tag="tuple")
for _, element in context:
# Process children of module table only
parent = element.getparent().get("name")
if parent is not None and parent.startswith("e"):
try:
if self._notify_count >= start:
yield self._parse(element)
finally:
element.clear()
# while element.getprevious() is not None:
# del element.getparent()[0]
self._notify_count += 1
if not self._notify_count % 5000:
logger.info(
"Read %s records from %s",
self._notify_count,
filelike,
)
if (
limit is not None
and self._notify_count >= start + limit
):
break
finally:
del context
logger.info("Read %s records total", self._notify_count)
self._job_done = True
finally:
if self._job_start:
self.report_progress()
[docs]
def from_xml_parallel(
self,
callback: Callable,
num_parts: int = 64,
handle_repeated_keys: str = "overwrite",
) -> Any:
"""Reads data from XML in parallel
Experimental. Works by creating temporary copies of the XML file, then reading from
those files in parallel. Seems to work best with a small number of copies.
Parameters
----------
callback : function
function to run on the import file
num_parts : int
number of parts to split the file into
handle_repeated_keys : str
defines how to handle keys that repeat across dicts returned by different jobs.
Must be one of 'combine' (which combines entires in a list), 'keep' (which keeps
the first key found), 'overwrite' (which overwrites the existing key), 'raise'
(which raises a KeyError), r 'sum' (which sums integer values). Ignored if
callback does not return a dict.
Yields
------
Any
result of callback function combined across jobs. If dict, results are combined
into a single dict. If list, results are combined into a single list. If another
type, returns a list of results returned by the callback.
"""
if len(self.files) > 1 or not self.files[0].path.lower().endswith(".xml"):
raise NotImplementedError(
"Not implemented when multiple source files provided"
)
allowed = ("combine", "keep", "overwrite", "raise", "sum")
if handle_repeated_keys not in allowed:
raise ValueError(f"dict_behavior must be one of the following: {allowed}")
# Create temporary directory
tmpdir = tempfile.mkdtemp(prefix="xmu-")
try:
files = []
with open(self.files[0].path, "rb") as f:
with mmap.mmap(
f.fileno(), length=0, access=mmap.ACCESS_READ, offset=0
) as m:
content = m.read()
sep = b"<!-- Row"
records = content.split(sep)
header = records.pop(0)
step = int(len(records) / num_parts) + 1
for i in range(0, len(records), step):
group = records[i : i + step]
tmp = tempfile.NamedTemporaryFile(
"wb", prefix="xmu-", suffix=".xml", dir=tmpdir, delete=False
)
with open(tmp.name, "wb") as f:
f.write(header)
f.write(b"".join((sep + r for r in group)))
if not group[-1].rstrip().endswith(b"</table>"):
f.write(b"\n</table>")
files.append(tmp)
results = Parallel(n_jobs=-1)(delayed(callback)(tmp.name) for tmp in files)
if isinstance(results[0], dict):
result = {}
for result_ in results:
if handle_repeated_keys in ("combine", "sum"):
for key, val in result_.items():
if not isinstance(val, list):
val = [val]
result.setdefault(key, []).extend(val)
elif handle_repeated_keys == "keep":
for key, val in result_.items():
result.setdefault(key, val)
elif handle_repeated_keys == "overwrite":
result.update(result_)
else:
if set(result_) & set(result):
raise KeyError("Duplicate keys returned")
result.update(result_)
if handle_repeated_keys == "sum":
result_ = {}
for key, vals in result.items():
try:
result_[key] = sum(vals)
except TypeError:
result_[key] = vals
result = result_
return result
elif isinstance(results[0], list):
result = []
for result_ in results:
result.extend(result_)
return result
return results
finally:
for tmp in files:
tmp.close()
os.remove(tmp.name)
shutil.rmtree(tmpdir)
[docs]
def from_json(self, chunk_size: int = 2097152) -> Generator[dict]:
"""Reads data from JSON
Parameters
----------
chunk_size : int
size of chunk to use when reading the file
Yields
------
dict
EMu record
"""
logger.info("Reading records from %s", self.json_path)
self._job_start = None
self._job_done = False
self._notify_start = None
self._notify_count = 0
with open(self.json_path, encoding="utf-8") as f:
f.read(1)
add_to_next_chunk = []
while True:
chunk = f.read(chunk_size)
if add_to_next_chunk:
chunk = "".join(add_to_next_chunk[::-1]).lstrip(",") + chunk
add_to_next_chunk = []
if len(chunk) <= 1:
break
while True:
try:
for rec in json.loads(f"[{chunk.lstrip(',')[:-1]}]"):
try:
yield rec
finally:
self._notify_count += 1
if not self._notify_count % 5000:
logger.info(
"Read %s records from %s",
self._notify_count,
self.json_path,
)
break
except json.JSONDecodeError:
chunk, trailer = chunk.rsplit("{", 1)
add_to_next_chunk.append(f"{{{trailer}")
logger.info("Read %s records total", self._notify_count)
self._job_done = True
if self._job_start:
self.report_progress()
[docs]
def to_csv(self, path: str, **kwargs) -> None:
"""Writes records in reader object to CSV
Parameters
----------
path : str
path to write the CSV file
kwargs :
any keyword argument accepted by open()
"""
return write_csv(self, path, **kwargs)
[docs]
def to_json(self, path: str = None, **kwargs) -> None:
"""Writes JSON version of XML to file
Parameters
----------
path : str
path to write JSON
kwargs :
keyword arguments for json.dump()
"""
if path is None:
path = self.json_path
logger.info("Writing records from %s to JSON", self.path)
params = {
"ensure_ascii": False,
"indent": None,
"sort_keys": False,
"separators": (",", ":"),
}
params.update(**kwargs)
sep = params["separators"][0]
with open(path, "w", encoding="utf-8"):
pass
try:
with open(path, "a", encoding="utf-8") as f:
f.write("[")
records = []
for rec in self.from_xml():
records.append(rec)
if len(records) > 1000:
f.write(json.dumps(records, **params)[1:-1] + sep)
records = []
if records:
f.write(json.dumps(records, **params)[1:-1])
f.write("]")
except KeyboardInterrupt as exc:
# Remove the partial JSON file if write is interrupted
os.remove(path)
raise IOError("Conversion to JSON failed") from exc
[docs]
def counts(self) -> dict | int:
"""Counts the number of records in each file
Returns
-------
dict | int
If one file, the number of records. Otherwise a dict of path: counts for
each file.
"""
counts = {}
for filelike in self.files:
with open(filelike.path, mode="r", encoding="utf8") as f:
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as m:
counts[filelike.path] = len(re.findall(rb"\n <tuple>", m.read()))
return counts[list(counts)[0]] if len(counts) == 1 else counts
[docs]
def map_to_select(self):
"""Maps report schema to select parameter for EMu API"""
select = {}
for filelike in self.files:
with filelike.open("r", encoding="utf-8") as f:
lines = []
for line in f:
lines.append(line)
if line.startswith("?>"):
break
content = "".join(lines)
containers = []
for line in (
re.search(r"<?schema\s+(.*?)\?>", content, flags=re.DOTALL)
.group(1)
.splitlines()
):
line = line.strip()
try:
dtype, field = [s.strip() for s in line.rsplit(" ", 1)]
except ValueError:
dtype = None
containers.pop()
else:
if dtype in ("table", "tuple"):
containers.append(field)
elif containers and not containers[-1].startswith(field):
parts = containers[1:] + [field]
select_ = select
while parts:
default = {} if is_ref(parts[0]) else None
select_ = select_.setdefault(parts.pop(0), default)
return select
[docs]
def verify_group(self, path: str | list | tuple, module: str = None) -> None:
"""Verifies that all fields in a group are present in the export
Parameters
----------
path: str
the path to one field in a group
module : str
the name of an EMu module
Raises
------
ValueError
if one or more fields missing
"""
if module is None:
module = self.module
group = tuple(self.schema.get_field_info(module, path).get("GroupFields", []))
missing = set(group) - set(self.fields)
if missing:
raise ValueError(f"Group including '{path}' is missing fields: {missing}")
[docs]
def report_progress(self, by: str = "time", at: int = 5) -> None:
"""Prints progress notification messages when reading a file
Parameters
----------
by : str
either "count" or "time"
at : int
number of seconds (if by time) or number of records (if by count)
"""
if self._notify_start is None:
self._job_start = time.time()
self._notify_start = time.time()
elapsed = time.time() - (
self._job_start if self._job_done else self._notify_start
)
if (
self._job_done
or by == "time"
and elapsed >= at
or by == "count"
and self._notify_count
and not (self._notify_count % at)
):
print(
"{:,} records processed (t{}={:.1f}s)".format(
self._notify_count, "otal" if self._job_done else "", elapsed
)
)
self._notify_start = time.time()
def _parse(self, xml: etree.Element) -> dict:
"""Parses a record from XML
Parameters
----------
xml : lxml.etree.Element
XML representing a single record
Returns
-------
dict
EMu record as a dict. If `rec_class` was specified when creating the
EMuReader object, the record will use that class.
"""
if self._rec_class != dict:
dct = self._rec_class(module=self.module)
else:
dct = self._rec_class()
elements = [(dct, "", xml)]
while elements:
new_elems = []
for obj, parent_name, elem in elements:
for child in elem:
# Add an empty rows to a nested table, which do not contain
# child nodes when exported from EMu
if child is None:
obj.append(None)
continue
# Get field name
name = child.get("name")
if name is None:
name = ""
# Field names for reverse attachments are based on the field
# name in the linking module and may therefore not follow the
# normal EMu naming conventions. These should always be
# tables. Since some attachment fields are already tabs, this
# adds the _tab suffix for consistency. Reverse attachment
# fields must be explicitly defined in .xmurc.
if child.tag == "table" and not is_tab(name):
warn(
f"Renaming reverse attachment field {repr(name)}"
f" to {repr(name + '_tab')}. You must use the latter"
f" value to access this field."
)
name += "_tab"
# Get field text
text = child.text
if text is not None:
text = text.strip()
# Add an atomic field. Omit empty irns.
if child.tag == "atom" and (text or name != "irn"):
try:
obj[name] = text
except TypeError:
obj.append(text)
# Add a reference
elif child.tag == "tuple" and is_ref(name) and not is_tab(name):
obj[name] = {}
new_elems.append((obj[name], name, child))
# Add a table or reference table
elif child.tag == "table" or (child.tag == "tuple" and name):
try:
obj[name] = []
new_elems.append((obj[name], name, child))
except TypeError:
obj.append([])
new_elems.append((obj[-1], name, child))
# Add a row to a reference table
elif (
child.tag == "tuple"
and is_ref_tab(parent_name)
and not is_nesttab(parent_name)
):
obj.append({})
new_elems.append((obj[-1], name, child))
# Add an empty row to an outer nested table
elif (
child.tag == "tuple"
and is_nesttab(parent_name)
and not len(child)
):
new_elems.append((obj, name, [None]))
elif child.tag == "tuple":
new_elems.append((obj, name, child))
elements = new_elems
return dct
def _get_files(self) -> None:
"""Analyzes source files on self.path"""
files = []
zip_file = None
if self.path:
if os.path.isdir(self.path):
files = glob.glob(os.path.join(self.path, "*.xml"))
elif self.path.lower().endswith(".xml"):
files = [self.path]
elif self.path.lower().endswith(".zip"):
zip_file = zipfile.ZipFile(self.path)
files = zipfile.ZipFile(self.path).infolist()
else:
raise IOError(f"Invalid path: {self.path}")
# Order source files from oldest to newest
self.files = [FileLike(obj, zip_file=zip_file) for obj in files]
self.files.sort(key=lambda f: f.getmtime())
# Get the module name from the first table tag
with self.files[0].open(encoding="utf-8") as f:
for line in f:
if line.strip().startswith("<table"):
self.module = line.split("=", 1)[-1].strip('">\r\n')
break
def _load_schema(self) -> "EMuSchema":
"""Tries to load the schema based on the rec_class"""
if self.schema is None:
try:
schema = self._rec_class.schema
if schema is None:
# This will also load the configuration
schema = self._rec_class(module=self.module).schema
except (AttributeError, ValueError):
schema = None
return self.schema
def _parse_file_schema(self) -> tuple[str]:
"""Parses top-level fields from header of EMu XML file
Returns
-------
tuple[str]
tuple with the top-level fields in the schema
"""
fields = {}
for filelike in self.files:
with open(filelike.path, "r", encoding="utf-8") as f:
lines = []
for line in f:
lines.append(line)
if line.startswith("?>"):
break
content = "".join(lines)
containers = []
for line in (
re.search(r"<?schema\s+(.*?)\?>", content, flags=re.DOTALL)
.group(1)
.splitlines()
):
line = line.strip()
try:
dtype, field = [s.strip() for s in line.rsplit(" ", 1)]
except ValueError:
dtype = None
containers.pop()
else:
if dtype in ("table", "tuple"):
containers.append(field)
else:
segments = containers[1:] + [field]
segments = [
s
for i, s in enumerate(segments)
if strip_tab(s) not in {strip_tab(s) for s in segments[:i]}
]
fields[segments[0]] = 1
return tuple(fields)
[docs]
class FileLike:
"""Open text and zip files using the same interface
Parameters
----------
filelike : str | zipfile.ZipInfo
either the path to an XML file or a ZipInfo object
zip_file : zipfile.ZipFile
if filelike is a ZipInfo object, the zip file containing that object
Attributes
----------
path : str
path to file
zip_info : zipfile.ZipInfo
member of a zip archive
zip_file : zipfile.ZipFile
the zip file containing the ZipInfo object
"""
def __init__(
self, filelike: str | zipfile.ZipInfo, zip_file: zipfile.ZipFile = None
):
self.path = None
self.zip_info = None
self.zip_file = None
if zip_file:
self.zip_info = filelike
self.zip_file = zip_file
else:
self.path = os.path.realpath(filelike)
def __str__(self) -> str:
return f'<FileLike name="{self.filename}">'
def __repr__(self) -> str:
return str(self)
@property
def filename(self) -> str:
"""Name of the file-like object"""
return os.path.basename(self.path) if self.path else self.zip_info.filename
[docs]
def open(self, mode: str = "r", encoding: str = None):
"""Opens a file or ZipInfo object"""
if not self.zip_info:
return open(self.path, mode=mode, encoding=encoding)
stream = self.zip_file.open(self.zip_info, mode.rstrip("b"))
if encoding:
return _ByteDecoder(stream, encoding)
return stream
[docs]
def getmtime(self) -> float:
"""Returns last modification timestamp from a file or ZipInfo object"""
try:
return os.path.getmtime(self.path)
except TypeError:
return dt.datetime(*self.zip_info.date_time).timestamp()
class _ByteDecoder:
"""File-like context manager that encodes a binary stream from a zip file"""
def __init__(self, stream, encoding):
self._stream = stream
self._encoding = encoding
def __iter__(self) -> Generator:
for line in self._stream:
yield line.decode(self._encoding)
def __enter__(self):
return self
def __exit__(self, exc_type, exception, traceback):
if exception:
raise exception
self._stream.close()
[docs]
def clean_xml(path: str, encoding: str = "utf-8") -> Path:
"""Removes restricted characters from XML file
Parameters
----------
path : str
path to write the CSV file
encoding : str
encoding for reading/writing XML
Returns
-------
pathlib.Path
path to clean XML file
"""
def _remove_restricted_chars(val: str) -> str:
# From https://stackoverflow.com/a/64570125
illegal_unichrs = [
(0x00, 0x08),
(0x0B, 0x0C),
(0x0E, 0x1F),
(0x7F, 0x84),
(0x86, 0x9F),
(0xFDD0, 0xFDDF),
(0xFFFE, 0xFFFF),
]
if sys.maxunicode >= 0x10000:
illegal_unichrs.extend(
[
(0x1FFFE, 0x1FFFF),
(0x2FFFE, 0x2FFFF),
(0x3FFFE, 0x3FFFF),
(0x4FFFE, 0x4FFFF),
(0x5FFFE, 0x5FFFF),
(0x6FFFE, 0x6FFFF),
(0x7FFFE, 0x7FFFF),
(0x8FFFE, 0x8FFFF),
(0x9FFFE, 0x9FFFF),
(0xAFFFE, 0xAFFFF),
(0xBFFFE, 0xBFFFF),
(0xCFFFE, 0xCFFFF),
(0xDFFFE, 0xDFFFF),
(0xEFFFE, 0xEFFFF),
(0xFFFFE, 0xFFFFF),
(0x10FFFE, 0x10FFFF),
]
)
illegal_ranges = [rf"{chr(low)}-{chr(high)}" for (low, high) in illegal_unichrs]
xml_illegal_character_regex = "[" + "".join(illegal_ranges) + "]"
illegal_xml_chars_re = re.compile(xml_illegal_character_regex)
return illegal_xml_chars_re.sub("", val)
chunks = []
with open(path, encoding=encoding) as f:
while True:
chunk = f.read(65536)
if not chunk:
break
chunks.append(_remove_restricted_chars(chunk))
path = Path(path)
output = path.parent / f"{path.stem}_clean{path.suffix}"
with open(output, "w", encoding=encoding) as f:
f.write("".join(chunks))
return output
[docs]
def write_csv(records: list["EMuRecord"], path: str, **kwargs) -> None:
"""Writes records to CSV
Parameters
----------
records : list-like
list of EMuRecords to be written
path : str
path to write the CSV file
kwargs :
any keyword argument accepted by open()
"""
flattened = [flatten(r) for r in records]
keys = {}
for rec in flattened:
keys.update({k: 1 for k in rec})
# Reorder keys to account for varying grid lengths
grouped = {}
for key in keys:
grouped.setdefault(re.sub(r"\.\d+\.", ".x.", key), []).append(key)
grouped = {
k: sorted(v, key=lambda s: ".".join([s.zfill(8) for s in s.split(".")]))
for k, v in grouped.items()
}
fieldnames = []
for group in grouped.values():
fieldnames.extend(group)
kwargs.setdefault("encoding", "utf-8-sig")
kwargs.setdefault("newline", "")
with open(path, "w", **kwargs) as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(({k: r.get(k) for k in fieldnames} for r in flattened))
[docs]
def write_import(*args, **kwargs) -> None:
"""Writes records to an EMu import file
Alias for write_xml()
"""
return write_xml(*args, **kwargs)
[docs]
def write_xml(records, path, **kwargs) -> None:
"""Writes records to an EMu import file
Parameters
----------
records : list-like
list of EMuRecords to be imported
path : str
path to write the import file
kwargs :
any keyword argument accepted by the to_xml() method of the record class
"""
if not records:
raise ValueError("No records to write to XML file")
root = etree.Element("table")
root.set("name", records[0].module)
root.addprevious(etree.Comment(" Data "))
for i, rec in enumerate(records):
try:
node = rec.copy().to_xml(root, **kwargs)
node.addprevious(etree.Comment(f" Row {i + 1} "))
except AttributeError:
raise ValueError(f"Could not convert record to XML: {rec}")
root.getroottree().write(
path, pretty_print=True, xml_declaration=True, encoding="utf-8"
)
[docs]
def write_group(records: str, path: str, irn: int = None, name: str = None) -> None:
"""Writes an import for the egroups module
Parameters
----------
records : list[EMuRecord]
list of EMuRecords, each of which specifies an irn
path : str
path to write the import file
irn : int
the irn of an existing egroups record (updates only)
name : str
the name of the group
"""
if not irn and not name:
raise ValueError("Must specify at least one of irn or name for a group")
rec = records[0].__class__(
{
"GroupType": "Static",
"Module": records[0].module,
"Keys_tab": [rec["irn"] for rec in records],
},
module="egroups",
)
if irn:
rec["irn"] = irn
if name:
rec["GroupName"] = name
write_import([rec], path)