Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

ShTransformer.cpp

00001 // Sh: A GPU metaprogramming language.
00002 //
00003 // Copyright (c) 2003 University of Waterloo Computer Graphics Laboratory
00004 // Project administrator: Michael D. McCool
00005 // Authors: Zheng Qin, Stefanus Du Toit, Kevin Moule, Tiberiu S. Popa,
00006 //          Michael D. McCool
00007 // 
00008 // This software is provided 'as-is', without any express or implied
00009 // warranty. In no event will the authors be held liable for any damages
00010 // arising from the use of this software.
00011 // 
00012 // Permission is granted to anyone to use this software for any purpose,
00013 // including commercial applications, and to alter it and redistribute it
00014 // freely, subject to the following restrictions:
00015 // 
00016 // 1. The origin of this software must not be misrepresented; you must
00017 // not claim that you wrote the original software. If you use this
00018 // software in a product, an acknowledgment in the product documentation
00019 // would be appreciated but is not required.
00020 // 
00021 // 2. Altered source versions must be plainly marked as such, and must
00022 // not be misrepresented as being the original software.
00023 // 
00024 // 3. This notice may not be removed or altered from any source
00025 // distribution.
00027 #include <algorithm>
00028 #include <map>
00029 #include <list>
00030 #include "ShContext.hpp"
00031 #include "ShError.hpp"
00032 #include "ShDebug.hpp"
00033 #include "ShVariant.hpp"
00034 #include "ShVariableNode.hpp"
00035 #include "ShInternals.hpp"
00036 #include "ShTransformer.hpp"
00037 
00038 namespace SH {
00039 
00040 ShTransformer::ShTransformer(const ShProgramNodePtr& program)
00041   : m_program(program), m_changed(false)
00042 {
00043 }
00044 
00045 ShTransformer::~ShTransformer()
00046 {
00047 }
00048 
00049 bool ShTransformer::changed()  { return m_changed; }
00050 
00051 // Variable splitting, marks statements for which some variable is split 
00052 struct VariableSplitter {
00053 
00054   VariableSplitter(int maxTuple, ShTransformer::VarSplitMap& splits, bool& changed)
00055     : maxTuple(maxTuple), splits(splits), changed(changed) {}
00056 
00057   void operator()(ShCtrlGraphNodePtr node) {
00058     if (!node) return;
00059     ShBasicBlockPtr block = node->block;
00060     if (!block) return;
00061     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00062       splitVars(*I);
00063     }
00064   }
00065 
00066   // this must be called BEFORE running a DFS on the program
00067   // to split temporaries (otherwise the stupid hack marked (#) does not work)
00068   void splitVarList(ShProgramNode::VarList &vars) {
00069     for(ShProgramNode::VarList::iterator I = vars.begin();
00070         I != vars.end();) {
00071       if(split(*I)) {
00072         // (#) erase the stuff that split added to the end of the var list
00073         // TODO check if this is actually no longer needed. 
00074         //vars.resize(vars.size() - splits[*I].size());
00075 
00076         vars.insert(I, splits[*I].begin(), splits[*I].end());
00077         I = vars.erase(I); 
00078       } else ++I;
00079     }
00080   }
00081   
00082   void splitVars(ShStatement& stmt) {
00083     stmt.marked = false;
00084     if(stmt.dest.node()) stmt.marked = split(stmt.dest.node()) || stmt.marked;
00085     for(int i = 0; i < 3; ++i) if(stmt.src[i].node()) stmt.marked = split(stmt.src[i].node()) || stmt.marked; 
00086   }
00087 
00088   // returns true if variable split
00089   // does not add variable to Program's VarList, so this must be handled manually 
00090   // (since this matters only for IN/OUT/INOUT types, splitVarList handles the
00091   // insertions nicely)
00092   bool split(ShVariableNodePtr node)
00093   {
00094     int i, offset;
00095     int n = node->size();
00096     if(n <= maxTuple ) return false;
00097     else if(splits.count(node) > 0) return true;
00098     if( node->kind() == SH_TEXTURE || node->kind() == SH_STREAM ) {
00099       shError( ShTransformerException(
00100             "Long tuple support is not implemented for textures or streams"));
00101             
00102     }
00103     changed = true;
00104     ShTransformer::VarNodeVec &nodeVarNodeVec = splits[node];
00105     ShVariableNodePtr newNode;
00106     int* copySwiz = new int[maxTuple];
00107     for(offset = 0; n > 0; offset += maxTuple, n -= maxTuple) {
00108       ShProgramNodePtr prev = ShContext::current()->parsing();
00109       // @todo type should not be necessary any more
00110       //if(node->uniform()) ShContext::current()->exit(); 
00111 
00112       int newSize = n < maxTuple ? n : maxTuple;
00113       newNode = node->clone(SH_BINDINGTYPE_END, newSize, 
00114           SH_VALUETYPE_END, SH_SEMANTICTYPE_END, false); 
00115 
00116       // @todo type should not be necessary any more
00117       // if(node->uniform()) ShContext::current()->enter(0);
00118 
00119       if( node->hasValues() ) { 
00120         // @todo type set up dependent uniforms here 
00121         for(i = 0; i < newSize; ++i) copySwiz[i] = offset + i;
00122         ShVariantCPtr subVariant = node->getVariant()->get(false,
00123             ShSwizzle(node->size(), newSize, copySwiz));
00124         newNode->setVariant(subVariant);
00125       }
00126       nodeVarNodeVec.push_back( newNode );
00127     }
00128         delete [] copySwiz;
00129     return true;
00130   }
00131 
00132   int maxTuple;
00133   ShTransformer::VarSplitMap &splits;
00134   bool& changed;
00135 };
00136 
00137 struct StatementSplitter {
00138   typedef std::vector<ShVariable> VarVec;
00139 
00140   StatementSplitter(int maxTuple, ShTransformer::VarSplitMap &splits, bool& changed)
00141     : maxTuple(maxTuple), splits(splits), changed(changed) {}
00142 
00143   void operator()(ShCtrlGraphNodePtr node) {
00144     if (!node) return;
00145     ShBasicBlockPtr block = node->block;
00146     if (!block) return;
00147     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end();) {
00148       splitStatement(block, I);
00149     }
00150   }
00151 
00152   void makeSrcTemps(const ShVariable &v, VarVec &vv, ShBasicBlock::ShStmtList &stmts) {
00153     if( v.size() <= maxTuple && v.node()->size() <= maxTuple ) {
00154       vv.push_back(v);
00155       return;
00156     }
00157     std::size_t i, j, k;
00158     int n;
00159     const ShSwizzle &swiz = v.swizzle();
00160     
00161     // get VarNodeVec for src
00162     ShTransformer::VarNodeVec srcVec;
00163     if(splits.count(v.node()) > 0) {
00164       srcVec = splits[v.node()];
00165     } else srcVec.push_back(v.node());
00166 
00167     // make and assign to a VarVec for temps
00168     for(i = 0, n = v.size(); n > 0; i += maxTuple, n -= maxTuple) {
00169       std::size_t tsize = (int)n < maxTuple ? n : maxTuple;
00170       //TODO  make this smarter so that it reuses variable nodes if it's just reswizlling a src node
00171       // (check that move elimination doesn't do this for us already)
00172       
00173       // TODO check that uniforms don't get screwed
00174       // TODO check that typing works correctly - temporary should
00175       // have same type as the statement's operation type
00176       ShVariable tempVar(resizeCloneNode(v.node(), tsize));
00177       vv.push_back(tempVar);
00178 
00179       int* tempSwiz = new int[tsize];
00180       int* srcSwiz = new int[tsize];
00181       int tempSize;
00182       for(j = 0; j < srcVec.size(); ++j) {
00183         tempSize = 0;
00184         for(k = 0; k < tsize; ++k ) { 
00185           if(swiz[i + k] / maxTuple == (int)j) {
00186             tempSwiz[tempSize] = k;
00187             srcSwiz[tempSize] = swiz[i + k] % maxTuple;
00188             tempSize++;
00189           } 
00190         }
00191         if( tempSize > 0 ) {
00192           ShVariable srcVar(srcVec[j]);
00193           stmts.push_back(ShStatement(tempVar(tempSize, tempSwiz), SH_OP_ASN, srcVar(tempSize, srcSwiz)));
00194         }
00195       }
00196       delete [] tempSwiz;
00197       delete [] srcSwiz;
00198     }
00199   }
00200 
00201   // moves the result to the destination based on the destination swizzle
00202   void movToDest(ShTransformer::VarNodeVec &destVec, const ShSwizzle &destSwiz, 
00203       const VarVec &resultVec, ShBasicBlock::ShStmtList &stmts) {
00204     std::size_t j;
00205     int k;
00206     int offset = 0;
00207     int* swizd = new int[maxTuple];
00208     int* swizr = new int[maxTuple];
00209     int size;
00210     for(VarVec::const_iterator I = resultVec.begin(); I != resultVec.end(); 
00211         offset += I->size(), ++I) {
00212       for(j = 0; j < destVec.size(); ++j) {
00213         size = 0;
00214         for(k = 0; k < I->size(); ++k) {
00215           if( destSwiz[k + offset] / maxTuple == (int)j) {
00216             swizd[size] = destSwiz[k + offset] % maxTuple;
00217             swizr[size] = k;
00218             size++;
00219           }
00220         }
00221         if( size > 0 ) {
00222           ShVariable destVar(destVec[j]);
00223           stmts.push_back(ShStatement(destVar(size, swizd), SH_OP_ASN, (*I)(size, swizr)));
00224         }
00225       }
00226     }
00227     delete [] swizd;
00228     delete [] swizr;
00229   }
00230 
00231   ShVariableNodePtr resizeCloneNode(ShVariableNodePtr node, int newSize) {
00232     return node->clone(SH_TEMP, newSize, SH_VALUETYPE_END, 
00233         SH_SEMANTICTYPE_END, true, false);
00234   }
00235   // works on two assumptions
00236   // 1) special cases for DOT, XPD (and any other future non-componentwise ops) implemented separately
00237   // 2) Everything else is in the form N = [1|N]+ in terms of tuple sizes involved in dest and src
00238   void updateStatement(ShStatement &oldStmt, VarVec srcVec[3], ShBasicBlock::ShStmtList &stmts) {
00239     std::size_t i, j;
00240     ShVariable &dest = oldStmt.dest;
00241     const ShSwizzle &destSwiz = dest.swizzle();
00242     ShTransformer::VarNodeVec destVec;
00243     VarVec resultVec;
00244 
00245     if(splits.count(dest.node()) > 0) {
00246       destVec = splits[dest.node()];
00247     } else destVec.push_back(dest.node());
00248 
00249     switch(oldStmt.op) {
00250       case SH_OP_DOT:
00251         { 
00252           // TODO for large tuples, may want to use another dot to sum up results instead of 
00253           // SH_OP_ADD. For now, do naive method
00254           SH_DEBUG_ASSERT(destSwiz.size() == 1);
00255 
00256           // TODO check that this works correctly for weird types
00257           // (temporaries should have same type as the ShStatement's operation type) 
00258           ShVariable dott = ShVariable(resizeCloneNode(dest.node(), 1));
00259           ShVariable sumt = ShVariable(resizeCloneNode(dest.node(), 1));
00260 
00261           stmts.push_back(ShStatement(sumt, srcVec[0][0], SH_OP_DOT, srcVec[1][0]));
00262           for(i = 1; i < srcVec[0].size(); ++i) {
00263             stmts.push_back(ShStatement(dott, srcVec[0][i], SH_OP_DOT, srcVec[1][i]));
00264             stmts.push_back(ShStatement(sumt, sumt, SH_OP_ADD, dott));
00265           }
00266           resultVec.push_back(sumt);
00267         }
00268         break;
00269       case SH_OP_XPD:
00270         {
00271           SH_DEBUG_ASSERT( srcVec[0].size() == 1 && srcVec[0][0].size() == 3 &&
00272               srcVec[1].size() == 1 && srcVec[1][0].size() == 3); 
00273 
00274           // TODO check typing
00275           ShVariable result = ShVariable(resizeCloneNode(dest.node(), 3));
00276 
00277           stmts.push_back(ShStatement(result, srcVec[0][0], SH_OP_XPD, srcVec[1][0]));
00278           resultVec.push_back(result);
00279         }
00280         break;
00281 
00282       default:
00283         {
00284           int maxi = 0;
00285           if( srcVec[1].size() > srcVec[0].size() ) maxi = 1;
00286           if( srcVec[2].size() > srcVec[maxi].size() ) maxi = 2;
00287           for(i = 0; i < srcVec[maxi].size(); ++i) {
00288             // TODO check typing
00289             ShVariable resultPart(resizeCloneNode(dest.node(), srcVec[maxi][i].size()));
00290 
00291             ShStatement newStmt(resultPart, oldStmt.op);
00292             for(j = 0; j < 3 && !srcVec[j].empty(); ++j) {
00293               newStmt.src[j] = srcVec[j].size() > i ? srcVec[j][i] : srcVec[j][0];
00294             }
00295             stmts.push_back(newStmt);
00296             resultVec.push_back(resultPart);
00297           }
00298         }
00299         break;
00300     }
00301     movToDest(destVec, destSwiz, resultVec, stmts); 
00302   }
00303   
00306   void splitStatement(ShBasicBlockPtr block, ShBasicBlock::ShStmtList::iterator &stit) {
00307     ShStatement &stmt = *stit;
00308     int i;
00309     if(!stmt.marked && stmt.dest.size() <= maxTuple) {
00310       for(i = 0; i < 3; ++i) if(stmt.src[i].size() > maxTuple) break;
00311       if(i == 3) { // nothing needs splitting
00312         ++stit;
00313         return; 
00314       }
00315     }
00316     changed = true;
00317     ShBasicBlock::ShStmtList newStmts;
00318     VarVec srcVec[3];
00319 
00320     for(i = 0; i < 3; ++i) if(stmt.src[i].node()) makeSrcTemps(stmt.src[i], srcVec[i], newStmts);
00321     updateStatement(stmt, srcVec, newStmts);
00322 
00323     // remove old statmeent and splice in new statements
00324     stit = block->erase(stit);
00325     block->splice(stit, newStmts);
00326   }
00327 
00328   int maxTuple;
00329   ShTransformer::VarSplitMap &splits;
00330   bool& changed;
00331 };
00332 
00333 void ShTransformer::splitTuples(int maxTuple, ShTransformer::VarSplitMap &splits) {
00334   SH_DEBUG_ASSERT(maxTuple > 0); 
00335 
00336   VariableSplitter vs(maxTuple, splits, m_changed);
00337   vs.splitVarList(m_program->inputs);
00338   vs.splitVarList(m_program->outputs);
00339   m_program->ctrlGraph->dfs(vs);
00340 
00341 
00342   StatementSplitter ss(maxTuple, splits, m_changed);
00343   m_program->ctrlGraph->dfs(ss);
00344 }
00345 
00346 static int id = 0;
00347 
00348 // Output Convertion to temporaries 
00349 struct InputOutputConvertor {
00350   InputOutputConvertor(const ShProgramNodePtr& program,
00351                        ShVarMap &varMap, bool& changed)
00352     : m_program(program), m_varMap( varMap ), m_changed(changed), m_id(++id)
00353   {}
00354 
00355   void operator()(ShCtrlGraphNodePtr node) {
00356     if (!node) return;
00357     ShBasicBlockPtr block = node->block;
00358     if (!block) return;
00359     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00360       convertIO(*I);
00361     }
00362   }
00363 
00364   // Turn node into a temporary, but do not update var list and do not keep
00365   // uniform
00366   ShVariableNodePtr cloneNode(ShVariableNodePtr node, ShBindingType binding_type=SH_TEMP) {
00367     return node->clone(binding_type, 0, SH_VALUETYPE_END, SH_SEMANTICTYPE_END, false, false);
00368   }
00369 
00370   /* Convert all INOUT nodes that appear in a VarList (use std::for_each with this object)
00371    * (currently InOuts are always converted) */ 
00372   void operator()(ShVariableNodePtr node) {
00373     if (node->kind() != SH_INOUT || m_varMap.count(node) > 0) return;
00374     m_varMap[node] = cloneNode(node);
00375   }
00376 
00377   // Convert inputs, outputs only when they appear in incompatible locations
00378   // (inputs used as dest, outputs used as src)
00379   void convertIO(ShStatement& stmt)
00380   {
00381     if(!stmt.dest.null()) {
00382       const ShVariableNodePtr &oldNode = stmt.dest.node();
00383       if(oldNode->kind() == SH_INPUT) { 
00384         if(m_varMap.count(oldNode) == 0) {
00385           m_varMap[oldNode] = cloneNode(oldNode);
00386         }
00387       }
00388     }
00389     for(int i = 0; i < 3; ++i) {
00390       if(!stmt.src[i].null()) {
00391         const ShVariableNodePtr &oldNode = stmt.src[i].node();
00392         if(oldNode->kind() == SH_OUTPUT) { 
00393           if(m_varMap.count(oldNode) == 0) {
00394             m_varMap[oldNode] = cloneNode(oldNode);
00395           }
00396         }
00397       }
00398     }
00399   }
00400 
00401   void updateGraph() {
00402     if(m_varMap.empty()) return;
00403     m_changed = true;
00404 
00405     // create block after exit
00406     ShCtrlGraphNodePtr oldExit = m_program->ctrlGraph->appendExit(); 
00407     ShCtrlGraphNodePtr oldEntry = m_program->ctrlGraph->prependEntry();
00408 
00409     for(ShVarMap::const_iterator it = m_varMap.begin(); it != m_varMap.end(); ++it) {
00410       // assign temporary to output
00411       ShVariableNodePtr oldNode = it->first; 
00412       if(oldNode->kind() == SH_OUTPUT) {
00413         oldExit->block->addStatement(ShStatement(
00414               ShVariable(oldNode), SH_OP_ASN, ShVariable(it->second)));
00415       } else if(oldNode->kind() == SH_INPUT) {
00416         oldEntry->block->addStatement(ShStatement(
00417               ShVariable(it->second), SH_OP_ASN, ShVariable(oldNode)));
00418       } else if(oldNode->kind() == SH_INOUT) {
00419         // replace INOUT nodes in input/output lists with INPUT and OUTPUT nodes
00420         ShVariableNodePtr newInNode(cloneNode(oldNode, SH_INPUT));
00421         ShVariableNodePtr newOutNode(cloneNode(oldNode, SH_OUTPUT));
00422 
00423         std::replace(m_program->inputs.begin(), m_program->inputs.end(),
00424             oldNode, newInNode);
00425 
00426         std::replace(m_program->outputs.begin(), m_program->outputs.end(),
00427             oldNode, newOutNode);
00428 
00429         // add mov statements to/from temporary 
00430         oldEntry->block->addStatement(ShStatement(
00431               ShVariable(it->second), SH_OP_ASN, ShVariable(newInNode)));
00432         oldExit->block->addStatement(ShStatement(
00433               ShVariable(newOutNode), SH_OP_ASN, ShVariable(it->second)));
00434       }
00435     }
00436   }
00437 
00438   ShProgramNodePtr m_program;
00439   ShVarMap &m_varMap; // maps from outputs used as srcs in computation to their temporary variables
00440   bool& m_changed;
00441   int m_id;
00442 };
00443 
00444 void ShTransformer::convertInputOutput()
00445 {
00446   ShVarMap varMap; // maps from outputs used as srcs in computation to their temporary variables
00447 
00448   InputOutputConvertor ioc(m_program, varMap, m_changed);
00449   std::for_each(m_program->inputs.begin(), m_program->inputs.end(), ioc);
00450   std::for_each(m_program->outputs.begin(), m_program->outputs.end(), ioc);
00451   m_program->ctrlGraph->dfs(ioc);
00452 
00453   ShVariableReplacer vr(varMap);
00454   m_program->ctrlGraph->dfs(vr);
00455 
00456   ioc.updateGraph(); 
00457 }
00458 
00459 struct TextureLookupConverter {
00460   TextureLookupConverter() : changed(false) {}
00461   
00462   void operator()(const ShCtrlGraphNodePtr& node)
00463   {
00464     if (!node) return;
00465     ShBasicBlockPtr block = node->block;
00466     if (!block) return;
00467     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00468       convert(block, I);
00469     }
00470   }
00471 
00472   ShVariableNodePtr cloneNode(ShVariableNodePtr node) {
00473     return node->clone(SH_TEMP, 0, SH_VALUETYPE_END, SH_SEMANTICTYPE_END, true, false);
00474   }
00475 
00476   void convert(ShBasicBlockPtr block, ShBasicBlock::ShStmtList::iterator& I)
00477   {
00478     const ShStatement& stmt = *I;
00479     if (stmt.op != SH_OP_TEX && stmt.op != SH_OP_TEXI) return;
00480     ShTextureNodePtr tn = shref_dynamic_cast<ShTextureNode>(stmt.src[0].node());
00481 
00482     ShBasicBlock::ShStmtList newStmts;
00483     
00484     if (!tn) { SH_DEBUG_ERROR("TEX Instruction from non-texture"); return; }
00485     if (stmt.op == SH_OP_TEX && tn->dims() == SH_TEXTURE_RECT) {
00486       // TODO check typing
00487       //ShVariable tc(new ShVariableNode(SH_TEMP, tn->texSizeVar().size()));
00488       ShVariable tc(cloneNode(tn->texSizeVar().node()));
00489 
00490       newStmts.push_back(ShStatement(tc, stmt.src[1], SH_OP_MUL, tn->texSizeVar()));
00491       newStmts.push_back(ShStatement(stmt.dest, stmt.src[0], SH_OP_TEXI, tc));
00492     } else if (stmt.op == SH_OP_TEXI && tn->dims() != SH_TEXTURE_RECT) {
00493       // TODO check typing
00494       //ShVariable tc(new ShVariableNode(SH_TEMP, tn->texSizeVar().size()));
00495       ShVariable tc(cloneNode(tn->texSizeVar().node()));
00496 
00497       newStmts.push_back(ShStatement(tc, stmt.src[1], SH_OP_DIV, tn->texSizeVar()));
00498       newStmts.push_back(ShStatement(stmt.dest, stmt.src[0], SH_OP_TEX, tc));
00499     } else {
00500       return;
00501     }
00502     I = block->erase(I); // I is pointing one past its original value now
00503     block->splice(I, newStmts);
00504     I--; // Make I point to its original value, it will be inc'd later.
00505     changed = true;
00506     return;
00507   }
00508 
00509   bool changed;
00510 };
00511 
00512 void ShTransformer::convertTextureLookups()
00513 {
00514   TextureLookupConverter conv;
00515   m_program->ctrlGraph->dfs(conv);
00516   if (conv.changed) m_changed = true;
00517 }
00518 
00519 
00520 }
00521 

Generated on Mon Jan 24 18:36:36 2005 for Sh by  doxygen 1.4.1