00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00027 #include <map>
00028 #include <string>
00029 #include <sstream>
00030 #include <algorithm>
00031 #include "ShAlgebra.hpp"
00032 #include "ShCtrlGraph.hpp"
00033 #include "ShDebug.hpp"
00034 #include "ShError.hpp"
00035 #include "ShOptimizations.hpp"
00036 #include "ShInternals.hpp"
00037 #include "ShEnvironment.hpp"
00038 #include "ShContext.hpp"
00039 #include "ShManipulator.hpp"
00040 #include "ShFixedManipulator.hpp"
00041
00042 namespace SH {
00043
00044 ShProgram connect(ShProgram pa, ShProgram pb)
00045 {
00046 ShProgramNodePtr a = pa.node();
00047 ShProgramNodePtr b = pb.node();
00048
00049 if( !a || !b ) SH_DEBUG_WARN( "Connecting with a null ShProgram" );
00050 if( !a ) return b;
00051 if( !b ) return a;
00052
00053 int aosize = a->outputs.size();
00054 int bisize = b->inputs.size();
00055 std::string rtarget;
00056
00057 if (a->target().empty()) {
00058 rtarget = b->target();
00059 } else {
00060 if (b->target().empty() || a->target() == b->target()) {
00061 rtarget = a->target();
00062 } else {
00063 SH_DEBUG_WARN("Connecting two different targets. Using empty target for result.");
00064 rtarget = "";
00065 }
00066 }
00067
00068 ShProgramNodePtr program = new ShProgramNode(rtarget);
00069
00070 ShCtrlGraphNodePtr heada, taila, headb, tailb;
00071
00072 a->ctrlGraph->copy(heada, taila);
00073 b->ctrlGraph->copy(headb, tailb);
00074
00075 taila->append(headb);
00076
00077 ShCtrlGraphPtr new_graph = new ShCtrlGraph(heada, tailb);
00078 program->ctrlGraph = new_graph;
00079
00080 program->inputs = a->inputs;
00081
00082
00083 if(aosize < bisize) {
00084 ShProgramNode::VarList::const_iterator II = b->inputs.begin();
00085 for(int i = 0; i < aosize; ++i, ++II);
00086 for(; II != b->inputs.end(); ++II) {
00087 program->inputs.push_back(*II);
00088 }
00089 }
00090 program->outputs = b->outputs;
00091
00092
00093 if(aosize > bisize) {
00094 ShProgramNode::VarList::const_iterator II = a->outputs.begin();
00095 for(int i = 0; i < bisize; ++i, ++II);
00096 for(; II != a->outputs.end(); ++II) {
00097 program->outputs.push_back(*II);
00098 }
00099 }
00100
00101 ShVarMap varMap;
00102
00103 ShContext::current()->enter(program);
00104
00105 ShProgramNode::VarList::const_iterator I, J;
00106
00107 ShProgramNode::VarList InOutInputs;
00108 ShProgramNode::VarList InOutOutputs;
00109
00110
00111 for (I = a->outputs.begin(), J = b->inputs.begin();
00112 I != a->outputs.end() && J != b->inputs.end(); ++I, ++J) {
00113 if((*I)->size() != (*J)->size()) {
00114 std::ostringstream err;
00115 err << "Cannot smash variables "
00116 << (*I)->nameOfType() << " " << (*I)->name() << " and "
00117 << (*J)->nameOfType() << " " << (*J)->name() << " with different sizes" << std::endl;
00118 err << "while connecting outputs: ";
00119 ShProgramNode::print(err, a->outputs) << std::endl;
00120 err << "to inputs: ";
00121 ShProgramNode::print(err, b->inputs) << std::endl;
00122 ShContext::current()->exit();
00123 shError(ShAlgebraException(err.str()));
00124 return ShProgram(ShProgramNodePtr(0));
00125 }
00126 ShVariableNodePtr n = (*I)->clone(SH_TEMP);
00127 varMap[*I] = n;
00128 varMap[*J] = n;
00129
00130 if((*I)->kind() == SH_INOUT) InOutInputs.push_back((*I));
00131 if((*J)->kind() == SH_INOUT) InOutOutputs.push_back((*J));
00132 }
00133
00134
00135
00136 ShCtrlGraphNodePtr graphEntry;
00137 for (I = InOutInputs.begin(); I != InOutInputs.end(); ++I) {
00138 if(!graphEntry) graphEntry = program->ctrlGraph->prependEntry();
00139 ShVariableNodePtr newInput((*I)->clone(SH_INPUT));
00140
00141 std::replace(program->inputs.begin(), program->inputs.end(),
00142 (*I), newInput);
00143 program->inputs.pop_back();
00144
00145 graphEntry->block->addStatement(ShStatement(
00146 ShVariable(varMap[*I]), SH_OP_ASN, ShVariable(newInput)));
00147 }
00148
00149 ShCtrlGraphNodePtr graphExit;
00150 for (I = InOutOutputs.begin(); I != InOutOutputs.end(); ++I) {
00151 if(!graphExit) graphExit = program->ctrlGraph->appendExit();
00152 ShVariableNodePtr newOutput((*I)->clone(SH_OUTPUT));
00153
00154 std::replace(program->outputs.begin(), program->outputs.end(),
00155 (*I), newOutput);
00156 program->outputs.pop_back();
00157
00158 graphExit->block->addStatement(ShStatement(
00159 ShVariable(newOutput), SH_OP_ASN, ShVariable(varMap[*I])));
00160 }
00161
00162 ShContext::current()->exit();
00163
00164 ShVariableReplacer replacer(varMap);
00165 program->ctrlGraph->dfs(replacer);
00166
00167 optimize(program);
00168
00169 program->collectVariables();
00170 return program;
00171 }
00172
00173 ShProgram combine(ShProgram pa, ShProgram pb)
00174 {
00175 ShProgramNodePtr a = pa.node();
00176 ShProgramNodePtr b = pb.node();
00177
00178 std::string rtarget;
00179 if( !a || !b ) SH_DEBUG_WARN( "Connecting with a null ShProgram" );
00180 if (!a) return b;
00181 if (!b) return a;
00182
00183 if (a->target().empty()) {
00184 rtarget = b->target();
00185 } else {
00186 if (b->target().empty() || a->target() == b->target()) {
00187 rtarget = a->target();
00188 } else {
00189 rtarget = "";
00190 }
00191 }
00192
00193 ShProgramNodePtr program = new ShProgramNode(rtarget);
00194
00195 ShCtrlGraphNodePtr heada, taila, headb, tailb;
00196
00197 a->ctrlGraph->copy(heada, taila);
00198 b->ctrlGraph->copy(headb, tailb);
00199
00200 taila->append(headb);
00201
00202 ShCtrlGraphPtr new_graph = new ShCtrlGraph(heada, tailb);
00203 program->ctrlGraph = new_graph;
00204
00205 program->inputs = a->inputs;
00206 program->inputs.insert(program->inputs.end(), b->inputs.begin(), b->inputs.end());
00207 program->outputs = a->outputs;
00208 program->outputs.insert(program->outputs.end(), b->outputs.begin(), b->outputs.end());
00209
00210 optimize(program);
00211
00212 program->collectVariables();
00213
00214 return program;
00215 }
00216
00217
00218 ShProgram mergeNames(ShProgram p)
00219 {
00220 typedef std::pair<std::string, int> InputType;
00221 typedef std::map< InputType, int > FirstOccurenceMap;
00222 typedef std::vector< std::vector<int> > Duplicates;
00223 FirstOccurenceMap firsts;
00224
00225
00226 Duplicates dups( p.node()->inputs.size(), std::vector<int>());
00227
00228 std::size_t i = 0;
00229 for(ShProgramNode::VarList::const_iterator I = p.node()->inputs.begin();
00230 I != p.node()->inputs.end(); ++I, ++i) {
00231 InputType it( (*I)->name(), (*I)->size() );
00232 if( firsts.find( it ) != firsts.end() ) {
00233 dups[ firsts[it] ].push_back(i);
00234 } else {
00235 firsts[it] = i;
00236 dups[i].push_back(i);
00237 }
00238 }
00239 std::vector<int> swizzle;
00240 ShFixedManipulator duplicator;
00241 for(i = 0; i < dups.size(); ++i) {
00242 if( dups[i].empty() ) continue;
00243 for(std::size_t j = 0; j < dups[i].size(); ++j) swizzle.push_back(dups[i][j]);
00244 if( duplicator ) duplicator = duplicator & shDup(dups[i].size());
00245 else duplicator = shDup(dups[i].size());
00246 }
00247 ShProgram result = p << shSwizzle(swizzle);
00248 if( duplicator ) result = result << duplicator;
00249 return result.node();
00250 }
00251
00252 ShProgram namedCombine(ShProgram a, ShProgram b) {
00253 return mergeNames(combine(a, b));
00254 }
00255
00256 ShProgram namedConnect(ShProgram pa, ShProgram pb, bool keepExtra)
00257 {
00258 ShProgramNodeCPtr a = pa.node();
00259 ShProgramNodeCPtr b = pb.node();
00260
00261 typedef std::map<int, int> MatchedChannelMap;
00262
00263 std::vector<bool> aMatch(a->outputs.size(), false);
00264 std::vector<bool> bMatch(b->inputs.size(), false);
00265 MatchedChannelMap mcm;
00266 std::size_t i, j;
00267 ShProgramNode::VarList::const_iterator I, J;
00268
00269 i = 0;
00270 for(I = a->outputs.begin(); I != a->outputs.end(); ++I, ++i) {
00271 j = 0;
00272 for(J = b->inputs.begin(); J != b->inputs.end(); ++J, ++j) {
00273 if(bMatch[j]) continue;
00274 if((*I)->name() != (*J)->name()) continue;
00275 if((*I)->size() != (*J)->size()) {
00276 SH_DEBUG_WARN("Named connect matched channel name " << (*I)->name()
00277 << " but output size " << (*I)->size() << " != " << " input size " << (*J)->size() );
00278 continue;
00279 }
00280 mcm[i] = j;
00281 aMatch[i] = true;
00282 bMatch[j] = true;
00283 }
00284 }
00285
00286 std::vector<int> swiz(b->inputs.size(), 0);
00287 for(MatchedChannelMap::iterator mcmit = mcm.begin(); mcmit != mcm.end(); ++mcmit) {
00288 swiz[mcmit->second] = mcmit->first;
00289 }
00290
00291
00292 ShProgram passer = SH_BEGIN_PROGRAM() {} SH_END;
00293 int newInputIdx = a->outputs.size();
00294 for(j = 0, J= b->inputs.begin(); J != b->inputs.end(); ++J, ++j) {
00295 if( !bMatch[j] ) {
00296 ShProgram passOne = SH_BEGIN_PROGRAM() {
00297 ShVariable var((*J)->clone(SH_INOUT));
00298 } SH_END;
00299 passer = passer & passOne;
00300 swiz[j] = newInputIdx++;
00301 }
00302 }
00303 ShProgram aPass = combine(pa, passer);
00304
00305 if( keepExtra ) {
00306 for(i = 0; i < aMatch.size(); ++i) {
00307 if( !aMatch[i] ) swiz.push_back(i);
00308 }
00309 }
00310
00311 return mergeNames(pb << ( shSwizzle(swiz) << aPass ));
00312 }
00313
00314 ShProgram renameInput(ShProgram a,
00315 const std::string& oldName, const std::string& newName) {
00316 ShProgram renamer = SH_BEGIN_PROGRAM() {
00317 for(ShProgramNode::VarList::const_iterator I = a.node()->inputs.begin();
00318 I != a.node()->inputs.end(); ++I) {
00319 ShVariable var((*I)->clone(SH_INOUT));
00320
00321 if (!(*I)->has_name()) continue;
00322 std::string name = (*I)->name();
00323 if( name == oldName ) {
00324 var.name(newName);
00325 } else {
00326 var.name(name);
00327 }
00328 }
00329 } SH_END;
00330 return connect(renamer, a);
00331 }
00332
00333
00334 ShProgram renameOutput(ShProgram a,
00335 const std::string& oldName, const std::string& newName) {
00336 ShProgram renamer = SH_BEGIN_PROGRAM() {
00337 for(ShProgramNode::VarList::const_iterator I = a.node()->outputs.begin();
00338 I != a.node()->outputs.end(); ++I) {
00339 ShVariable var((*I)->clone(SH_INOUT));
00340
00341 if (!(*I)->has_name()) continue;
00342 std::string name = (*I)->name();
00343 if( name == oldName ) {
00344 var.name(newName);
00345 } else {
00346 var.name(name);
00347 }
00348 }
00349 } SH_END;
00350 return connect(a, renamer);
00351 }
00352
00353 ShProgram namedAlign(ShProgram a, ShProgram b) {
00354 ShManipulator<std::string> ordering;
00355
00356 for(ShProgramNode::VarList::const_iterator I = b.node()->inputs.begin();
00357 I != b.node()->inputs.end(); ++I) {
00358 ordering((*I)->name());
00359 }
00360
00361 return ordering << a;
00362 }
00363
00364 ShProgram operator<<(ShProgram a, ShProgram b)
00365 {
00366 return connect(b,a);
00367 }
00368
00369 ShProgram operator>>(ShProgram a, ShProgram b)
00370 {
00371 return connect(a,b);
00372 }
00373
00374 ShProgram operator&(ShProgram a, ShProgram b)
00375 {
00376 return combine(a, b);
00377 }
00378
00379 ShProgram operator>>(ShProgram p, const ShVariable &var) {
00380 return replaceUniform(p, var);
00381 }
00382
00383 ShProgram replaceUniform(ShProgram a, const ShVariable& v)
00384 {
00385 if(!v.uniform()) {
00386 shError(ShAlgebraException("Cannot replace non-uniform variable"));
00387 }
00388
00389 ShProgram program(a.node()->clone());
00390
00391 ShVarMap varMap;
00392
00393 ShContext::current()->enter(program.node());
00394
00395
00396 ShVariableNodePtr newInput(v.node()->clone(SH_INPUT));
00397 varMap[v.node()] = newInput;
00398
00399 ShContext::current()->exit();
00400
00401 ShVariableReplacer replacer(varMap);
00402 program.node()->ctrlGraph->dfs(replacer);
00403
00404 optimize(program);
00405
00406 program.node()->collectVariables();
00407
00408 return program;
00409 }
00410
00411 }