{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Minimax Linkage in C\n", "\n", "* This is part of the C code of [protoclust](https://cran.r-project.org/web/packages/protoclust/index.html), an R package that implements the minimax linkage. \n", "* 這份 code 要跑起來要:\n", " 1. 把所有 `Rprintf` 取代成 `printf`\n", " 2. `#include ` \n", " 3. 自己寫 `rsort_with_index()` 和 `R_isort()`\n", " * `void R_isort (int* x, int n)` 就只是 sort on integers\n", " * `void rsort_with_index (double* x, int* index, int n)` sorts on x, and applies the same permutation to index. NAs are sorted last,看[這裡](https://colinfay.me/writing-r-extensions/the-r-api-entry-points-for-c-code.html)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "/*\n", "Minimax linkage agglomerative clustering.\n", "*/\n", "\n", "// #include \n", "#include \n", "#define BIGGEST 1e200\n", "#define min(a,b) (((a) < (b)) ? (a) : (b))\n", "#define max(a,b) (((a) > (b)) ? (a) : (b))\n", "#define lt(i,j,n) n*j - j*(j+1)/2 + i - j - 1 // note: must have i > j\n", "\n", "typedef struct node {\n", " int i;\n", " struct node *next;\n", "} Cluster;\n", "\n", "void printLT(double *d,int n,int* clustLabel);\n", "void printCluster(Cluster * G);\n", "void printMatrix(double *dmax,int n,int *clustLabel);\n", "double maxlink(double * d, int n, Cluster * G, Cluster * H);\n", "int findNN(double *dd, int n, int *clustLabel, double *dmax, Cluster** clusters, int *inchain, int i);\n", "double minimaxlink(double *dmax, int n,\n", " Cluster *G, Cluster *tG, int iG, Cluster *H, int iH, int nGH);\n", "int minimaxpoint(double *dmax, int n, Cluster *G, int iG, int nG);\n", "double completelink(double *dmax, int n,Cluster *G, int iG, Cluster *H, int iH);\n", "\n", "void hier(double *d, int *ndim, int *verbose, int *merge, double *height, int *order, int *protos) {\n", " int i, j, k, ii, imerge, jmerge, reverse;\n", " int n = *ndim;\n", " double *dd = malloc((n*(n-1)/2)*sizeof(double));\n", " // dd contains the inter-cluster distances as a lower triangular matrix\n", " \n", " double *dmax = malloc((n*n)*sizeof(double));\n", " // dmax[n*i + j] = max_{l\\in C_j}d(i,l) where C_j is the jth cluster\n", "\n", " int *un_protos = malloc(n*sizeof(int));\n", " int *un_merge = malloc((n*2)*sizeof(int));\n", " // versions of the outputs \"protos\" and \"merge\"... except unordered by height\n", " // these will be ordered based on the nearest-neighbor chain order\n", " // at the end, we will convert from un_protos to protos and likewise for merge\n", "\n", "\n", " // each cluster will be stored as a linked list between clusters[i] and tails[i]\n", " Cluster** clusters = malloc(n*sizeof(Cluster *));\n", " Cluster** tails = malloc(n*sizeof(Cluster *));\n", " \n", " int clustLabel[n];\n", " int clustSize[n];\n", "\n", " int nnchain[n];\n", " int nn,end = -1; // index of end of nnchain\n", " int inchain[n];// indicates whether element i is in current chain\n", "\n", " double old;\n", " int nochange=0;\n", "\n", " // initialize n singleton clusters\n", " for(i = 0; i < n; i++) {\n", " clusters[i] = malloc(sizeof(Cluster));\n", " tails[i] = clusters[i];\n", " clusters[i]->i = i;\n", " clusters[i]->next = NULL;\n", " clustLabel[i] = -(i+1);// leaves are negative (as in R's hclust)\n", " clustSize[i] = 1;\n", " nnchain[i] = -1;\n", " inchain[i] = 0;\n", " }\n", " for(i = 0; i < n; i++)\n", " dmax[n*i + i] = 0;\n", " for(j = 0; j < n-1; j++) {\n", " for(i = j+1; i < n; i++) {\n", " ii = lt(i,j,n);\n", " dmax[n*i + j] = d[ii];\n", " dmax[n*j + i] = d[ii];\n", " }\n", " }\n", " \n", " for(ii = 0; ii < n*(n-1)/2; ii++)\n", " dd[ii] = d[ii];\n", " \n", " for(k = 0; k < n-1; k++) {\n", " // check if chain is empty\n", " if(end == -1) {\n", " // start a new chain\n", " for(i = 0; i < n; i++)\n", " if(clustLabel[i] != 0)\n", " break;\n", " nnchain[0] = i;\n", " if(*verbose)\n", " Rprintf(\"\\nStarting a new chain at %d.\\n\",i+1);\n", " end = 0;\n", " }\n", " \n", " // grow nearest neighbor chain until a RNN pair found:\n", " while(TRUE) {\n", " nn = findNN(dd,n,clustLabel,dmax,clusters,inchain,nnchain[end]);\n", " if(end > 0) {\n", " if(nn == nnchain[end-1]) {\n", " // reached a RNN pair\n", " break;\n", " }\n", " inchain[nnchain[end-1]] = 1;\n", " // we purposely delay this update to allow findNN to \n", " // select a cluster visited a step earlier\n", " // (which happens for RNN pairs)\n", " }\n", " nnchain[++end] = nn;\n", " }\n", " if(*verbose) {\n", " Rprintf(\" NN-Chain: \");\n", " for(i = 0; i <= end; i++)\n", " Rprintf(\"%d \",nnchain[i]+1);\n", " Rprintf(\"\\n\");\n", " }\n", " if(nnchain[end] < nnchain[end-1]) {\n", " imerge = nnchain[end-1];\n", " jmerge = nnchain[end];\n", " }\n", " else {\n", " imerge = nnchain[end];\n", " jmerge = nnchain[end-1];\n", " }\n", " \n", " // remove RNN pair from chain\n", " inchain[nnchain[end]] = 0;\n", " inchain[nnchain[end-1]] = 0;\n", " \n", " if(end > 1)\n", " inchain[nnchain[end-2]] = 0;// again, so that a RNN can be detected\n", " if(end > 2)\n", " inchain[nnchain[end-3]] = 0;// again, so that a RNN can be detected\n", " \n", " nnchain[end] = -1;\n", " nnchain[end-1] = -1;\n", " end -= 2;\n", " \n", " // create merged cluster from this RNN pair:\n", " ii = lt(imerge,jmerge,n);\n", " height[k] = dd[ii];\n", " if(*verbose)\n", " Rprintf(\" Merged reciprocal nearest neighbor pair at height %g\\n\",height[k]);\n", " reverse = 0;\n", " \n", " if(clustLabel[imerge] > clustLabel[jmerge])\n", " reverse = 1;// put smaller cluster on left\n", " \n", " if(clustLabel[imerge] < 0 && clustLabel[jmerge] < 0)\n", " reverse = 1; // since imerge > jmerge\n", " \n", " if(reverse) {\n", " un_merge[2*k] = clustLabel[jmerge];\n", " un_merge[2*k+1] = clustLabel[imerge];\n", " }\n", " else {\n", " un_merge[2*k] = clustLabel[imerge];\n", " un_merge[2*k+1] = clustLabel[jmerge];\n", " }\n", " \n", " // update the imerge cluster:\n", " clustLabel[imerge] = k + 1;\n", " clustSize[imerge] += clustSize[jmerge];\n", " \n", " if(reverse) {\n", " // put jmerge's elements to left of imerge's\n", " tails[jmerge]->next = clusters[imerge];\n", " clusters[imerge] = clusters[jmerge];\n", " }\n", " else {\n", " // put imerge's elements to left of jmerge's\n", " tails[imerge]->next = clusters[jmerge];\n", " tails[imerge] = tails[jmerge];\n", " }\n", " \n", " // at jmerge, we no longer have a cluster:\n", " clustLabel[jmerge] = 0;\n", " clustSize[jmerge] = 0;\n", " clusters[jmerge] = NULL;\n", " tails[jmerge] = NULL;\n", " \n", " /// update dmax:\n", " for(i = 0; i < n; i++) {\n", " // point i\n", " if(dmax[n*i+imerge] < dmax[n*i+jmerge]) {\n", " dmax[n*i+imerge] = dmax[n*i+jmerge];//i.e. max{dmax(i,imerge),dmax(i,jmerge)}\n", " }\n", " }\n", " \n", " // get the minimax prototype for this newly formed cluster\n", " un_protos[k] = minimaxpoint(dmax,n,clusters[imerge],imerge,clustSize[imerge]) + 1;\n", " \n", " /// update dd:\n", " \n", " // imerge is now a new cluster, so update its distances:\n", " for(j = 0; j < imerge; j++)\n", " if(clustLabel[j] != 0) { // still an active cluster\n", " ii = lt(imerge,j,n);\n", " old = dd[ii];\n", " dd[ii] = minimaxlink(dmax,n,clusters[j],tails[j],j,clusters[imerge],imerge,\n", " clustSize[j]+clustSize[imerge]);\n", " if(dd[ii]==old)\n", " nochange++;\n", " }\n", " \n", " ii = lt(imerge+1,imerge,n);\n", " for(i = imerge + 1; i < n; i++) {\n", " if(clustLabel[i] != 0) {\n", " old = dd[ii];\n", " dd[ii] = minimaxlink(dmax,n,clusters[i],tails[i],i,clusters[imerge],imerge,\n", " clustSize[i] + clustSize[imerge]);\n", " if(dd[ii]==old)\n", " nochange++;\n", " }\n", " ii++;\n", " }\n", " }\n", "\n", " //// List merges and protos in order of increasing height:\n", " int o[n-1];\n", " for(i = 0; i < n-1; i++)\n", " o[i] = i;\n", " \n", " // sort heights and \"o = order(height)\" (in R speak)\n", " rsort_with_index(height,o,n-1);\n", "\n", " // if there are ties, want indices ordered (to match R's convention).\n", " int count;\n", " \n", " for(i = 0; i < n-1; i++) {\n", " count = 0;\n", " if((i+count+1)<(n-1))\n", " while(height[i+count+1]==height[i])\n", " count++;\n", " // heights are constant from i to i+count\n", " if(count > 0) {\n", " R_isort(&o[i],count+1);\n", " }\n", " i += count;\n", " }\n", "\n", " for(i = 0; i < n-1; i++) {\n", " protos[i] = un_protos[o[i]];\n", " merge[2*i] = un_merge[2*o[i]];\n", " merge[2*i+1] = un_merge[2*o[i]+1];\n", " }\n", "\n", " // shuffling merge rows around messes up the positive indices in merge:\n", " int ranks[n-1];\n", " for(i = 0; i < n-1; i++)\n", " ranks[o[i]] = i;\n", "\n", " int mtemp;\n", " for(i = 0; i < n-1; i++) {\n", " for(j = 0; j < 2; j++) {\n", " if(merge[2*i+j] > 0)\n", " merge[2*i+j] = ranks[merge[2*i+j]-1] + 1;\n", " //the -1 and +1 are to match R's indexing\n", " }\n", " if(merge[2*i] > 0 && merge[2*i+1] > 0)\n", " if(merge[2*i] > merge[2*i+1]) {\n", " // hclust has positive rows in increasing order:\n", " mtemp = merge[2*i];\n", " merge[2*i] = merge[2*i+1];\n", " merge[2*i+1] = mtemp;\n", " }\n", " }\n", " // get order by following \"merge\"\n", "\n", " // using ranks for a different purpose... ranks[k] will be s.t. clusters[ranks[k]] contains\n", " // the cluster created at step k.\n", " for(k=0; k < n-1; k++)\n", " ranks[k] = 0;\n", "\n", " Cluster *cur = clusters[imerge];\n", " for(i = 0; i < n; i++) {\n", " clusters[cur->i] = cur;\n", " tails[cur->i] = clusters[cur->i];\n", " cur = cur->next;\n", " }\n", " for(i = 0; i < n; i++)\n", " clusters[i]->next = NULL;\n", "\n", " for(k = 0; k < n-1; k++) {\n", " if(merge[2*k] < 0)\n", " imerge = -merge[2*k] - 1;\n", " else\n", " imerge = ranks[merge[2*k] - 1];\n", " if(merge[2*k+1] < 0)\n", " jmerge = -merge[2*k+1] - 1;\n", " else\n", " jmerge = ranks[merge[2*k+1] - 1];\n", " \n", " tails[imerge]->next = clusters[jmerge];\n", " tails[imerge] = tails[jmerge];\n", " clusters[jmerge] = NULL;\n", " tails[jmerge] = NULL;\n", " ranks[k] = imerge;\n", " }\n", " \n", " cur = clusters[imerge];\n", " for(i = 0; i < n; i++) {\n", " order[i] = cur->i+1;\n", " cur = cur->next;\n", " }\n", "\n", " cur = clusters[imerge];\n", " Cluster *curnext;\n", " for(i = 0; i < n; i++) {\n", " curnext = cur->next;\n", " free(cur);\n", " cur = curnext;\n", " }\n", " free(dd);\n", " free(dmax);\n", " free(un_protos);\n", " free(un_merge);\n", " free(clusters);\n", " free(tails);\n", "}\n", "\n", "// returns the nearest neighbor cluster of cluster i that is not \n", "// already in the chain.\n", "int findNN(double *dd, int n, int *clustLabel, double *dmax, Cluster **clusters, int *inchain, int i)\n", "{\n", " int j,ii;\n", " double mind = BIGGEST;\n", " double dcomp, mincomplete = 0;\n", " int nn;\n", " for(j = 0; j < i; j++)\n", " {\n", " if(clustLabel[j] == 0 || inchain[j]==1)\n", " continue;\n", " ii = lt(i,j,n);\n", " \n", " if(dd[ii] < mind)\n", " {\n", " mind = dd[ii];\n", " nn = j;\n", " mincomplete = 0;// reset mincomplete\n", " }\n", " else if(dd[ii]==mind)\n", " {\n", " if(mincomplete==0)\n", " {\n", " // this is the first duplicate\n", " mincomplete = completelink(dmax, n, clusters[nn], nn, clusters[i], i);\n", " }\n", " dcomp = completelink(dmax, n, clusters[j], j, clusters[i], i);\n", " if(dcomp < mincomplete)\n", " {\n", " mincomplete = dcomp;\n", " nn = j;\n", " }\n", " }\n", " }\n", " \n", " for(j = i+1; j < n; j++)\n", " {\n", " if(clustLabel[j] == 0 || inchain[j]==1)\n", " continue;\n", " ii = lt(j,i,n);\n", " if(dd[ii] < mind)\n", " {\n", " mind = dd[ii];\n", " nn = j;\n", " mincomplete = 0;// reset mincomplete\n", " }\n", " else if(dd[ii]==mind)\n", " {\n", " if(mincomplete==0)\n", " {\n", " // this is the first duplicate\n", " mincomplete = completelink(dmax, n, clusters[nn], nn, clusters[i], i);\n", " }\n", " dcomp = completelink(dmax, n, clusters[j], j, clusters[i], i);\n", " if(dcomp < mincomplete)\n", " {\n", " mincomplete = dcomp;\n", " nn = j;\n", " }\n", " }\n", " }\n", "\n", " return nn;\n", "}\n", "\n", "\n", "// Returns the minimax distance\n", "double minimaxlink(double *dmax, int n,\n", " Cluster *G, Cluster *tG, int iG, Cluster *H, int iH,int nGH)\n", "{\n", " //printf(\"Inside dmax: (%d,%d)\\n\",iG,iH);\n", " //printCluster(G);\n", " //printCluster(H);\n", "\n", " // temporarily combine clusters\n", " tG->next = H;\n", " int i;\n", " double dmm;\n", " double dmaxGH[nGH];\n", "\n", " Cluster *cur1;\n", " cur1 = G;\n", " for(i = 0; i < nGH; i++)\n", " {\n", " if(dmax[n*cur1->i + iG] > dmax[n*cur1->i + iH])\n", " dmaxGH[i] = dmax[n*cur1->i + iG];\n", " else\n", " dmaxGH[i] = dmax[n*cur1->i + iH];\n", " cur1 = cur1->next;\n", " }\n", "\n", " dmm = BIGGEST;\n", " for(i = 0; i < nGH; i++)\n", " {\n", " if(dmaxGH[i] < dmm)\n", " dmm = dmaxGH[i];\n", " }\n", " \n", " // uncombine the clusters\n", " tG->next = NULL;\n", " return dmm;\n", "}\n", "\n", "// Finds the minimax point of the cluster G\n", "int minimaxpoint(double *dmax, int n, Cluster *G, int iG, int nG)\n", "{\n", " int mm;\n", " Cluster *cur1 = G;\n", " double dmm = BIGGEST;\n", " int i;\n", " for(i = 0; i < nG; i++)\n", " {\n", " // printf(\"dmax[%d,%d]=%.2g\\n\",cur1->i,iG,dmax[n*cur1->i+iG]);\n", " if(dmax[n*cur1->i+iG] < dmm)\n", " {\n", " dmm = dmax[n*cur1->i+iG];\n", " mm = cur1->i;\n", " } \n", " cur1 = cur1->next;\n", " }\n", " return mm;\n", "} \n", "\n", "// complete linkage d(G,H)\n", "double completelink(double *dmax, int n, Cluster *G, int iG, Cluster *H, int iH)\n", "{\n", " //printf(\"\\nComplete linkage:\\n\");\n", " //printCluster(G);\n", " //printCluster(H);\n", "\n", " double dmm = 0;\n", " Cluster *cur = G;\n", " while(cur != NULL)\n", " {\n", " // dmax(g,H)\n", " if(dmax[n*cur->i + iH] > dmm)\n", " dmm = dmax[n*cur->i + iH];\n", " cur = cur->next;\n", " }\n", " cur = H;\n", " while(cur != NULL)\n", " {\n", " // dmax(h,G)\n", " if(dmax[n*cur->i + iG] > dmm)\n", " dmm = dmax[n*cur->i + iG];\n", " cur = cur->next;\n", " }\n", " // printf(\"%.8g\\n\",dmm);\n", " return dmm;\n", "}\n", "\n", "// prints a lower triangular matrix\n", "void printLT(double *d,int n,int* clustLabel)\n", "{\n", " int i,j;\n", " \n", " for(j = 0; j < n; j++)\n", " Rprintf(\"\\t%d\",j);\n", " Rprintf(\"\\n\");\n", " for(i = 1; i < n; i++)\n", " {\n", " Rprintf(\"%d\\t\",i);\n", " if(clustLabel[i]==0)\n", " {\n", " for(j = 0; j < i; j++)\n", " Rprintf(\"*\\t\");\n", " Rprintf(\"\\n\");\n", " continue;\n", " }\n", " for(j = 0; j < i; j++)\n", " {\n", " if(clustLabel[j]==0)\n", " Rprintf(\"*\\t\");\n", " else\n", " Rprintf(\"%.2g\\t\",d[lt(i,j,n)]);\n", " }\n", " Rprintf(\"\\n\");\n", " }\n", "}\n", "\n", "void printCluster(Cluster * G)\n", "{\n", " while(G!=NULL)\n", " {\n", " Rprintf(\"%d\\t\",G->i);\n", " G = G->next;\n", " }\n", " Rprintf(\"\\n\");\n", "}\n", "\n", "void printMatrix(double *dmax,int n,int *clustLabel)\n", "{\n", " int i, j;\n", " for(i = 0; i < n; i++)\n", " {\n", " for(j = 0; j < n; j++)\n", " {\n", " if(clustLabel[j]==0)\n", " Rprintf(\"*\\t\");\n", " else\n", " Rprintf(\"%.2g\\t\",dmax[n*i + j]);\n", " }\n", " Rprintf(\"\\n\");\n", " }\n", "}\n", "\n", "/*\n", "int main(int argc, char** argv)\n", "{\n", " return 1;\n", "}\n", "*/\n" ] } ], "metadata": { "kernelspec": { "display_name": "C++17", "language": "C++17", "name": "xcpp17" }, "language_info": { "codemirror_mode": "text/x-c++src", "file_extension": ".cpp", "mimetype": "text/x-c++src", "name": "c++", "version": "17" } }, "nbformat": 4, "nbformat_minor": 4 }