Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

ShAlgebra.cpp

00001 // Sh: A GPU metaprogramming language.
00002 //
00003 // Copyright (c) 2003 University of Waterloo Computer Graphics Laboratory
00004 // Project administrator: Michael D. McCool
00005 // Authors: Zheng Qin, Stefanus Du Toit, Kevin Moule, Tiberiu S. Popa,
00006 //          Michael D. McCool
00007 // 
00008 // This software is provided 'as-is', without any express or implied
00009 // warranty. In no event will the authors be held liable for any damages
00010 // arising from the use of this software.
00011 // 
00012 // Permission is granted to anyone to use this software for any purpose,
00013 // including commercial applications, and to alter it and redistribute it
00014 // freely, subject to the following restrictions:
00015 // 
00016 // 1. The origin of this software must not be misrepresented; you must
00017 // not claim that you wrote the original software. If you use this
00018 // software in a product, an acknowledgment in the product documentation
00019 // would be appreciated but is not required.
00020 // 
00021 // 2. Altered source versions must be plainly marked as such, and must
00022 // not be misrepresented as being the original software.
00023 // 
00024 // 3. This notice may not be removed or altered from any source
00025 // distribution.
00027 #include <map>
00028 #include <string>
00029 #include <sstream>
00030 #include <algorithm>
00031 #include "ShAlgebra.hpp"
00032 #include "ShCtrlGraph.hpp"
00033 #include "ShDebug.hpp"
00034 #include "ShError.hpp"
00035 #include "ShOptimizations.hpp"
00036 #include "ShInternals.hpp"
00037 #include "ShEnvironment.hpp"
00038 #include "ShContext.hpp"
00039 #include "ShManipulator.hpp"
00040 #include "ShFixedManipulator.hpp"
00041 
00042 namespace SH {
00043 
00044 ShProgram connect(ShProgram pa, ShProgram pb)
00045 {
00046   ShProgramNodePtr a = pa.node();
00047   ShProgramNodePtr b = pb.node();
00048   
00049   if( !a || !b ) SH_DEBUG_WARN( "Connecting with a null ShProgram" );
00050   if( !a ) return b;
00051   if( !b ) return a;
00052   
00053   int aosize = a->outputs.size();
00054   int bisize = b->inputs.size();
00055   std::string rtarget;
00056 
00057   if (a->target().empty()) {
00058     rtarget = b->target(); // A doesn't have a target. Use b's.
00059   } else {
00060     if (b->target().empty() || a->target() == b->target()) {
00061       rtarget = a->target(); // A has a target, b doesn't
00062     } else {
00063       SH_DEBUG_WARN("Connecting two different targets. Using empty target for result.");
00064       rtarget = ""; // Connecting different targets.
00065     }
00066   }
00067 
00068   ShProgramNodePtr program = new ShProgramNode(rtarget);
00069 
00070   ShCtrlGraphNodePtr heada, taila, headb, tailb;
00071 
00072   a->ctrlGraph->copy(heada, taila);
00073   b->ctrlGraph->copy(headb, tailb);
00074 
00075   taila->append(headb);
00076 
00077   ShCtrlGraphPtr new_graph = new ShCtrlGraph(heada, tailb);
00078   program->ctrlGraph = new_graph;
00079 
00080   program->inputs = a->inputs;
00081 
00082 // push back extra inputs from b if aosize < bisize
00083   if(aosize < bisize) {
00084     ShProgramNode::VarList::const_iterator II = b->inputs.begin();
00085     for(int i = 0; i < aosize; ++i, ++II); 
00086     for(; II != b->inputs.end(); ++II) {
00087       program->inputs.push_back(*II);
00088     }
00089   }
00090   program->outputs = b->outputs;
00091 
00092   // push back extra outputs from a if aosize > bisize
00093   if(aosize > bisize) { 
00094     ShProgramNode::VarList::const_iterator II = a->outputs.begin();
00095     for(int i = 0; i < bisize; ++i, ++II); 
00096     for(; II != a->outputs.end(); ++II) {
00097       program->outputs.push_back(*II);
00098     }
00099   }
00100   
00101   ShVarMap varMap;
00102 
00103   ShContext::current()->enter(program);
00104   
00105   ShProgramNode::VarList::const_iterator I, J;  
00106 
00107   ShProgramNode::VarList InOutInputs;
00108   ShProgramNode::VarList InOutOutputs;
00109 
00110   // replace outputs and inputs connected together by temps 
00111   for (I = a->outputs.begin(), J = b->inputs.begin(); 
00112       I != a->outputs.end() && J != b->inputs.end(); ++I, ++J) { 
00113     if((*I)->size() != (*J)->size()) {
00114       std::ostringstream err;
00115       err << "Cannot smash variables "  
00116           << (*I)->nameOfType() << " " << (*I)->name() << " and " 
00117           << (*J)->nameOfType() << " " << (*J)->name() << " with different sizes" << std::endl;
00118       err << "while connecting outputs: ";
00119       ShProgramNode::print(err, a->outputs) << std::endl;
00120       err << "to inputs: ";
00121       ShProgramNode::print(err, b->inputs) << std::endl;
00122       ShContext::current()->exit();
00123       shError(ShAlgebraException(err.str()));
00124       return ShProgram(ShProgramNodePtr(0));
00125     }
00126     ShVariableNodePtr n = (*I)->clone(SH_TEMP);
00127     varMap[*I] = n;
00128     varMap[*J] = n;
00129 
00130     if((*I)->kind() == SH_INOUT) InOutInputs.push_back((*I)); 
00131     if((*J)->kind() == SH_INOUT) InOutOutputs.push_back((*J)); 
00132   }
00133 
00134   // Change connected InOut variables to either Input or Output only
00135   // (since they have been connected and turned into temps internally)
00136   ShCtrlGraphNodePtr graphEntry;
00137   for (I = InOutInputs.begin(); I != InOutInputs.end(); ++I) {
00138     if(!graphEntry) graphEntry = program->ctrlGraph->prependEntry();
00139     ShVariableNodePtr newInput((*I)->clone(SH_INPUT)); 
00140 
00141     std::replace(program->inputs.begin(), program->inputs.end(),
00142         (*I), newInput);
00143     program->inputs.pop_back();
00144 
00145     graphEntry->block->addStatement(ShStatement(
00146         ShVariable(varMap[*I]), SH_OP_ASN, ShVariable(newInput)));
00147   }
00148 
00149   ShCtrlGraphNodePtr graphExit;
00150   for (I = InOutOutputs.begin(); I != InOutOutputs.end(); ++I) {
00151     if(!graphExit) graphExit = program->ctrlGraph->appendExit();
00152     ShVariableNodePtr newOutput((*I)->clone(SH_OUTPUT));
00153     
00154     std::replace(program->outputs.begin(), program->outputs.end(),
00155         (*I), newOutput);
00156     program->outputs.pop_back();
00157 
00158     graphExit->block->addStatement(ShStatement(
00159         ShVariable(newOutput), SH_OP_ASN, ShVariable(varMap[*I])));
00160   }
00161 
00162   ShContext::current()->exit();
00163 
00164   ShVariableReplacer replacer(varMap);
00165   program->ctrlGraph->dfs(replacer);
00166 
00167   optimize(program);
00168   
00169   program->collectVariables();
00170   return program;
00171 }
00172 
00173 ShProgram combine(ShProgram pa, ShProgram pb)
00174 {
00175   ShProgramNodePtr a = pa.node();
00176   ShProgramNodePtr b = pb.node();
00177   
00178   std::string rtarget;
00179   if( !a || !b ) SH_DEBUG_WARN( "Connecting with a null ShProgram" );
00180   if (!a) return b;
00181   if (!b) return a;
00182 
00183   if (a->target().empty()) {
00184     rtarget = b->target(); // A doesn't have a target. Use b's.
00185   } else {
00186     if (b->target().empty() || a->target() == b->target()) {
00187       rtarget = a->target(); // A has a target, b doesn't
00188     } else { 
00189       rtarget = ""; // Connecting different targets.
00190     }
00191   }
00192 
00193   ShProgramNodePtr program = new ShProgramNode(rtarget);
00194 
00195   ShCtrlGraphNodePtr heada, taila, headb, tailb;
00196 
00197   a->ctrlGraph->copy(heada, taila);
00198   b->ctrlGraph->copy(headb, tailb);
00199 
00200   taila->append(headb);
00201 
00202   ShCtrlGraphPtr new_graph = new ShCtrlGraph(heada, tailb);
00203   program->ctrlGraph = new_graph;
00204 
00205   program->inputs = a->inputs;
00206   program->inputs.insert(program->inputs.end(), b->inputs.begin(), b->inputs.end());
00207   program->outputs = a->outputs;
00208   program->outputs.insert(program->outputs.end(), b->outputs.begin(), b->outputs.end());
00209 
00210   optimize(program);
00211  
00212   program->collectVariables();
00213   
00214   return program;
00215 }
00216 
00217 // Duplicates to inputs with matching name/type
00218 ShProgram mergeNames(ShProgram p)
00219 {
00220   typedef std::pair<std::string, int> InputType;
00221   typedef std::map< InputType, int > FirstOccurenceMap;  // position of first occurence of an input type
00222   typedef std::vector< std::vector<int> > Duplicates;
00223   FirstOccurenceMap firsts;
00224   // dups[i] stores the set of positions that have matching input types with position i.
00225   // The whole set is stored in the smallest i position.
00226   Duplicates dups( p.node()->inputs.size(), std::vector<int>()); 
00227 
00228   std::size_t i = 0;
00229   for(ShProgramNode::VarList::const_iterator I = p.node()->inputs.begin();
00230       I != p.node()->inputs.end(); ++I, ++i) {
00231     InputType it( (*I)->name(), (*I)->size() );
00232     if( firsts.find( it ) != firsts.end() ) { // duplicate
00233       dups[ firsts[it] ].push_back(i); 
00234     } else {
00235       firsts[it] = i;
00236       dups[i].push_back(i);
00237     }
00238   }
00239   std::vector<int> swizzle;
00240   ShFixedManipulator duplicator;
00241   for(i = 0; i < dups.size(); ++i) {
00242     if( dups[i].empty() ) continue;
00243     for(std::size_t j = 0; j < dups[i].size(); ++j) swizzle.push_back(dups[i][j]);
00244     if( duplicator ) duplicator = duplicator & shDup(dups[i].size());
00245     else duplicator = shDup(dups[i].size());
00246   }
00247   ShProgram result = p << shSwizzle(swizzle);
00248   if( duplicator ) result = result << duplicator;
00249   return result.node(); 
00250 }
00251 
00252 ShProgram namedCombine(ShProgram a, ShProgram b) {
00253   return mergeNames(combine(a, b));
00254 }
00255 
00256 ShProgram namedConnect(ShProgram pa, ShProgram pb, bool keepExtra)
00257 {
00258   ShProgramNodeCPtr a = pa.node();
00259   ShProgramNodeCPtr b = pb.node();
00260   // positions of a pair of matched a output and b input 
00261   typedef std::map<int, int> MatchedChannelMap; 
00262 
00263   std::vector<bool> aMatch(a->outputs.size(), false);
00264   std::vector<bool> bMatch(b->inputs.size(), false);
00265   MatchedChannelMap mcm;
00266   std::size_t i, j;
00267   ShProgramNode::VarList::const_iterator I, J;
00268 
00269   i = 0;
00270   for(I = a->outputs.begin(); I != a->outputs.end(); ++I, ++i) {
00271     j = 0;
00272     for(J = b->inputs.begin(); J != b->inputs.end(); ++J, ++j) {
00273       if(bMatch[j]) continue;
00274       if((*I)->name() != (*J)->name()) continue;
00275       if((*I)->size() != (*J)->size()) {
00276         SH_DEBUG_WARN("Named connect matched channel name " << (*I)->name() 
00277             << " but output size " << (*I)->size() << " != " << " input size " << (*J)->size() );
00278         continue;
00279       }
00280       mcm[i] = j;
00281       aMatch[i] = true;
00282       bMatch[j] = true;
00283     }
00284   }
00285 
00286   std::vector<int> swiz(b->inputs.size(), 0); 
00287   for(MatchedChannelMap::iterator mcmit = mcm.begin(); mcmit != mcm.end(); ++mcmit) {
00288     swiz[mcmit->second] = mcmit->first;
00289   }
00290 
00291   // swizzle unmatched inputs and make a pass them through properly
00292   ShProgram passer = SH_BEGIN_PROGRAM() {} SH_END;
00293   int newInputIdx = a->outputs.size(); // index of next new input added to a
00294   for(j = 0, J= b->inputs.begin(); J != b->inputs.end(); ++J, ++j) {
00295     if( !bMatch[j] ) {
00296       ShProgram passOne = SH_BEGIN_PROGRAM() {
00297         ShVariable var((*J)->clone(SH_INOUT));
00298       } SH_END;
00299       passer = passer & passOne; 
00300       swiz[j] = newInputIdx++;
00301     }
00302   }
00303   ShProgram aPass = combine(pa, passer);
00304 
00305   if( keepExtra ) {
00306     for(i = 0; i < aMatch.size(); ++i) {
00307       if( !aMatch[i] ) swiz.push_back(i);
00308     }
00309   }
00310    
00311   return mergeNames(pb << ( shSwizzle(swiz) << aPass )); 
00312 }
00313 
00314 ShProgram renameInput(ShProgram a,
00315                       const std::string& oldName, const std::string& newName) {
00316   ShProgram renamer = SH_BEGIN_PROGRAM() {
00317     for(ShProgramNode::VarList::const_iterator I = a.node()->inputs.begin();
00318         I != a.node()->inputs.end(); ++I) {
00319       ShVariable var((*I)->clone(SH_INOUT));
00320 
00321       if (!(*I)->has_name()) continue;
00322       std::string name = (*I)->name();
00323       if( name == oldName ) {
00324         var.name(newName);
00325       } else {
00326         var.name(name);
00327       }
00328     }
00329   } SH_END;
00330   return connect(renamer, a);
00331 }
00332 
00333 // TODO factor out common code from renameInput, renameOutput
00334 ShProgram renameOutput(ShProgram a,
00335                        const std::string& oldName, const std::string& newName) {
00336   ShProgram renamer = SH_BEGIN_PROGRAM() {
00337     for(ShProgramNode::VarList::const_iterator I = a.node()->outputs.begin();
00338         I != a.node()->outputs.end(); ++I) {
00339       ShVariable var((*I)->clone(SH_INOUT));
00340 
00341       if (!(*I)->has_name()) continue;
00342       std::string name = (*I)->name();
00343       if( name == oldName ) {
00344         var.name(newName);
00345       } else {
00346         var.name(name);
00347       }
00348     }
00349   } SH_END;
00350   return connect(a, renamer);
00351 }
00352 
00353 ShProgram namedAlign(ShProgram a, ShProgram b) {
00354   ShManipulator<std::string> ordering;
00355 
00356   for(ShProgramNode::VarList::const_iterator I = b.node()->inputs.begin();
00357       I != b.node()->inputs.end(); ++I) {
00358     ordering((*I)->name());
00359   }
00360 
00361   return ordering << a; 
00362 }
00363 
00364 ShProgram operator<<(ShProgram a, ShProgram b)
00365 {
00366   return connect(b,a);
00367 }
00368 
00369 ShProgram operator>>(ShProgram a, ShProgram b)
00370 {
00371   return connect(a,b);
00372 }
00373 
00374 ShProgram operator&(ShProgram a, ShProgram b)
00375 {
00376   return combine(a, b);
00377 }
00378 
00379 ShProgram operator>>(ShProgram p, const ShVariable &var) { 
00380   return replaceUniform(p, var);
00381 }
00382 
00383 ShProgram replaceUniform(ShProgram a, const ShVariable& v)
00384 {
00385   if(!v.uniform()) {
00386     shError(ShAlgebraException("Cannot replace non-uniform variable"));
00387   }
00388 
00389   ShProgram program(a.node()->clone()); 
00390   
00391   ShVarMap varMap;
00392 
00393   ShContext::current()->enter(program.node());
00394 
00395   // make a new input
00396   ShVariableNodePtr newInput(v.node()->clone(SH_INPUT)); 
00397   varMap[v.node()] = newInput;
00398 
00399   ShContext::current()->exit();
00400 
00401   ShVariableReplacer replacer(varMap);
00402   program.node()->ctrlGraph->dfs(replacer);
00403 
00404   optimize(program);
00405   
00406   program.node()->collectVariables();
00407 
00408   return program;
00409 }
00410 
00411 }

Generated on Mon Jan 24 18:36:29 2005 for Sh by  doxygen 1.4.1