#include "domain_name.h"
#include <ctype.h>

using namespace System::Text;

namespace ADNS {


	String^ parse_escape(String^ s)
	{
		int pos;
		String^ subcharval;

		while(( pos = s->IndexOf('\\')) != -1)
		{
			//check for an octal value
			if ( (Char::IsDigit(s[pos + 1])) &&
				 (Char::IsDigit(s[pos + 2])) &&
				 (Char::IsDigit(s[pos + 3])) )
			{
				//calculate the octal to decimal conversion.  Replace the octal sequence with the ASCII equivalent
				subcharval = s->Substring(pos+ 1,3);
				s = s->Replace(s->Substring(pos,4),Convert::ToString(Convert::ToChar(Convert::ToByte(subcharval,10))));
			}
			else
			{//check for a null char
				if (s[pos + 1] == '\0')
				{
					//remove the slash and the terminator
					s = s->Remove(pos,2);
				}
				else
				{
					return nullptr;
				}
			}
		}

		return s;


	}

	DOMAIN_NAME::DOMAIN_NAME()
	{
		name = gcnew array<Byte>(1);
		name[0] = 0;
	}

	DOMAIN_NAME::DOMAIN_NAME(array<Byte>^ domain)
	{
		name = gcnew array<Byte>(domain->Length);
		domain->CopyTo(name,0);
	}

	DOMAIN_NAME::DOMAIN_NAME(String^ domain)
	{
		ASCIIEncoding^ ae = gcnew ASCIIEncoding();
		array<String^>^ ssplit;
		int i = 0;
		int pos = 0;
		int totalsize = 0;
		array<Byte>^ tmp;
		Exception^ ex;

		domain = domain->ToLower();

		//break the string up along the "." characters to get individual labels.
		ssplit = domain->Split('.');
		//parse each label to remove octal (\ooo) characters replace them with their ASCII equivalent
		//also remove null terminators (\0)

		for (i = 0; i < ssplit->Length; ++i)
		{
			ssplit[i] = parse_escape(ssplit[i]);
			if (ssplit[i] == nullptr)
			{
				ex = gcnew Exception("Bad Escape String sent to parse_escape: " + ssplit[i]);
				throw ex;
			}
		}

		//construct the [length][string][length][string] representation, and put it into a DNAME RDF.
		for (i = 0; i < ssplit->Length; ++i)
		{
			if (ssplit[i] != "")  //ignore null strings caused by leading or trailing "."
				totalsize += ae->GetByteCount(ssplit[i]) + 1; //+1 for the length byte
		}		
		totalsize += 1;  //for ending "0"
			
		name = gcnew array<Byte>(totalsize);
		name->Clear(name,0,name->Length);

		pos = 0;
		for (i = 0; i < ssplit->Length; ++i)
		{
			if (ssplit[i] != "")//ignore null strings caused by leading or trailing "."
			{
				tmp = ae->GetBytes(ssplit[i]);
				name[pos++] = tmp->Length;
				tmp->CopyTo(name,pos);
				pos += tmp->Length;
			}
		}
	}

	array<Byte>^ DOMAIN_NAME::GetName()
	{
		array<Byte>^ output = gcnew array<Byte>(name->Length);
		if (name->Length == 1)
		{
			output[0] = name[0];
		}
		else
		{
			name->Copy(name,0,output,0,name->Length);
		}
		return output;
	}

	int DOMAIN_NAME::LabelCount()
	{
		int pos = 0;
		int num_labels = 0;

		if (name->Length == 1)
			return 0;
	
		while ((pos < name->Length) && (name[pos] > 0))
		{
			num_labels++;
			pos += name[pos] + 1;  // +1 so the skip includes the current length byte
		}

		return num_labels;
	}
	
	DOMAIN_NAME^ DOMAIN_NAME::Clone()
	{
		array<Byte>^ f;
		f = GetName();
		return gcnew DOMAIN_NAME(f);
	}

	Void DOMAIN_NAME::Concatenate(DOMAIN_NAME^ catname)
	{	
		int old_data_len = 0;
		array<Byte>^ intermed;

		if (catname == nullptr)
		{
			//do nothing
			return;
		}
		else
		{
			intermed = catname->GetName();;

			if (name->Length == 1)
			{
				name->Resize(name,intermed->Length);
				intermed->CopyTo(name,0);
			}
			else
			{
				old_data_len = name->Length;
				name->Resize(name,old_data_len + intermed->Length - 1); //-2 :  -1 for 0 base, -1 for skipping ending "0"
				intermed->CopyTo(name,old_data_len - 2);
			}
		}

		return;
	}

	DOMAIN_NAME^ DOMAIN_NAME::Reverse()
	{
		array<Byte>^ newdata;
		int pos = 0;

		newdata = gcnew array<Byte>(name->Length);
		newdata[name->Length - 1] = 0;
		while((pos < newdata->Length) && (name[pos] > 0))
		{	
			newdata->Copy(name,pos,newdata,newdata->Length - pos - 1,name[pos]);
			pos += name[pos] + 1;  // +1 so the skip includes the current length byte
		}

		return gcnew DOMAIN_NAME(newdata);
	}

	DOMAIN_NAME^ DOMAIN_NAME::CloneFrom(unsigned short int n)
	{
		array<Byte>^ newdata;
		int pos = 0;
		int labelnum = 0;

		if ( LabelCount() < n)
			return nullptr;

		while ((pos < name->Length) && (labelnum < n))
		{
			labelnum++;
			pos += name[pos] + 1;
		}

		newdata = gcnew array<Byte>(name->Length - pos);
		newdata->Copy(name,pos,newdata,0,name->Length-pos);

		return gcnew DOMAIN_NAME(newdata);
	}

	DOMAIN_NAME^ DOMAIN_NAME::LeftChop()
	{
		return CloneFrom(1);
	}

	bool DOMAIN_NAME::IsEqual(DOMAIN_NAME^ a2)
	{
		int i = 0;
		array<Byte>^ a2data;

		if (a2 == nullptr)
			return false;
		
		a2data = a2->GetName();

		if (name->Length != a2data->Length)
		{
			return false;
		}

		for (i = 0; i < name->Length; ++i)
		{
			if (name[i] != a2data[i])
			{
				return false;
			}
		}
		return true;
	}
	
	bool DOMAIN_NAME::IsSubDomain(DOMAIN_NAME^ sub)
	{
		DOMAIN_NAME^ tmp;

		int i = 0;

		if (sub == nullptr)
			return false;

		if (LabelCount()>= sub->LabelCount())  //if sub has less labels, it can't be a subdomain
		{
			return false;
		}

		tmp = sub->Clone();

		for (i = 0; i < sub->LabelCount() - LabelCount(); ++i)  //trim sub up to the same number of labels
			tmp = tmp->LeftChop();

		if (!IsEqual(tmp))  //see if they're equal.
		{
			return false;
		}

		return true;
	}

	List<array<Byte>^>^ DOMAIN_NAME::Split()
	{
		List<array<Byte>^>^ newdata = gcnew List<array<Byte>^>(0);
		int pos = 0;
		int len = 0;
		array<Byte>^ tmparray;
		
		while ((pos < name->Length) && (name[pos] > 0))
		{
			len = name[pos] + 1;
			tmparray = gcnew array<Byte>(len);
			name->Copy(name,pos,tmparray,0,len);
			newdata->Add(tmparray);
			pos += len;
		}

		return newdata;
	}

	DOMAIN_NAME^ DOMAIN_NAME::GetLabel(int labelpos)
	{
		List<array<Byte>^>^ dlist;

		dlist = Split();
		if (dlist->Count < labelpos)
			return nullptr;

		return gcnew DOMAIN_NAME(dlist[labelpos]);
	}

	int DOMAIN_NAME::Compare(DOMAIN_NAME^ rd)
	{
		int lc1 = 0;
		int lc2 = 0;
		List<array<Byte>^>^ rd1split;
		List<array<Byte>^>^ rd2split;
		array<Byte>^ a1; ;
		array<Byte>^ a2;
		int i = 0;

		/* see RFC4034 for this algorithm */
		/* this algorithm assumes the names are normalized to case */
		/* only when both are not NULL we can say anything about them */
		
		if (rd == nullptr)
				return -1;

		lc1 = LabelCount();
		lc2 = rd->LabelCount();

		if (lc1 == 0 && lc2 == 0) 
			return 0;
		
		if (lc1 == 0)
			return -1;
	
		if (lc2 == 0) 
			return 1;
		
		//sort by most significant domain label (right most) alphabetically.
		//if equal, move on to the next most significant, and so on.  Fewer 
		//labels means smaller.  

		rd1split = Split(); 
		rd2split = rd->Split();

		lc1--;
		lc2--;

		while ((lc1 >= 0) && (lc2 >= 0))
		{
			a1 = rd1split[lc1];
			a2 = rd2split[lc2];
			//compare array lengths - largest wins
			if (a1[0] > a2[0])
				return 1;
			if (a1[0] < a2[0])
				return -1;
			//if they're the same length, we need to go char by char to compare.
			for (i = 1; i < a1->Length; ++i)
			{
				if (Char::ToLower(a1[i]) > Char::ToLower(a2[i]))
					return 1;
				if (Char::ToLower(a1[i]) < Char::ToLower(a2[i]))
					return -1;
			}
			//apparently they're the same, loop to the next domain.
			
			--lc1;
			--lc2;
		}

		if (lc1 == lc2)
		{
			//if we're here, we're identical
			return 0;
		}
		if (lc1 > lc2)
			return 1;

		return -1;
	}		

	bool DOMAIN_NAME::MatchWildcard(DOMAIN_NAME^ wildcard)
	{
		DOMAIN_NAME^ chopped;
		array<Byte>^ wildcarddata;
		bool result;

		if (wildcard == nullptr)
			return false;

		wildcarddata = wildcard->GetName();
		if ((wildcard->LabelCount() > 0) &&
		    (wildcarddata[0] == 1) &&
		    (wildcarddata[1] == '*'))
		{
			/* ok, so the dname needs to be a subdomain of the wildcard
			 * without the *
			 */	
				chopped = wildcard->LeftChop();
				result = IsSubDomain(chopped);
		}
		else
		{
			if (Compare(wildcard) == 0)
				result = true;
			else
				result = false;
		}

		return result;
	}

	int	DOMAIN_NAME::InInterval(DOMAIN_NAME^ start, DOMAIN_NAME^ end)
	{
		int prev_check, next_check;

		prev_check = Compare(start);
		next_check = Compare(end);
		/* <= next. This cannot be the case for nsec, because then we would
		 * have gotten the nsec of next...
		 */
		if (next_check == 0) {
			return 0;
		}

				/* <= */
		if ((prev_check == 1 || prev_check == 0) &&
				/* < */
				next_check == -1) {
			return -1;
		} else {
			return 1;
		}
	}

	Void DOMAIN_NAME::MakeCanonical()
	{
		int pos = 0;
		int i = 0;

		while ((pos < name->Length) && (name[pos] > 0))
		{
			for (i = 1; i < name[pos]; ++i)
				name[pos + i] = tolower(name[pos + i]);
			pos += name[pos] + 1;
		}

	}

	String^ DOMAIN_NAME::Print()
	{
		ASCIIEncoding^ ae = gcnew ASCIIEncoding();
		int i;
		String^ output = gcnew String("");
		List<array<Byte>^>^ dlist;

		if (!name)
			return gcnew String("");

		if (name->Length == 1)
			return gcnew String(".");

		dlist = Split();
		for (i = 0; i < dlist->Count; ++i)
		{
			output += ae->GetString(dlist[i],1,dlist[i]->Length - 1);  //skipping the length byte
			output += ".";
		}

		return output;
	}
				
	UInt32 DOMAIN_NAME::Size()
	{
		return name->Length;
	}
		

}