// Copyright 2017-2018 Alexander Luzgarev
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Numerics;
using System.Text;
namespace MatFileHandler
{
///
/// Class for writing .mat files.
///
public class MatFileWriter
{
private readonly MatFileWriterOptions _options;
///
/// Initializes a new instance of the class with a stream and default options.
///
/// Output stream.
public MatFileWriter(Stream stream)
{
Stream = stream;
_options = MatFileWriterOptions.Default;
}
///
/// Initializes a new instance of the class with a stream.
///
/// Output stream.
/// Options to use for file writing.
public MatFileWriter(Stream stream, MatFileWriterOptions options)
{
Stream = stream;
_options = options;
}
private Stream Stream { get; }
///
/// Writes a .mat file.
///
/// A file to write.
public void Write(IMatFile file)
{
var header = Header.CreateNewHeader();
using var writer = new BinaryWriter(Stream);
WriteHeader(writer, header);
foreach (var variable in file.Variables)
{
switch (_options.UseCompression)
{
case CompressionUsage.Always:
if (Stream.CanSeek)
{
WriteCompressedVariableToSeekableStream(writer, variable);
}
else
{
WriteCompressedVariableToUnseekableStream(writer, variable);
}
break;
case CompressionUsage.Never:
WriteVariable(writer, variable);
break;
default:
throw new NotImplementedException();
}
}
}
private static uint CalculateAdler32Checksum(Stream stream)
{
uint s1 = 1;
uint s2 = 0;
const uint bigPrime = 0xFFF1;
const int bufferSize = 2048;
var buffer = new byte[bufferSize];
while (true)
{
var bytesRead = stream.Read(buffer, 0, bufferSize);
for (var i = 0; i < bytesRead; i++)
{
s1 = (s1 + buffer[i]) % bigPrime;
s2 = (s2 + s1) % bigPrime;
}
if (bytesRead < bufferSize)
{
break;
}
}
return (s2 << 16) | s1;
}
private static void WriteHeader(BinaryWriter writer, Header header)
{
writer.Write(Encoding.UTF8.GetBytes(header.Text));
writer.Write(header.SubsystemDataOffset);
writer.Write((short)header.Version);
writer.Write((short)19785); // Magic number, 'IM'.
}
private static void WriteTag(BinaryWriter writer, Tag tag)
{
writer.Write((int)tag.Type);
writer.Write(tag.Length);
}
private static void WriteShortTag(BinaryWriter writer, Tag tag)
{
writer.Write((short)tag.Type);
writer.Write((short)tag.Length);
}
private static void WriteDataElement(BinaryWriter writer, DataType type, byte[] data)
{
if (data.Length > 4)
{
WriteTag(writer, new Tag(type, data.Length));
writer.Write(data);
var rem = data.Length % 8;
if (rem > 0)
{
var padding = new byte[8 - rem];
writer.Write(padding);
}
}
else
{
WriteShortTag(writer, new Tag(type, data.Length));
writer.Write(data);
if (data.Length < 4)
{
var padding = new byte[4 - data.Length];
writer.Write(padding);
}
}
}
private static void WriteDimensions(BinaryWriter writer, int[] dimensions)
{
var buffer = ConvertToByteArray(dimensions);
WriteDataElement(writer, DataType.MiInt32, buffer);
}
private static byte[] ConvertToByteArray(T[] data)
where T : struct
{
int size;
if (typeof(T) == typeof(sbyte))
{
size = sizeof(sbyte);
}
else if (typeof(T) == typeof(byte))
{
size = sizeof(byte);
}
else if (typeof(T) == typeof(short))
{
size = sizeof(short);
}
else if (typeof(T) == typeof(ushort))
{
size = sizeof(ushort);
}
else if (typeof(T) == typeof(int))
{
size = sizeof(int);
}
else if (typeof(T) == typeof(uint))
{
size = sizeof(uint);
}
else if (typeof(T) == typeof(long))
{
size = sizeof(long);
}
else if (typeof(T) == typeof(ulong))
{
size = sizeof(ulong);
}
else if (typeof(T) == typeof(float))
{
size = sizeof(float);
}
else if (typeof(T) == typeof(double))
{
size = sizeof(double);
}
else if (typeof(T) == typeof(bool))
{
size = sizeof(bool);
}
else
{
throw new NotSupportedException();
}
var buffer = new byte[data.Length * size];
Buffer.BlockCopy(data, 0, buffer, 0, buffer.Length);
return buffer;
}
private static (byte[] real, byte[] imaginary) ConvertToPairOfByteArrays(ComplexOf[] data)
where T : struct
{
return (ConvertToByteArray(data.Select(x => x.Real).ToArray()),
ConvertToByteArray(data.Select(x => x.Imaginary).ToArray()));
}
private static (byte[] real, byte[] imaginary) ConvertToPairOfByteArrays(Complex[] data)
{
return (ConvertToByteArray(data.Select(x => x.Real).ToArray()),
ConvertToByteArray(data.Select(x => x.Imaginary).ToArray()));
}
private static void WriteComplexValues(BinaryWriter writer, DataType type, (byte[] real, byte[] complex) data)
{
WriteDataElement(writer, type, data.real);
WriteDataElement(writer, type, data.complex);
}
private static void WriteArrayFlags(BinaryWriter writer, ArrayFlags flags)
{
var flag = (byte)flags.Variable;
WriteTag(writer, new Tag(DataType.MiUInt32, 8));
writer.Write((byte)flags.Class);
writer.Write(flag);
writer.Write([0, 0, 0, 0, 0, 0]);
}
private static void WriteSparseArrayFlags(BinaryWriter writer, SparseArrayFlags flags)
{
var flag = (byte)flags.ArrayFlags.Variable;
WriteTag(writer, new Tag(DataType.MiUInt32, 8));
writer.Write((byte)flags.ArrayFlags.Class);
writer.Write(flag);
writer.Write([0, 0]);
writer.Write(flags.NzMax);
}
private static void WriteName(BinaryWriter writer, string name)
{
var nameBytes = Encoding.ASCII.GetBytes(name);
WriteDataElement(writer, DataType.MiInt8, nameBytes);
}
private static void WriteNumericalArrayValues(BinaryWriter writer, IArray value)
{
switch (value)
{
case IArrayOf sbyteArray:
WriteDataElement(writer, DataType.MiInt8, ConvertToByteArray(sbyteArray.Data));
break;
case IArrayOf byteArray:
WriteDataElement(writer, DataType.MiUInt8, ConvertToByteArray(byteArray.Data));
break;
case IArrayOf shortArray:
WriteDataElement(writer, DataType.MiInt16, ConvertToByteArray(shortArray.Data));
break;
case IArrayOf ushortArray:
WriteDataElement(writer, DataType.MiUInt16, ConvertToByteArray(ushortArray.Data));
break;
case IArrayOf intArray:
WriteDataElement(writer, DataType.MiInt32, ConvertToByteArray(intArray.Data));
break;
case IArrayOf uintArray:
WriteDataElement(writer, DataType.MiUInt32, ConvertToByteArray(uintArray.Data));
break;
case IArrayOf longArray:
WriteDataElement(writer, DataType.MiInt64, ConvertToByteArray(longArray.Data));
break;
case IArrayOf ulongArray:
WriteDataElement(writer, DataType.MiUInt64, ConvertToByteArray(ulongArray.Data));
break;
case IArrayOf floatArray:
WriteDataElement(writer, DataType.MiSingle, ConvertToByteArray(floatArray.Data));
break;
case IArrayOf doubleArray:
WriteDataElement(writer, DataType.MiDouble, ConvertToByteArray(doubleArray.Data));
break;
case IArrayOf boolArray:
WriteDataElement(writer, DataType.MiUInt8, ConvertToByteArray(boolArray.Data));
break;
case IArrayOf> complexSbyteArray:
WriteComplexValues(writer, DataType.MiInt8, ConvertToPairOfByteArrays(complexSbyteArray.Data));
break;
case IArrayOf> complexByteArray:
WriteComplexValues(writer, DataType.MiUInt8, ConvertToPairOfByteArrays(complexByteArray.Data));
break;
case IArrayOf> complexShortArray:
WriteComplexValues(writer, DataType.MiInt16, ConvertToPairOfByteArrays(complexShortArray.Data));
break;
case IArrayOf> complexUshortArray:
WriteComplexValues(writer, DataType.MiUInt16, ConvertToPairOfByteArrays(complexUshortArray.Data));
break;
case IArrayOf> complexIntArray:
WriteComplexValues(writer, DataType.MiInt32, ConvertToPairOfByteArrays(complexIntArray.Data));
break;
case IArrayOf> complexUintArray:
WriteComplexValues(writer, DataType.MiUInt32, ConvertToPairOfByteArrays(complexUintArray.Data));
break;
case IArrayOf> complexLongArray:
WriteComplexValues(writer, DataType.MiInt64, ConvertToPairOfByteArrays(complexLongArray.Data));
break;
case IArrayOf> complexUlongArray:
WriteComplexValues(writer, DataType.MiUInt64, ConvertToPairOfByteArrays(complexUlongArray.Data));
break;
case IArrayOf> complexFloatArray:
WriteComplexValues(writer, DataType.MiSingle, ConvertToPairOfByteArrays(complexFloatArray.Data));
break;
case IArrayOf complexDoubleArray:
WriteComplexValues(writer, DataType.MiDouble, ConvertToPairOfByteArrays(complexDoubleArray.Data));
break;
default:
throw new NotSupportedException();
}
}
private static ArrayFlags GetArrayFlags(IArray array, bool isGlobal)
{
var variableFlags = isGlobal ? Variable.IsGlobal : 0;
return array switch
{
IArrayOf => new ArrayFlags(ArrayType.MxInt8, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxUInt8, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxInt16, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxUInt16, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxInt32, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxUInt32, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxInt64, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxUInt64, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxSingle, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxDouble, variableFlags),
IArrayOf => new ArrayFlags(ArrayType.MxUInt8, variableFlags | Variable.IsLogical),
IArrayOf> => new ArrayFlags(ArrayType.MxInt8, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxUInt8, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxInt16, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxUInt16, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxInt32, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxUInt32, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxInt64, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxUInt64, variableFlags | Variable.IsComplex),
IArrayOf> => new ArrayFlags(ArrayType.MxSingle, variableFlags | Variable.IsComplex),
IArrayOf => new ArrayFlags(ArrayType.MxDouble, variableFlags | Variable.IsComplex),
IStructureArray => new ArrayFlags(ArrayType.MxStruct, variableFlags),
ICellArray => new ArrayFlags(ArrayType.MxCell, variableFlags),
_ => throw new NotSupportedException(),
};
}
private static SparseArrayFlags GetSparseArrayFlags(ISparseArrayOf array, bool isGlobal, uint nonZero)
where T : struct
{
var flags = GetArrayFlags(array, isGlobal);
return new SparseArrayFlags
{
ArrayFlags = new ArrayFlags
{
Class = ArrayType.MxSparse,
Variable = flags.Variable,
},
NzMax = nonZero,
};
}
private static ArrayFlags GetCharArrayFlags(bool isGlobal)
{
return new ArrayFlags(ArrayType.MxChar, isGlobal ? Variable.IsGlobal : 0);
}
private static void WriteWrappingContents(
BinaryWriter writer,
T array,
Action lengthCalculator,
Action writeContents)
where T : IArray
{
if (array.IsEmpty)
{
WriteTag(writer, new Tag(DataType.MiMatrix, 0));
return;
}
var fakeWriter = new FakeWriter();
lengthCalculator(fakeWriter);
var calculatedLength = fakeWriter.Position;
WriteTag(writer, new Tag(DataType.MiMatrix, calculatedLength));
writeContents(writer);
}
private static void WriteNumericalArrayContents(BinaryWriter writer, IArray array, string name, bool isGlobal)
{
WriteArrayFlags(writer, GetArrayFlags(array, isGlobal));
WriteDimensions(writer, array.Dimensions);
WriteName(writer, name);
WriteNumericalArrayValues(writer, array);
}
private static void WriteNumericalArray(
BinaryWriter writer,
IArray numericalArray,
string name = "",
bool isGlobal = false)
{
WriteWrappingContents(
writer,
numericalArray,
fakeWriter => fakeWriter.WriteNumericalArrayContents(numericalArray, name),
contentsWriter => { WriteNumericalArrayContents(contentsWriter, numericalArray, name, isGlobal); });
}
private static void WriteCharArrayContents(BinaryWriter writer, ICharArray charArray, string name, bool isGlobal)
{
WriteArrayFlags(writer, GetCharArrayFlags(isGlobal));
WriteDimensions(writer, charArray.Dimensions);
WriteName(writer, name);
var array = charArray.String.ToCharArray().Select(c => (ushort)c).ToArray();
WriteDataElement(writer, DataType.MiUtf16, ConvertToByteArray(array));
}
private static void WriteCharArray(BinaryWriter writer, ICharArray charArray, string name, bool isGlobal)
{
WriteWrappingContents(
writer,
charArray,
fakeWriter => fakeWriter.WriteCharArrayContents(charArray, name),
contentsWriter => { WriteCharArrayContents(contentsWriter, charArray, name, isGlobal); });
}
private static void WriteSparseArrayValues(
BinaryWriter writer, int[] rows, int[] columns, T[] data)
where T : struct
{
WriteDataElement(writer, DataType.MiInt32, ConvertToByteArray(rows));
WriteDataElement(writer, DataType.MiInt32, ConvertToByteArray(columns));
if (data is double[])
{
WriteDataElement(writer, DataType.MiDouble, ConvertToByteArray(data));
}
else if (data is Complex[] complexData)
{
WriteDataElement(
writer,
DataType.MiDouble,
ConvertToByteArray(complexData.Select(c => c.Real).ToArray()));
WriteDataElement(
writer,
DataType.MiDouble,
ConvertToByteArray(complexData.Select(c => c.Imaginary).ToArray()));
}
else if (data is bool[] boolData)
{
WriteDataElement(
writer,
DataType.MiUInt8,
ConvertToByteArray(boolData));
}
}
private static (int[] rowIndex, int[] columnIndex, T[] data, uint nonZero) PrepareSparseArrayData(
ISparseArrayOf array)
where T : struct, IEquatable
{
var dict = array.Data;
var keys = dict.Keys.ToArray();
var rowIndexList = new List();
var valuesList = new List();
var numberOfColumns = array.Dimensions[1];
var columnIndex = new int[numberOfColumns + 1];
columnIndex[0] = 0;
for (var column = 0; column < numberOfColumns; column++)
{
var column1 = column;
var thisColumn = keys.Where(pair => pair.column == column1 && !dict[pair].Equals(default));
var thisRow = thisColumn.Select(pair => pair.row).OrderBy(x => x).ToArray();
rowIndexList.AddRange(thisRow);
valuesList.AddRange(thisRow.Select(row => dict[(row, column1)]));
columnIndex[column + 1] = rowIndexList.Count;
}
return (rowIndexList.ToArray(), columnIndex, valuesList.ToArray(), (uint)rowIndexList.Count);
}
private static void WriteSparseArrayContents(
BinaryWriter writer,
ISparseArrayOf array,
string name,
bool isGlobal)
where T : struct, IEquatable
{
(var rows, var columns, var data, var nonZero) = PrepareSparseArrayData(array);
WriteSparseArrayFlags(writer, GetSparseArrayFlags(array, isGlobal, nonZero));
WriteDimensions(writer, array.Dimensions);
WriteName(writer, name);
WriteSparseArrayValues(writer, rows, columns, data);
}
private static void WriteSparseArray(BinaryWriter writer, ISparseArrayOf sparseArray, string name, bool isGlobal)
where T : unmanaged, IEquatable
{
WriteWrappingContents(
writer,
sparseArray,
fakeWriter => fakeWriter.WriteSparseArrayContents(sparseArray, name),
contentsWriter => { WriteSparseArrayContents(contentsWriter, sparseArray, name, isGlobal); });
}
private static void WriteFieldNames(BinaryWriter writer, IEnumerable fieldNames)
{
var fieldNamesArray = fieldNames.Select(name => Encoding.ASCII.GetBytes(name)).ToArray();
var maxFieldName = fieldNamesArray.Max(name => name.Length) + 1;
WriteDataElement(writer, DataType.MiInt32, ConvertToByteArray([maxFieldName]));
var buffer = new byte[fieldNamesArray.Length * maxFieldName];
var startPosition = 0;
foreach (var name in fieldNamesArray)
{
for (var i = 0; i < name.Length; i++)
{
buffer[startPosition + i] = name[i];
}
startPosition += maxFieldName;
}
WriteDataElement(writer, DataType.MiInt8, buffer);
}
private void WriteStructureArrayValues(BinaryWriter writer, IStructureArray array)
{
for (var i = 0; i < array.Count; i++)
{
foreach (var name in array.FieldNames)
{
WriteArray(writer, array[name, i]);
}
}
}
private void WriteStructureArrayContents(BinaryWriter writer, IStructureArray array, string name, bool isGlobal)
{
WriteArrayFlags(writer, GetArrayFlags(array, isGlobal));
WriteDimensions(writer, array.Dimensions);
WriteName(writer, name);
WriteFieldNames(writer, array.FieldNames);
WriteStructureArrayValues(writer, array);
}
private void WriteStructureArray(
BinaryWriter writer,
IStructureArray structureArray,
string name,
bool isGlobal)
{
WriteWrappingContents(
writer,
structureArray,
fakeWriter => fakeWriter.WriteStructureArrayContents(structureArray, name),
contentsWriter => { WriteStructureArrayContents(contentsWriter, structureArray, name, isGlobal); });
}
private void WriteCellArrayValues(BinaryWriter writer, ICellArray array)
{
for (var i = 0; i < array.Count; i++)
{
WriteArray(writer, array[i]);
}
}
private void WriteCellArrayContents(BinaryWriter writer, ICellArray array, string name, bool isGlobal)
{
WriteArrayFlags(writer, GetArrayFlags(array, isGlobal));
WriteDimensions(writer, array.Dimensions);
WriteName(writer, name);
WriteCellArrayValues(writer, array);
}
private void WriteCellArray(BinaryWriter writer, ICellArray cellArray, string name, bool isGlobal)
{
WriteWrappingContents(
writer,
cellArray,
fakeWriter => fakeWriter.WriteCellArrayContents(cellArray, name),
contentsWriter => { WriteCellArrayContents(contentsWriter, cellArray, name, isGlobal); });
}
private void WriteArray(BinaryWriter writer, IArray array, string variableName = "", bool isGlobal = false)
{
switch (array)
{
case ICharArray charArray:
WriteCharArray(writer, charArray, variableName, isGlobal);
break;
case ISparseArrayOf doubleSparseArray:
WriteSparseArray(writer, doubleSparseArray, variableName, isGlobal);
break;
case ISparseArrayOf complexSparseArray:
WriteSparseArray(writer, complexSparseArray, variableName, isGlobal);
break;
case ISparseArrayOf boolSparseArray:
WriteSparseArray(writer, boolSparseArray, variableName, isGlobal);
break;
case ICellArray cellArray:
WriteCellArray(writer, cellArray, variableName, isGlobal);
break;
case IStructureArray structureArray:
WriteStructureArray(writer, structureArray, variableName, isGlobal);
break;
default:
WriteNumericalArray(writer, array, variableName, isGlobal);
break;
}
}
private void WriteVariable(BinaryWriter writer, IVariable variable)
{
WriteArray(writer, variable.Value, variable.Name, variable.IsGlobal);
}
private void WriteCompressedVariableToSeekableStream(BinaryWriter writer, IVariable variable)
{
var position = writer.BaseStream.Position;
WriteTag(writer, new Tag(DataType.MiCompressed, 0));
writer.Write((byte)0x78);
writer.Write((byte)0x9c);
int compressedLength;
uint crc;
var before = writer.BaseStream.Position;
using (var compressionStream = new DeflateStream(writer.BaseStream, CompressionMode.Compress, leaveOpen: true))
{
using var checksumStream = new ChecksumCalculatingStream(compressionStream);
using var internalWriter = new BinaryWriter(checksumStream, Encoding.UTF8, leaveOpen: true);
WriteVariable(internalWriter, variable);
crc = checksumStream.GetCrc();
}
var after = writer.BaseStream.Position;
compressedLength = (int)(after - before) + 6;
writer.Write(BitConverter.GetBytes(crc).Reverse().ToArray());
writer.BaseStream.Position = position;
WriteTag(writer, new Tag(DataType.MiCompressed, compressedLength));
writer.BaseStream.Seek(0, SeekOrigin.End);
}
private void WriteCompressedVariableToUnseekableStream(BinaryWriter writer, IVariable variable)
{
using var compressedStream = new MemoryStream();
uint crc;
using (var originalStream = new MemoryStream())
{
using var internalWriter = new BinaryWriter(originalStream);
WriteVariable(internalWriter, variable);
originalStream.Position = 0;
crc = CalculateAdler32Checksum(originalStream);
originalStream.Position = 0;
using var compressionStream = new DeflateStream(
compressedStream,
CompressionMode.Compress,
leaveOpen: true);
originalStream.CopyTo(compressionStream);
}
compressedStream.Position = 0;
WriteTag(writer, new Tag(DataType.MiCompressed, (int)(compressedStream.Length + 6)));
writer.Write((byte)0x78);
writer.Write((byte)0x9c);
compressedStream.CopyTo(writer.BaseStream);
writer.Write(BitConverter.GetBytes(crc).Reverse().ToArray());
}
}
}