Extend SqlMembershipProvider, but regain internal data

Tuesday, June 7, 2011
A recent project of mine required a slight deviation from what's provided us out of the box by Microsoft's ASP.NET security library.  I wanted to extend the SqlMembershipProvider, SqlRoleProvider, and SqlProfileProvider as they provided us a good base, but needed a few tweaks to fully meet our needs.  Instead of rewriting the entire framework, I decided to extend it.  One of the privileges afforded us in OO is the ability to decide access modifiers of a class and its data.  The downside of these decisions is perceptional.
However, with the use of Red Gate's .NET Reflector and System.Reflection, I was able to reuse much of the internal data which in turn helped save lots of time and effort.  The following article illustrates how I regained this data. 
 
Below code outline:
  1. Subclass MembershipUser for custom implementation
  2. Subclass SqlMembershipProvder, use reflection to regain internal data, with sample override of GetUser method
  3. WebUtility, a helper class to reuse internal ResourceManager, and assist in ADO.NET structuring
  4. SecUtility, a helper class to help with input validation and reading config values
 
public class CustomMembershipUser : MembershipUser
{
	private int _customRefId;

	protected CustomMembershipUser()
		: base()
	{
	}

	public CustomMembershipUser(string providerName, string name, object providerUserKey, string email, string passwordQuestion,
		string comment, bool isApproved, bool isLockedOut, DateTime creationDate, DateTime lastLoginDate, DateTime lastActivityDate,
		DateTime lastPasswordChangedDate, DateTime lastLockoutDate, int customRefId)
		: base(providerName, name, providerUserKey, email, passwordQuestion, comment, isApproved, isLockedOut, creationDate, lastLoginDate, lastActivityDate, lastPasswordChangedDate, lastLockoutDate)
	{
		this._customRefId = customRefId;
	}

	public virtual int CustomRefId { get { return this._customRefId; } }
}

public class CustomMembershipProvider : SqlMembershipProvider
{
	private string _sqlConnectionString = string.Empty;
	private string _customConfigValue = string.Empty;
	private int _commandTimeout = 30;

	public override void Initialize(string name, NameValueCollection config)
	{
		// Do this before passing to base as it will clear config items
		this._customConfigValue = config["CustomKeyName"];
		config.Remove("CustomKeyName");

		// Run base
		base.Initialize(name, config);

		// Regain access to internal data
		var flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance;
		// Get base connectionstring
		FieldInfo fi = this.GetType().BaseType.GetField("_sqlConnectionString", flags);
		if (fi != null)
		{ this._sqlConnectionString = (string)fi.GetValue(this); }

		// Get base command timeout
		fi = this.GetType().BaseType.GetField("_CommandTimeout", flags);
		if (fi != null)
		{ this._commandTimeout = (int)fi.GetValue(this); }
	}
	
	public new virtual CustomMembershipUser GetUser(string username, bool userIsOnline)	
    {
        //return base.GetUser(username, userIsOnline);
    
        CustomMembershipUser user = null;
        SecUtility.CheckParameter(ref username, true, false, true, 0x100, "username");
        SqlConnection connection = null;
        try
        {
            connection = new SqlConnection(this._sqlConnectionString);
            connection.Open();
            //base.CheckSchemaVersion(connection.Connection);
            SqlCommand command = new SqlCommand("dbo.aspnet_Membership_GetUserByName", connection);
            SqlDataReader reader = null;
            command.CommandTimeout = this._commandTimeout;
            command.CommandType = CommandType.StoredProcedure;
            command.AddParameter("@ApplicationName", SqlDbType.NVarChar, this.ApplicationName);
            command.AddParameter("@UserName", SqlDbType.NVarChar, username);
            command.AddParameter("@UpdateLastActivity", SqlDbType.Bit, userIsOnline);
            command.AddParameter("@CurrentTimeUtc", SqlDbType.DateTime, DateTime.UtcNow);
            command.AddParameter("@ReturnValue", SqlDbType.Int, null, ParameterDirection.ReturnValue);
            try
            {
                reader = command.ExecuteReader();
                if (reader.Read())
                {
                    string email = reader.GetNullableString(0);
                    string passwordQuestion = reader.GetNullableString(1);
                    string comment = reader.GetNullableString(2);
                    bool isApproved = reader.GetBoolean(3);
                    DateTime creationDate = reader.GetDateTime(4).ToLocalTime();
                    DateTime lastLoginDate = reader.GetDateTime(5).ToLocalTime();
                    DateTime lastActivityDate = reader.GetDateTime(6).ToLocalTime();
                    DateTime lastPasswordChangedDate = reader.GetDateTime(7).ToLocalTime();
                    Guid providerUserKey = reader.GetGuid(8);
                    bool isLockedOut = reader.GetBoolean(9);
                    DateTime lastLockoutDate = reader.GetDateTime(10);
                    int customRefId  = reader.GetBoolean(11);
    
                    user = new CustomMembershipUser(this.Name, username, providerUserKey, email, passwordQuestion, comment,
                        isApproved, isLockedOut, creationDate, lastLoginDate, lastActivityDate, lastPasswordChangedDate,
                        lastLockoutDate.ToLocalTime(), customRefId);
                }
            }
            catch
            {
                throw;
            }
            finally
            {
                if (reader != null)
                {
                    reader.Close();
                    reader = null;
                }
            }
        }
        finally
        {
            if (connection != null)
            {
                connection.Close();
                connection = null;
            }
        }
        return user;
    }
}

internal static class WebUtility
{
	internal static System.Resources.ResourceManager SR
	{
		get
		{
			return new System.Resources.ResourceManager("System.Web", typeof(System.Web.Security.SqlMembershipProvider).Assembly);
		}
	}

	internal static void AddParameter(this SqlCommand cmd, string paramName, SqlDbType dbType, object objValue = null, ParameterDirection direction = ParameterDirection.Input)
	{
		cmd.Parameters.Add(new SqlParameter(paramName, dbType)
		{
			IsNullable = ((objValue == null) ? true : false),
			Direction = direction,
			Value = ((objValue == null) ? DBNull.Value : objValue)
		});
	}

	internal static T GetParameterValue(this SqlCommand cmd, ParameterDirection direction = ParameterDirection.ReturnValue, string parameterName = null)
		where T : struct
	{
		T rtn = default(T);
		if (!String.IsNullOrEmpty(parameterName))
		{
			rtn = (T)cmd.Parameters[parameterName].Value;
		}
		else
		{
			foreach (SqlParameter parameter in cmd.Parameters)
			{
				if (((parameter.Direction == direction) && (parameter.Value != null)) && (parameter.Value is T))
				{
					rtn = (T)parameter.Value;
					break;
				}
			}
		}

		return rtn; // was -1
	}

	internal static string GetNullableString(this IDataReader reader, int col)
	{
		if (!reader.IsDBNull(col))
		{
			return reader.GetString(col);
		}
		return string.Empty;
	}

	internal static Guid GetNullableGuid(this IDataReader reader, int col)
	{
		if (!reader.IsDBNull(col))
		{
			return reader.GetGuid(col);
		}
		return Guid.Empty;
	}

	internal static DateTime GetNullableDateTime(this IDataReader reader, int col)
	{
		if (!reader.IsDBNull(col))
		{
			return reader.GetDateTime(col);
		}
		return DateTime.MinValue;
	}

	internal static Int32 GetNullableInt32(this IDataReader reader, int col, int defaultValue = 1)
	{
		if (!reader.IsDBNull(col))
		{
			return reader.GetInt32(col);
		}
		return defaultValue;
	}
	
	internal static string GetString(this System.Resources.ResourceManager sr, string name, object[] args)
	{
		return string.Format(sr.GetString(name), args);
	}
}

class SecUtility
{
	internal static void CheckParameter(ref string param, bool checkForNull, bool checkIfEmpty, bool checkForCommas, int maxSize, string paramName)
	{
		if (param == null)
		{
			if (checkForNull)
			{
				throw new ArgumentNullException(paramName);
			}
		}
		else
		{
			param = param.Trim();
			if (checkIfEmpty && (param.Length < 1))
			{
				throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_can_not_be_empty"), paramName));
			}
			if ((maxSize > 0) && (param.Length > maxSize))
			{
				throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_too_long"), paramName, maxSize.ToString(CultureInfo.InvariantCulture)));
			}
			if (checkForCommas && param.Contains(","))
			{
				throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_can_not_contain_comma"), paramName));
			}
		}
	}

	internal static void CheckArrayParameter(ref string[] param, bool checkForNull, bool checkIfEmpty, bool checkForCommas, int maxSize, string paramName)
	{
		if (param == null)
		{
			throw new ArgumentNullException(paramName);
		}
		if (param.Length < 1)
		{
			throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_array_empty"), paramName), paramName);
		}
		Hashtable hashtable = new Hashtable(param.Length);
		for (int i = param.Length - 1; i >= 0; i--)
		{
			CheckParameter(ref param[i], checkForNull, checkIfEmpty, checkForCommas, maxSize, paramName + "[ " + i.ToString(CultureInfo.InvariantCulture) + " ]");
			if (hashtable.Contains(param[i]))
			{
				throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_duplicate_array_element"), paramName), paramName);
			}
			hashtable.Add(param[i], param[i]);
		}
	}

	internal static bool ValidateParameter(ref string param, bool checkForNull, bool checkIfEmpty, bool checkForCommas, int maxSize)
	{
		if (param == null)
		{
			return !checkForNull;
		}
		param = param.Trim();
		return (((!checkIfEmpty || (param.Length >= 1)) && ((maxSize <= 0) || (param.Length <= maxSize))) && (!checkForCommas || !param.Contains(",")));
	}

	internal static int GetIntValue(NameValueCollection config, string valueName, int defaultValue, bool zeroAllowed, int maxValueAllowed)
	{
		int num;
		string s = config[valueName];
		if (s == null)
		{
			return defaultValue;
		}
		if (!int.TryParse(s, out num))
		{
			if (zeroAllowed)
			{
				throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_non_negative_integer"), valueName));
			}
			throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_positive_integer"), valueName));
		}
		if (zeroAllowed && (num < 0))
		{
			throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_non_negative_integer"), valueName));
		}
		if (!zeroAllowed && (num <= 0))
		{
			throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_positive_integer"), valueName));
		}
		if ((maxValueAllowed > 0) && (num > maxValueAllowed))
		{
			throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_too_big"), valueName, maxValueAllowed.ToString(CultureInfo.InvariantCulture)));
		}
		return num;
	}

	internal static string GetConnectionString(NameValueCollection config, string connectionStringAttributeName = "connectionStringName")
	{
		string str = config["connectionString"];
		if (string.IsNullOrEmpty(str))
		{
			string str2 = config[connectionStringAttributeName];
			if (string.IsNullOrEmpty(str2))
			{
				throw new ProviderException(WebUtility.SR.GetString("Connection_name_not_specified"));
			}
			bool lookupConnectionString = true;
			bool appLevel = true;
			str = DataAccess.SqlConnectionHelper.GetConnectionString(str2, lookupConnectionString, appLevel);
			if (string.IsNullOrEmpty(str))
			{
				throw new ProviderException(string.Format(WebUtility.SR.GetString("Connection_string_not_found"), str2));
			}
		}
		return str;
	}

	internal static string GetDefaultAppName()
	{
		try
		{
			string applicationVirtualPath = HostingEnvironment.ApplicationVirtualPath;
			if (string.IsNullOrEmpty(applicationVirtualPath))
			{
				applicationVirtualPath = Process.GetCurrentProcess().MainModule.ModuleName;
				int index = applicationVirtualPath.IndexOf('.');
				if (index != -1)
				{
					applicationVirtualPath = applicationVirtualPath.Remove(index);
				}
			}
			if (string.IsNullOrEmpty(applicationVirtualPath))
			{
				return "/";
			}
			return applicationVirtualPath;
		}
		catch
		{
			return "/";
		}
	}

	internal static string MakeStringLiteral(string input)
	{
		if (string.IsNullOrEmpty(input))
		{
			return "''";
		}
		return ("'" + EscapeSqlStringAsLiteral(input) + "'");
	}

	internal static string EscapeSqlStringAsLiteral(string input)
	{
		return input.Replace("'", "''");
	}
}
Tags:
Filed Under: ASP.NET, C#

Add comment




  Country flag

  • Comment
  • Preview
Loading