A recent project of mine required a slight deviation from what's provided us out of the box by Microsoft's ASP.NET security library. I wanted to extend the SqlMembershipProvider, SqlRoleProvider, and SqlProfileProvider as they provided us a good base, but needed a few tweaks to fully meet our needs. Instead of rewriting the entire framework, I decided to extend it. One of the privileges afforded us in OO is the ability to decide access modifiers of a class and its data. The downside of these decisions is perceptional.
However, with the use of Red Gate's .NET Reflector and System.Reflection, I was able to reuse much of the internal data which in turn helped save lots of time and effort. The following article illustrates how I regained this data.
Below code outline:
Subclass MembershipUser for custom implementation
Subclass SqlMembershipProvder, use reflection to regain internal data, with sample override of GetUser method
WebUtility, a helper class to reuse internal ResourceManager, and assist in ADO.NET structuring
SecUtility, a helper class to help with input validation and reading config values
public class CustomMembershipUser : MembershipUser
{
private int _customRefId;
protected CustomMembershipUser()
: base()
{
}
public CustomMembershipUser(string providerName, string name, object providerUserKey, string email, string passwordQuestion,
string comment, bool isApproved, bool isLockedOut, DateTime creationDate, DateTime lastLoginDate, DateTime lastActivityDate,
DateTime lastPasswordChangedDate, DateTime lastLockoutDate, int customRefId)
: base(providerName, name, providerUserKey, email, passwordQuestion, comment, isApproved, isLockedOut, creationDate, lastLoginDate, lastActivityDate, lastPasswordChangedDate, lastLockoutDate)
{
this._customRefId = customRefId;
}
public virtual int CustomRefId { get { return this._customRefId; } }
}
public class CustomMembershipProvider : SqlMembershipProvider
{
private string _sqlConnectionString = string.Empty;
private string _customConfigValue = string.Empty;
private int _commandTimeout = 30;
public override void Initialize(string name, NameValueCollection config)
{
// Do this before passing to base as it will clear config items
this._customConfigValue = config["CustomKeyName"];
config.Remove("CustomKeyName");
// Run base
base.Initialize(name, config);
// Regain access to internal data
var flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance;
// Get base connectionstring
FieldInfo fi = this.GetType().BaseType.GetField("_sqlConnectionString", flags);
if (fi != null)
{ this._sqlConnectionString = (string)fi.GetValue(this); }
// Get base command timeout
fi = this.GetType().BaseType.GetField("_CommandTimeout", flags);
if (fi != null)
{ this._commandTimeout = (int)fi.GetValue(this); }
}
public new virtual CustomMembershipUser GetUser(string username, bool userIsOnline)
{
//return base.GetUser(username, userIsOnline);
CustomMembershipUser user = null;
SecUtility.CheckParameter(ref username, true, false, true, 0x100, "username");
SqlConnection connection = null;
try
{
connection = new SqlConnection(this._sqlConnectionString);
connection.Open();
//base.CheckSchemaVersion(connection.Connection);
SqlCommand command = new SqlCommand("dbo.aspnet_Membership_GetUserByName", connection);
SqlDataReader reader = null;
command.CommandTimeout = this._commandTimeout;
command.CommandType = CommandType.StoredProcedure;
command.AddParameter("@ApplicationName", SqlDbType.NVarChar, this.ApplicationName);
command.AddParameter("@UserName", SqlDbType.NVarChar, username);
command.AddParameter("@UpdateLastActivity", SqlDbType.Bit, userIsOnline);
command.AddParameter("@CurrentTimeUtc", SqlDbType.DateTime, DateTime.UtcNow);
command.AddParameter("@ReturnValue", SqlDbType.Int, null, ParameterDirection.ReturnValue);
try
{
reader = command.ExecuteReader();
if (reader.Read())
{
string email = reader.GetNullableString(0);
string passwordQuestion = reader.GetNullableString(1);
string comment = reader.GetNullableString(2);
bool isApproved = reader.GetBoolean(3);
DateTime creationDate = reader.GetDateTime(4).ToLocalTime();
DateTime lastLoginDate = reader.GetDateTime(5).ToLocalTime();
DateTime lastActivityDate = reader.GetDateTime(6).ToLocalTime();
DateTime lastPasswordChangedDate = reader.GetDateTime(7).ToLocalTime();
Guid providerUserKey = reader.GetGuid(8);
bool isLockedOut = reader.GetBoolean(9);
DateTime lastLockoutDate = reader.GetDateTime(10);
int customRefId = reader.GetBoolean(11);
user = new CustomMembershipUser(this.Name, username, providerUserKey, email, passwordQuestion, comment,
isApproved, isLockedOut, creationDate, lastLoginDate, lastActivityDate, lastPasswordChangedDate,
lastLockoutDate.ToLocalTime(), customRefId);
}
}
catch
{
throw;
}
finally
{
if (reader != null)
{
reader.Close();
reader = null;
}
}
}
finally
{
if (connection != null)
{
connection.Close();
connection = null;
}
}
return user;
}
}
internal static class WebUtility
{
internal static System.Resources.ResourceManager SR
{
get
{
return new System.Resources.ResourceManager("System.Web", typeof(System.Web.Security.SqlMembershipProvider).Assembly);
}
}
internal static void AddParameter(this SqlCommand cmd, string paramName, SqlDbType dbType, object objValue = null, ParameterDirection direction = ParameterDirection.Input)
{
cmd.Parameters.Add(new SqlParameter(paramName, dbType)
{
IsNullable = ((objValue == null) ? true : false),
Direction = direction,
Value = ((objValue == null) ? DBNull.Value : objValue)
});
}
internal static T GetParameterValue(this SqlCommand cmd, ParameterDirection direction = ParameterDirection.ReturnValue, string parameterName = null)
where T : struct
{
T rtn = default(T);
if (!String.IsNullOrEmpty(parameterName))
{
rtn = (T)cmd.Parameters[parameterName].Value;
}
else
{
foreach (SqlParameter parameter in cmd.Parameters)
{
if (((parameter.Direction == direction) && (parameter.Value != null)) && (parameter.Value is T))
{
rtn = (T)parameter.Value;
break;
}
}
}
return rtn; // was -1
}
internal static string GetNullableString(this IDataReader reader, int col)
{
if (!reader.IsDBNull(col))
{
return reader.GetString(col);
}
return string.Empty;
}
internal static Guid GetNullableGuid(this IDataReader reader, int col)
{
if (!reader.IsDBNull(col))
{
return reader.GetGuid(col);
}
return Guid.Empty;
}
internal static DateTime GetNullableDateTime(this IDataReader reader, int col)
{
if (!reader.IsDBNull(col))
{
return reader.GetDateTime(col);
}
return DateTime.MinValue;
}
internal static Int32 GetNullableInt32(this IDataReader reader, int col, int defaultValue = 1)
{
if (!reader.IsDBNull(col))
{
return reader.GetInt32(col);
}
return defaultValue;
}
internal static string GetString(this System.Resources.ResourceManager sr, string name, object[] args)
{
return string.Format(sr.GetString(name), args);
}
}
class SecUtility
{
internal static void CheckParameter(ref string param, bool checkForNull, bool checkIfEmpty, bool checkForCommas, int maxSize, string paramName)
{
if (param == null)
{
if (checkForNull)
{
throw new ArgumentNullException(paramName);
}
}
else
{
param = param.Trim();
if (checkIfEmpty && (param.Length < 1))
{
throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_can_not_be_empty"), paramName));
}
if ((maxSize > 0) && (param.Length > maxSize))
{
throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_too_long"), paramName, maxSize.ToString(CultureInfo.InvariantCulture)));
}
if (checkForCommas && param.Contains(","))
{
throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_can_not_contain_comma"), paramName));
}
}
}
internal static void CheckArrayParameter(ref string[] param, bool checkForNull, bool checkIfEmpty, bool checkForCommas, int maxSize, string paramName)
{
if (param == null)
{
throw new ArgumentNullException(paramName);
}
if (param.Length < 1)
{
throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_array_empty"), paramName), paramName);
}
Hashtable hashtable = new Hashtable(param.Length);
for (int i = param.Length - 1; i >= 0; i--)
{
CheckParameter(ref param[i], checkForNull, checkIfEmpty, checkForCommas, maxSize, paramName + "[ " + i.ToString(CultureInfo.InvariantCulture) + " ]");
if (hashtable.Contains(param[i]))
{
throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, WebUtility.SR.GetString("Parameter_duplicate_array_element"), paramName), paramName);
}
hashtable.Add(param[i], param[i]);
}
}
internal static bool ValidateParameter(ref string param, bool checkForNull, bool checkIfEmpty, bool checkForCommas, int maxSize)
{
if (param == null)
{
return !checkForNull;
}
param = param.Trim();
return (((!checkIfEmpty || (param.Length >= 1)) && ((maxSize <= 0) || (param.Length <= maxSize))) && (!checkForCommas || !param.Contains(",")));
}
internal static int GetIntValue(NameValueCollection config, string valueName, int defaultValue, bool zeroAllowed, int maxValueAllowed)
{
int num;
string s = config[valueName];
if (s == null)
{
return defaultValue;
}
if (!int.TryParse(s, out num))
{
if (zeroAllowed)
{
throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_non_negative_integer"), valueName));
}
throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_positive_integer"), valueName));
}
if (zeroAllowed && (num < 0))
{
throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_non_negative_integer"), valueName));
}
if (!zeroAllowed && (num <= 0))
{
throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_must_be_positive_integer"), valueName));
}
if ((maxValueAllowed > 0) && (num > maxValueAllowed))
{
throw new ProviderException(string.Format(WebUtility.SR.GetString("Value_too_big"), valueName, maxValueAllowed.ToString(CultureInfo.InvariantCulture)));
}
return num;
}
internal static string GetConnectionString(NameValueCollection config, string connectionStringAttributeName = "connectionStringName")
{
string str = config["connectionString"];
if (string.IsNullOrEmpty(str))
{
string str2 = config[connectionStringAttributeName];
if (string.IsNullOrEmpty(str2))
{
throw new ProviderException(WebUtility.SR.GetString("Connection_name_not_specified"));
}
bool lookupConnectionString = true;
bool appLevel = true;
str = DataAccess.SqlConnectionHelper.GetConnectionString(str2, lookupConnectionString, appLevel);
if (string.IsNullOrEmpty(str))
{
throw new ProviderException(string.Format(WebUtility.SR.GetString("Connection_string_not_found"), str2));
}
}
return str;
}
internal static string GetDefaultAppName()
{
try
{
string applicationVirtualPath = HostingEnvironment.ApplicationVirtualPath;
if (string.IsNullOrEmpty(applicationVirtualPath))
{
applicationVirtualPath = Process.GetCurrentProcess().MainModule.ModuleName;
int index = applicationVirtualPath.IndexOf('.');
if (index != -1)
{
applicationVirtualPath = applicationVirtualPath.Remove(index);
}
}
if (string.IsNullOrEmpty(applicationVirtualPath))
{
return "/";
}
return applicationVirtualPath;
}
catch
{
return "/";
}
}
internal static string MakeStringLiteral(string input)
{
if (string.IsNullOrEmpty(input))
{
return "''";
}
return ("'" + EscapeSqlStringAsLiteral(input) + "'");
}
internal static string EscapeSqlStringAsLiteral(string input)
{
return input.Replace("'", "''");
}
}