﻿/*
cap2mod is a simple mod kit for the game Capitalism 2. dbmod allows editing
of the 1STD.SET database. It was written by Adam Milazzo on February 28th,
2013. This source code is released into the public domain.

For more information, feel free to contact me. http://www.adammil.net/
*/

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Text;
using AdamMil.IO;
using AdamMil.Utilities;

// this code is based on my reverse-engineering efforts described at
// http://www.adammil.net/blog/v112_Reverse_Engineering_Capitalism_2.html

namespace Cap2Mod.DBMod
{

enum ColumnType { String = 'C', Number = 'N', Bool = 'L' }

#region Database
sealed class Database
{
  public readonly List<Table> Tables = new List<Table>();

  public void Load(BinaryReader reader)
  {
    Tables.Clear();

    int tableCount = reader.ReadUInt16();
    string[] tableNames = new string[tableCount+1];
    int[] offsets = new int[tableCount+1];
    for(int i=0; i<tableCount+1; i++)
    {
      tableNames[i] = ReadNulTerminatedString(reader, 9);
      offsets[i]    = reader.ReadInt32();
    }

    for(int i=0; i<tableCount; i++)
    {
      reader.Skip(offsets[i]-(int)reader.Position);
      Table table = new Table(tableNames[i]);
      table.Load(reader);
      Tables.Add(table);
    }
  }

  public void Save(BinaryWriter writer)
  {
    writer.Write((ushort)Tables.Count);

    int position = (Tables.Count+1)*13 + 2;
    foreach(Table table in Tables)
    {
      WriteNulTerminatedString(writer, table.Name, 9);
      writer.Write(position);
      position += table.GetSize();
    }
    writer.WriteZeros(9);
    writer.Write(position);

    foreach(Table table in Tables) table.Save(writer);
  }

  internal static string ReadNulTerminatedString(BinaryReader reader, int maxLength)
  {
    byte[] bytes = reader.ReadBytes(maxLength);
    int eos = Array.IndexOf(bytes, (byte)0);
    return Encoding.ASCII.GetString(bytes, 0, eos == -1 ? maxLength : eos);
  }

  internal static void WriteNulTerminatedString(BinaryWriter writer, string str, int maxLength)
  {
    byte[] bytes = Encoding.ASCII.GetBytes(str);
    writer.Write(bytes);
    writer.WriteZeros(maxLength - bytes.Length);
  }

  internal static void ValidateName(string name, int maxLength)
  {
    if(name == null) throw new ArgumentNullException();
    byte[] bytes = Encoding.ASCII.GetBytes(name);
    if(bytes.Length > maxLength) throw new ArgumentException("The name is too long.");
  }
}
#endregion

#region Table
sealed class Table
{
  public Table(string name)
  {
    Header = 0x1b9b0103;
    Name   = name;
  }

  public readonly List<Column> Columns = new List<Column>();
  public readonly List<Row> Rows = new List<Row>();

  public string Name
  {
    get { return _name; }
    set
    {
      Database.ValidateName(value, 8);
      _name = value;
    }
  }

  public uint Header { get; set; }

  public bool HasColumn(string name)
  {
    for(int i=0; i<Columns.Count; i++)
    {
      if(Columns[i].Name == name) return true;
    }
    return false;
  }

  internal int GetSize()
  {
    return 32 + Columns.Count*32 + 1 + GetRowLength()*Rows.Count + 1;
  }

  internal void Load(BinaryReader reader)
  {
    Columns.Clear();
    Rows.Clear();

    Header = reader.ReadUInt32();

    int rowCount = reader.ReadInt32(), columnCount = reader.ReadUInt16()/32-1, rowLength = reader.ReadUInt16();
    reader.Skip(20);

    while(columnCount-- != 0)
    {
      Column column = new Column();
      column.Load(reader);
      Columns.Add(column);
    }

    reader.Skip(1);

    while(rowCount-- != 0)
    {
      Row row = new Row();
      row.Load(Columns, reader, rowLength);
      Rows.Add(row);
    }

    //reader.Skip(1); // some tables in Capitalism Lab don't end with the correct byte, so don't read it
  }

  internal void Save(BinaryWriter writer)
  {
    Validate();
    int rowLength = GetRowLength();

    writer.Write(Header);
    writer.Write(Rows.Count);
    writer.Write((ushort)(Columns.Count*32+33));
    writer.Write((ushort)rowLength);
    writer.WriteZeros(20);

    foreach(Column column in Columns) column.Save(writer);
    writer.Write((byte)0x0D);

    byte[] rowBuffer = new byte[rowLength];
    foreach(Row row in Rows)
    {
      row.Save(Columns, rowBuffer);
      writer.Write(rowBuffer);
    }
    writer.Write((byte)0x1A);
  }

  int GetRowLength()
  {
    int length = 0;
    foreach(Column column in Columns) length = Math.Max(length, column.Offset+column.Length);
    return length;
  }

  void Validate()
  {
    HashSet<string> columnNames = new HashSet<string>();
    foreach(Column column in Columns)
    {
      if(column.Name == null) throw new InvalidOperationException("A column name was null.");
      if(!columnNames.Add(column.Name)) throw new InvalidOperationException("The column " + column.Name + " was duplicated.");
    }

    Column[] columns = Columns.ToArray();
    Array.Sort(columns, (a, b) => a.Offset.CompareTo(b.Offset));
    for(int i=0; i<columns.Length-1; i++)
    {
      int end = columns[i].Offset + columns[i].Length;
      if(end > columns[i+1].Offset)
      {
        throw new InvalidOperationException("Column " + columns[i].Name + " overlaps column " + columns[i+1].Name);
      }
    }
  }

  string _name;
}
#endregion

#region Column
sealed class Column
{
  public int Decimals
  {
    get { return _decimals; }
    set
    {
      if(value < 0 || value >= 255) throw new ArgumentOutOfRangeException();
      _decimals = value;
    }
  }

  public int Length
  {
    get { return _length; }
    set
    {
      if(value < 0 || value >= 255) throw new ArgumentOutOfRangeException();
      _length = value;
    }
  }

  public int Offset
  {
    get { return _offset; }
    set
    {
      if(value < 0 || value >= short.MaxValue) throw new ArgumentOutOfRangeException();
      _offset = value;
    }
  }

  public string Name
  {
    get { return _name; }
    set
    {
      Database.ValidateName(value, 10);
      _name = value;
    }
  }

  public ColumnType Type { get; set; }

  internal void Load(BinaryReader reader)
  {
    Name     = Database.ReadNulTerminatedString(reader, 11);
    Type     = (ColumnType)reader.ReadByte();
    Offset   = reader.ReadInt32();
    Length   = reader.ReadByte();
    Decimals = reader.ReadByte();
    reader.Skip(14);
  }

  internal void Save(BinaryWriter writer)
  {
    Database.WriteNulTerminatedString(writer, Name, 11);
    writer.Write((byte)Type);
    writer.Write(Offset);
    writer.Write((byte)Length);
    writer.Write((byte)Decimals);
    writer.WriteZeros(14);
  }

  string _name;
  int _decimals, _length, _offset;
}
#endregion

#region Row
sealed class Row
{
  public object this[string columnName]
  {
    get { return values[columnName]; }
    set { values[columnName] = value; }
  }

  internal void Load(IEnumerable<Column> columns, BinaryReader reader, int rowLength)
  {
    rowData = reader.ReadBytes(rowLength);
    foreach(Column column in columns)
    {
      int end = column.Length-1;
      while(end >= 0 && rowData[column.Offset+end] == (byte)' ') end--;
      string strValue = end < 0 ? null : Encoding.ASCII.GetString(rowData, column.Offset, end+1);
      object value;
      if(column.Type == ColumnType.String)
      {
        value = strValue;
      }
      else if(column.Type == ColumnType.Number)
      {
        if(strValue != null) strValue = strValue.Trim();
        value = string.IsNullOrEmpty(strValue) || strValue == "." ? null : (object)double.Parse(strValue, CultureInfo.InvariantCulture);
      }
      else if(column.Type == ColumnType.Bool)
      {
        value = strValue == null ? null : (object)(strValue[0] == 'T');
      }
      else
      {
        throw new Exception("Unhandled column type: " + column.Type.ToString());
      }
      values[column.Name] = value;
    }
  }

  internal void Save(IEnumerable<Column> columns, byte[] rowBuffer)
  {
    bool hasExistingData = rowData != null;
    if(hasExistingData) Array.Copy(rowData, rowBuffer, rowData.Length);
    else ClearRowData(rowBuffer, 0, rowBuffer.Length);

    foreach(Column column in columns)
    {
      object value;
      if(values.TryGetValue(column.Name, out value) && value != null)
      {
        string strValue;
        if(column.Type == ColumnType.String)
        {
          strValue = value as string;
          if(strValue == null)
          {
            if(value is double)
            {
              double dblValue = (double)value;
              strValue = TrimNumberString(dblValue.ToString(CultureInfo.InvariantCulture), column);
            }
            else
            {
              strValue = Convert.ToString(value, CultureInfo.InvariantCulture);
            }
          }
        }
        else if(column.Type == ColumnType.Number)
        {
          double dblValue;
          if(value is double)
          {
            dblValue = (double)value;
          }
          else if(value is string)
          {
            strValue = (string)value;
            if(strValue.Length == 0) continue;
            if(!InvariantCultureUtility.TryParse(strValue, out dblValue))
            {
              if(strValue.Trim() == ".") continue; // the stock database uses "  . " for some nonexistent numbers
              throw new Exception("The value \"" + strValue + "\" (for column " + column.Name + ") is not a number.");
            }
          }
          else
          {
            throw new Exception("The value \"" + value.ToString() + "\" (for column " + column.Name + ") is not a number.");
          }

          strValue = TrimNumberString(dblValue.ToString(column.Decimals == 0 ? "f0" : "f" + column.Decimals.ToStringInvariant(),
                                                        CultureInfo.InvariantCulture), column);
        }
        else if(column.Type == ColumnType.Bool)
        {
          bool boolValue;
          if(value is bool)
          {
            boolValue = (bool)value;
          }
          else if(value is string)
          {
            strValue = ((string)value).Trim();
            if(strValue.Length == 0) continue;
            boolValue = char.ToUpperInvariant(strValue[0]) == 'T';
          }
          else
          {
            throw new Exception("The value \"" + value.ToString() + "\" (for column " + column.Name + ") is not a boolean.");
          }
          strValue = boolValue ? "T" : "F";
        }
        else
        {
          throw new Exception("Unhandled column type: " + column.Type.ToString());
        }

        if(strValue.Length != 0)
        {
          byte[] bytes = Encoding.ASCII.GetBytes(strValue);
          if(bytes.Length > column.Length) throw new Exception("The value \"" + strValue + "\" is too large for column " + column.Name);
          Array.Copy(bytes, 0, rowBuffer, column.Offset, bytes.Length);
          if(hasExistingData && bytes.Length < column.Length)
          {
            ClearRowData(rowBuffer, column.Offset+bytes.Length, column.Length-bytes.Length);
          }
        }
      }
      else if(hasExistingData) // if we have existing data, we have to explicitly remove empty values
      {
        ClearRowData(rowBuffer, column.Offset, column.Length);
      }
    }
  }

  readonly Dictionary<string, object> values = new Dictionary<string, object>();
  byte[] rowData;

  static void ClearRowData(byte[] rowData, int index, int length)
  {
    for(int end=index+length; index<end; index++) rowData[index] = (byte)' ';
  }

  static string TrimNumberString(string str, Column column)
  {
    if(str.Length > column.Length)
    {
      int period = str.IndexOf('.');
      if(period == -1 || period > column.Length) throw new Exception("The value \"" + str + "\" is too large for column " + column.Name);
      else str = str.Substring(0, period == column.Length-1 ? period-1 : column.Length);
    }
    else if(str.Length < column.Length) // right-align numbers to match the stock database format
    {
      str = str.PadLeft(column.Length, ' ');
    }
    return str;
  }
}
#endregion

} // namespace Cap2Mod.DBMod
