#include <iostream>

using namespace std;



void
factorial (double f [] /* d + 1 elements */, int d)
{
    for (int i = 0; i <= d; i++)
    {
        if (i == 0)
        {
            f [i] = 1;
        }
        else
        {
            f [i] = f [i-1] * i;
        }
    }
}


// compute C (d, i) for 0 < i <= d
void
combinations (double c [] /* d + 1 elements */, int d)
{
    double f [d + 1];

    factorial (f, d);

    for (int i = 0; i <= d; i++)
    {
        // C (d, i) 
        c [i] = f [d] / f [i] / f [d - i];
    }
}

int
main (int argc, char* argv [])
{
    FILE* fp1;
    FILE* fp2;
    int m;
    int d;
    double C [100];
    int s;
    char afile [64];
    char pfile [64];
    double mc [64 + 1] = {0};
    double mc2 [64 + 1] = {0};

    if (argc != 3)
    {
        cerr << "arrangements <m> <d>" << endl;
        return -1;
    }

    m = atoi (argv [1]);
    d = atoi (argv [2]);

    sprintf (pfile, "../partitions/p%d", m);
    cout << "reading partition file " << pfile << " for m = " << m << endl;
    fp1 = fopen (pfile, "r");

    sprintf (afile, "m%dd%d", m, d);
    cout << "generating m = " << m << " d = " << d << " in file " << afile << endl;
    fp2 = fopen (afile, "w");

    fprintf (fp2, "m=%d d=%d\n", m, d);

    combinations (C, d);

    // if s slots are available, how to arrange m nodes
    // if the depth of each slot is d

    for (s = 0; s <= m; s++)
    {
        double n;
	double nl;

        if (s == 0)
        {
            n = 0;
	    nl = 0;
        }
        else
        {
            n = 0; // n == 0 signals no arrangements
	    nl = 0;

            rewind (fp1);

            for (; !feof (fp1); )
            {
                int p [100] = {0};
                int i, r, rr;
                int t;
                char str [256] = {0};
                
                fgets (str, sizeof (str), fp1);
                
                // load all members into p1
                
                for (rr = 0, i = 0, r = 0;
                     sscanf (str + r, "%d%n", &p [i], &rr) == 1;
                     i++, r += rr) 
                    ;
                
                t = i; // number of terms
                
                // skip blank lines
                
                if (t == 0) continue;

                // find the partitions of m with only i terms
                
                if (t != s) continue;

                bool bad;
		bool dfound;
                
                bad = false;
		dfound = false;

                for (int j = 0; j < s; j++) 
                {
                    if (p [j] > d)
                    {
                        bad = true;
                        break;
                    }

		    if (p [j] == d &&
			!dfound)
		    {
			dfound = true;
		    }
                }
               
	        if (bad) continue;
		
		double n1 = 1;
		
		for (int j = 0; j < s; j++) 
		{
		    n1 *= C [p [j]];
		}
		
		if (dfound)
		{
		    nl += n1;
		}
		else
		{
		    n += n1;
		}	
            }
        }

	mc [s] = n;
	mc2 [s] = nl;
    }

    for (s = 0; s <= m; s++)
    {
        fprintf (fp2, "%d = %lg %lg\n", s, mc [s], mc2 [s]);
    }

    fclose (fp1);
    fclose (fp2);
}
