// Copyright (C) 2004-2007 MySQL AB
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 2 as published by
// the Free Software Foundation
//
// There are special exceptions to the terms and conditions of the GPL 
// as it is applied to this software. View the full text of the 
// exception in file EXCEPTIONS in the directory of this software 
// distribution.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA 

using System;
using System.ComponentModel;
using System.Data.Common;
using System.Data;
using System.Text;
using MySql.Data.Types;
using System.Globalization;

namespace MySql.Data.MySqlClient
{
    /// <include file='docs/MySqlCommandBuilder.xml' path='docs/class/*'/>
#if !PocketPC
    [ToolboxItem(false)]
    [DesignerCategory("Code")]
#endif
    public sealed class MySqlCommandBuilder : DbCommandBuilder
    {
        private string finalSelect;
        private bool returnGeneratedIds;

        #region Constructors

        /// <include file='docs/MySqlCommandBuilder.xml' path='docs/Ctor/*'/>
        public MySqlCommandBuilder()
        {
            QuotePrefix = QuoteSuffix = "`";
        }

        /// <include file='docs/MySqlCommandBuilder.xml' path='docs/Ctor2/*'/>
        public MySqlCommandBuilder(MySqlDataAdapter adapter)
            : this()
        {
            DataAdapter = adapter;
        }

        #endregion

        #region Properties

        /// <include file='docs/mysqlcommandBuilder.xml' path='docs/DataAdapter/*'/>
        public new MySqlDataAdapter DataAdapter
        {
            get { return (MySqlDataAdapter)base.DataAdapter; }
            set { base.DataAdapter = value; }
        }

        private char ParameterMarker
        {
            get
            {
                return DataAdapter.SelectCommand.Connection.ParameterMarker;
            }
        }

        /// <summary>
        /// Indicates whether the command builder should generate a SELECT statement
        /// to populate any autogenerated fields.  We provide this property rather
        /// than rely on the MySqlCommand.UpdatedRowSource property since a user should
        /// still be able to write a custom insert command and not have our work interfere.
        /// </summary>
        public bool ReturnGeneratedIdentifiers
        {
            get { return returnGeneratedIds; }
            set { returnGeneratedIds = value; }
        }

        #endregion

        #region Public Methods

        /// <summary>
        /// Retrieves parameter information from the stored procedure specified 
        /// in the MySqlCommand and populates the Parameters collection of the 
        /// specified MySqlCommand object.
        /// This method is not currently supported since stored procedures are 
        /// not available in MySql.
        /// </summary>
        /// <param name="command">The MySqlCommand referencing the stored 
        /// procedure from which the parameter information is to be derived. 
        /// The derived parameters are added to the Parameters collection of the 
        /// MySqlCommand.</param>
        /// <exception cref="InvalidOperationException">The command text is not 
        /// a valid stored procedure name.</exception>
        public static void DeriveParameters(MySqlCommand command)
        {
            if (!command.Connection.driver.Version.isAtLeast(5, 0, 0))
                throw new MySqlException("DeriveParameters is not supported on MySQL versions " +
                    "prior to 5.0");

            // retrieve the proc definitino from the cache.
            string spName = command.CommandText;
            if (spName.IndexOf(".") == -1)
                spName = command.Connection.Database + "." + spName;
            DataSet ds = command.Connection.ProcedureCache.GetProcedure(command.Connection, spName);

            DataTable parameters = ds.Tables["Procedure Parameters"];
            DataTable procTable = ds.Tables["Procedures"];
            command.Parameters.Clear();
            foreach (DataRow row in parameters.Rows)
            {
                MySqlParameter p = new MySqlParameter();
                p.ParameterName = row["PARAMETER_NAME"].ToString();
                p.Direction = GetDirection(row["PARAMETER_MODE"].ToString(),
                    row["IS_RESULT"].ToString());
                bool unsigned = row["FLAGS"].ToString().IndexOf("UNSIGNED") != -1;
                bool real_as_float = procTable.Rows[0]["SQL_MODE"].ToString().IndexOf("REAL_AS_FLOAT") != -1;
                p.MySqlDbType = MetaData.NameToType(row["DATA_TYPE"].ToString(),
                    unsigned, real_as_float, command.Connection);
                if (!row["CHARACTER_MAXIMUM_LENGTH"].Equals(DBNull.Value))
                    p.Size = (int)row["CHARACTER_MAXIMUM_LENGTH"];
                if (!row["NUMERIC_PRECISION"].Equals(DBNull.Value))
                    p.Precision = (byte)row["NUMERIC_PRECISION"];
                if (!row["NUMERIC_SCALE"].Equals(DBNull.Value))
                    p.Scale = (byte)(int)row["NUMERIC_SCALE"];
                command.Parameters.Add(p);
            }
        }

        private static ParameterDirection GetDirection(string direction, string is_result)
        {
            if (is_result == "YES")
                return ParameterDirection.ReturnValue;
            else if (direction == "IN")
                return ParameterDirection.Input;
            else if (direction == "OUT")
                return ParameterDirection.Output;
            return ParameterDirection.InputOutput;
        }

        /// <summary>
        /// Gets the automatically generated MySqlCommand object required to perform deletions on the database. 
        /// </summary>
        /// <returns>The automatically generated MySqlCommand object required to perform deletions. </returns>
        public new MySqlCommand GetDeleteCommand()
        {
            return (MySqlCommand)base.GetDeleteCommand();
        }

        /// <summary>
        /// Gets the automatically generated MySqlCommand object required to perform updates on the database. 
        /// </summary>
        /// <returns>The automatically generated MySqlCommand object required to perform updates. </returns>
        public new MySqlCommand GetUpdateCommand()
        {
            return (MySqlCommand)base.GetUpdateCommand();
        }

        /// <summary>
        /// Gets the automatically generated MySqlCommand object required to perform inserts on the database. 
        /// </summary>
        /// <returns>The automatically generated MySqlCommand object required to perform inserts. </returns>
        public new MySqlCommand GetInsertCommand()
        {
            return (MySqlCommand)GetInsertCommand(false);
        }

        /// <include file='docs/MySqlCommandBuilder.xml' path='docs/RefreshSchema/*'/>
        public override void RefreshSchema()
        {
            base.RefreshSchema();
            finalSelect = null;
        }
        #endregion


        /// <summary>
        /// 
        /// </summary>
        /// <param name="columnName"></param>
        /// <returns></returns>
        protected override string GetParameterName(string columnName)
        {
            StringBuilder sb = new StringBuilder(columnName);
            sb.Replace(" ", "");
            sb.Replace("/", "_per_");
            sb.Replace("-", "_");
            sb.Replace(")", "_cb_");
            sb.Replace("(", "_ob_");
            sb.Replace("%", "_pct_");
            sb.Replace("<", "_lt_");
            sb.Replace(">", "_gt_");
            sb.Replace(".", "_pt_");
            return sb.ToString();
        }

        protected override void ApplyParameterInfo(DbParameter parameter, DataRow row,
            StatementType statementType, bool whereClause)
        {
            ((MySqlParameter)parameter).MySqlDbType = (MySqlDbType)row["ProviderType"];
        }

        protected override string GetParameterName(int parameterOrdinal)
        {
            return String.Format("{0}p{1}", ParameterMarker,
                parameterOrdinal.ToString(CultureInfo.InvariantCulture));
        }

        protected override string GetParameterPlaceholder(int parameterOrdinal)
        {
            return String.Format("{0}p{1}", ParameterMarker,
                parameterOrdinal.ToString(CultureInfo.InvariantCulture));
        }

        protected override void SetRowUpdatingHandler(DbDataAdapter adapter)
        {
            if (adapter != base.DataAdapter)
                ((MySqlDataAdapter)adapter).RowUpdating += new MySqlRowUpdatingEventHandler(RowUpdating);
            else
                ((MySqlDataAdapter)adapter).RowUpdating -= new MySqlRowUpdatingEventHandler(RowUpdating);
        }

        private void RowUpdating(object sender, MySqlRowUpdatingEventArgs args)
        {
            RowUpdatingHandler(args);

            if (args.StatementType != StatementType.Insert) return;

            if (ReturnGeneratedIdentifiers)
            {
                if (args.Command.UpdatedRowSource != UpdateRowSource.None)
                    throw new InvalidOperationException(
                        Resources.MixingUpdatedRowSource);
                args.Command.UpdatedRowSource = UpdateRowSource.FirstReturnedRecord;
                if (finalSelect == null)
                    CreateFinalSelect();
            }

            args.Command.CommandText += finalSelect;
        }

        /// <summary>
        /// We only need to return the single auto generated column since the base
        /// ADO.Net classes will take care of mapping it onto the datarow for us.
        /// </summary>
        private void CreateFinalSelect()
        {
            StringBuilder select = new StringBuilder(";SELECT last_insert_id() AS ");

            DataTable dt = GetSchemaTable(DataAdapter.SelectCommand);

            foreach (DataRow row in dt.Rows)
            {
                if (!(bool)row["IsAutoIncrement"])
                    continue;

                select.AppendFormat("`{0}`", row["ColumnName"]);
                break;
            }

            finalSelect = select.ToString();
        }
    }
}
