00001 #include "ShOptimizations.hpp"
00002 #include <map>
00003 #include <set>
00004 #include <utility>
00005 #include "ShBitSet.hpp"
00006 #include "ShCtrlGraph.hpp"
00007 #include "ShDebug.hpp"
00008 #include "ShEvaluate.hpp"
00009 #include "ShContext.hpp"
00010 #include "ShSyntax.hpp"
00011 #include <sstream>
00012 #include <fstream>
00013
00014 namespace {
00015
00016 using namespace SH;
00017
00018
00019
00020 struct BraInstInserter {
00021 void operator()(ShCtrlGraphNodePtr node)
00022 {
00023 if (!node) return;
00024
00025 for (std::vector<ShCtrlGraphBranch>::const_iterator I = node->successors.begin();
00026 I != node->successors.end(); ++I) {
00027 if (!node->block) node->block = new ShBasicBlock();
00028 node->block->addStatement(ShStatement(I->cond, SH_OP_OPTBRA, I->cond));
00029 }
00030 }
00031 };
00032
00033 struct BraInstRemover {
00034 void operator()(ShCtrlGraphNodePtr node)
00035 {
00036 if (!node) return;
00037 ShBasicBlockPtr block = node->block;
00038 if (!block) return;
00039
00040 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end();) {
00041 if (I->op == SH_OP_OPTBRA) {
00042 I = block->erase(I);
00043 continue;
00044 }
00045 ++I;
00046 }
00047 }
00048 };
00049
00050
00051 struct Straightener {
00052 Straightener(const ShCtrlGraphPtr& graph, bool& changed)
00053 : graph(graph), changed(changed)
00054 {
00055 }
00056
00057 void operator()(const ShCtrlGraphNodePtr& node)
00058 {
00059 if (!node) return;
00060 if (!node->follower) return;
00061 if (node == graph->entry()) return;
00062 if (node->follower == graph->exit()) return;
00063 if (!node->successors.empty()) return;
00064 if (node->follower->predecessors.size() > 1) return;
00065
00066 if (!node->block) node->block = new ShBasicBlock();
00067 if (!node->follower->block) node->follower->block = new ShBasicBlock();
00068
00069 for (ShBasicBlock::ShStmtList::iterator I = node->follower->block->begin(); I != node->follower->block->end(); ++I) {
00070 node->block->addStatement(*I);
00071 }
00072 node->successors = node->follower->successors;
00073
00074
00075
00076 for (std::vector<ShCtrlGraphBranch>::iterator I = node->follower->successors.begin();
00077 I != node->follower->successors.end(); ++I) {
00078 replacePredecessors(I->node, node->follower.object(), node.object());
00079 }
00080 if (node->follower->follower) replacePredecessors(node->follower->follower, node->follower.object(), node.object());
00081
00082 node->follower = node->follower->follower;
00083
00084 changed = true;
00085 }
00086
00087 void replacePredecessors(ShCtrlGraphNodePtr node,
00088 ShCtrlGraphNode* old,
00089 ShCtrlGraphNode* replacement)
00090 {
00091 for (ShCtrlGraphNode::ShPredList::iterator I = node->predecessors.begin(); I != node->predecessors.end(); ++I) {
00092 if (*I == old) {
00093 *I = replacement;
00094 break;
00095 }
00096 }
00097 }
00098
00099 ShCtrlGraphPtr graph;
00100 bool& changed;
00101 };
00102
00103 typedef std::queue<ShStatement*> DeadCodeWorkList;
00104
00105 struct InitLiveCode {
00106 InitLiveCode(DeadCodeWorkList& w)
00107 : w(w)
00108 {
00109 }
00110
00111 void operator()(ShCtrlGraphNodePtr node) {
00112 if (!node) return;
00113 ShBasicBlockPtr block = node->block;
00114 if (!block) return;
00115
00116 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00117 if (I->dest.node()->kind() != SH_TEMP
00118 || I->dest.node()->uniform()
00119 || I->op == SH_OP_KIL
00120
00121
00122
00123
00124 || I->op == SH_OP_OPTBRA) {
00125 I->marked = true;
00126 w.push(&(*I));
00127 continue;
00128 }
00129 I->marked = false;
00130 }
00131 }
00132
00133 DeadCodeWorkList& w;
00134 };
00135
00136 struct DeadCodeRemover {
00137 DeadCodeRemover(bool& changed)
00138 : changed(changed)
00139 {
00140 }
00141
00142 void operator()(ShCtrlGraphNodePtr node) {
00143 if (!node) return;
00144 ShBasicBlockPtr block = node->block;
00145 if (!block) return;
00146
00147 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end();) {
00148 if (!I->marked) {
00149 changed = true;
00150 I = block->erase(I);
00151 continue;
00152 }
00153 ++I;
00154 }
00155 }
00156
00157 bool& changed;
00158 };
00159
00160
00161
00162
00163 struct CopyPropagator {
00164 CopyPropagator(bool& changed)
00165 : changed(changed)
00166 {
00167 }
00168
00169 void operator()(const ShCtrlGraphNodePtr& node) {
00170 if (!node) return;
00171 ShBasicBlockPtr block = node->block;
00172
00173 if (!block) return;
00174 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00175 for (int i = 0; i < opInfo[I->op].arity; i++) copyValue(I->src[i]);
00176 removeACP(I->dest);
00177
00178 if (I->op == SH_OP_ASN
00179 && I->dest.node() != I->src[0].node()
00180 && I->dest.node()->kind() == SH_TEMP
00181 && I->dest.swizzle().identity()
00182 && I->src[0].swizzle().identity()) {
00183 m_acp.push_back(std::make_pair(I->dest, I->src[0]));
00184 }
00185 }
00186 m_acp.clear();
00187 }
00188
00189 void removeACP(const ShVariable& var)
00190 {
00191 for (ACP::iterator I = m_acp.begin(); I != m_acp.end();) {
00192 if (I->first.node() == var.node() || I->second.node() == var.node()) {
00193 I = m_acp.erase(I);
00194 continue;
00195 }
00196 ++I;
00197 }
00198 }
00199
00200
00201 void copyValue(ShVariable& var)
00202 {
00203 for (ACP::const_iterator I = m_acp.begin(); I != m_acp.end(); ++I) {
00204 if (I->first.node() == var.node()) {
00205 changed = true;
00206 var = ShVariable(I->second.node(), var.swizzle(),
00207 var.neg() ^ (I->first.neg() ^ I->second.neg()));
00208 break;
00209 }
00210 }
00211 }
00212
00213 typedef std::list< std::pair<ShVariable, ShVariable> > ACP;
00214 ACP m_acp;
00215
00216 bool& changed;
00217 };
00218
00219
00221 bool inRHS(const ShVariableNodePtr& node,
00222 const ShStatement& stmt)
00223 {
00224 for (int i = 0; i < opInfo[stmt.op].arity; i++) {
00225 if (stmt.src[i].node() == node) return true;
00226 }
00227
00228 return false;
00229 }
00230
00231 struct ForwardSubst {
00232 ForwardSubst(bool& changed)
00233 : changed(changed)
00234 {
00235 }
00236
00237 void operator()(const ShCtrlGraphNodePtr& node) {
00238 if (!node) return;
00239 ShBasicBlockPtr block = node->block;
00240 if (!block) return;
00241 for (ShBasicBlock::ShStmtList::iterator I = block->begin();
00242 I != block->end(); ++I) {
00243 substitute(*I);
00244
00245 removeAME(I->dest.node());
00246
00247 if (!inRHS(I->dest.node(), *I)
00248 && I->dest.node()->kind() == SH_TEMP
00249 && I->dest.swizzle().identity()) {
00250 m_ame.push_back(*I);
00251 }
00252 }
00253 m_ame.clear();
00254 }
00255
00256 void substitute(ShStatement& stmt)
00257 {
00258 if (stmt.op != SH_OP_ASN) return;
00259 if (stmt.src[0].neg()) return;
00260 if (stmt.src[0].node()->kind() != SH_TEMP) return;
00261 if (!stmt.src[0].swizzle().identity()) return;
00262
00263 for (AME::const_iterator I = m_ame.begin(); I != m_ame.end(); ++I) {
00264 if (I->dest.node() == stmt.src[0].node()) {
00265 ShVariable v = stmt.dest;
00266 stmt = *I;
00267 stmt.dest = v;
00268 changed = true;
00269 break;
00270 }
00271 }
00272 }
00273
00274 void removeAME(const ShVariableNodePtr& node)
00275 {
00276 for (AME::iterator I = m_ame.begin(); I != m_ame.end();) {
00277 if (I->dest.node() == node || inRHS(node, *I)) {
00278 I = m_ame.erase(I);
00279 continue;
00280 }
00281 ++I;
00282 }
00283 }
00284
00285 bool& changed;
00286 typedef std::list<ShStatement> AME;
00287 AME m_ame;
00288 };
00289
00290 }
00291
00292 namespace SH {
00293
00294 void insert_branch_instructions(ShProgram& p)
00295 {
00296 BraInstInserter r;
00297 p.node()->ctrlGraph->dfs(r);
00298 }
00299
00300 void remove_branch_instructions(ShProgram& p)
00301 {
00302 BraInstRemover r;
00303 p.node()->ctrlGraph->dfs(r);
00304 }
00305
00306 void straighten(ShProgram& p, bool& changed)
00307 {
00308 Straightener s(p.node()->ctrlGraph, changed);
00309 p.node()->ctrlGraph->dfs(s);
00310 }
00311
00312 void remove_dead_code(ShProgram& p, bool& changed)
00313 {
00314 DeadCodeWorkList w;
00315
00316 ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00317
00318 InitLiveCode init(w);
00319 graph->dfs(init);
00320
00321 while (!w.empty()) {
00322 ShStatement* stmt = w.front(); w.pop();
00323 ValueTracking* vt = stmt->get_info<ValueTracking>();
00324 if (!vt) continue;
00325
00326 for (int i = 0; i < opInfo[stmt->op].arity; i++) {
00327
00328 for (ValueTracking::TupleUseDefChain::iterator C = vt->defs[i].begin();
00329 C != vt->defs[i].end(); ++C) {
00330 for (ValueTracking::UseDefChain::iterator I = C->begin(); I != C->end(); ++I) {
00331 if (I->stmt->marked) continue;
00332 I->stmt->marked = true;
00333 w.push(I->stmt);
00334 }
00335 }
00336 }
00337 }
00338
00339 DeadCodeRemover r(changed);
00340 graph->dfs(r);
00341 }
00342
00343 void copy_propagate(ShProgram& p, bool& changed)
00344 {
00345 CopyPropagator c(changed);
00346 p.node()->ctrlGraph->dfs(c);
00347 }
00348
00349 void forward_substitute(ShProgram& p, bool& changed)
00350 {
00351 ForwardSubst f(changed);
00352 p.node()->ctrlGraph->dfs(f);
00353 }
00354
00355 void optimize(ShProgram& p, int level)
00356 {
00357 if (level <= 0) return;
00358
00359 #ifdef SH_DEBUG_OPTIMIZER
00360 int pass = 0;
00361 SH_DEBUG_PRINT("Begin optimization for program with target " << p.node()->target());
00362 #endif
00363
00364 bool changed;
00365 do {
00366
00367 #ifdef SH_DEBUG_OPTIMIZER
00368 SH_DEBUG_PRINT("---Optimizer pass " << pass << " BEGIN---");
00369 std::ostringstream s;
00370 s << "opt_" << pass;
00371 std::string filename = s.str() + ".dot";
00372 std::ofstream out(filename.c_str());
00373 p.node()->ctrlGraph->graphvizDump(out);
00374 out.close();
00375 std::string cmdline = std::string("dot -Tps -o ") + s.str() + ".ps " + s.str() + ".dot";
00376 system(cmdline.c_str());
00377 #endif
00378
00379 changed = false;
00380
00381 if (!ShContext::current()->optimization_disabled("copy propagation")) {
00382 copy_propagate(p, changed);
00383 }
00384 if (!ShContext::current()->optimization_disabled("forward substitution")) {
00385 forward_substitute(p, changed);
00386 }
00387
00388 p.node()->ctrlGraph->computePredecessors();
00389
00390 if (!ShContext::current()->optimization_disabled("straightening")) {
00391 straighten(p, changed);
00392 }
00393
00394 insert_branch_instructions(p);
00395
00396 if (level >= 2 &&
00397 !ShContext::current()->optimization_disabled("propagation")) {
00398 add_value_tracking(p);
00399 propagate_constants(p);
00400 }
00401
00402 if (!ShContext::current()->optimization_disabled("deadcode")) {
00403 add_value_tracking(p);
00404 remove_dead_code(p, changed);
00405 }
00406
00407 remove_branch_instructions(p);
00408
00409 #ifdef SH_DEBUG_OPTIMIZER
00410 SH_DEBUG_PRINT("---Optimizer pass " << pass << " END---");
00411 pass++;
00412 #endif
00413 } while (changed);
00414 }
00415
00416 void optimize(const ShProgramNodePtr& n, int level)
00417 {
00418 ShProgram p(n);
00419 optimize(p, level);
00420 }
00421
00422 void optimize(ShProgram& p)
00423 {
00424 optimize(p, ShContext::current()->optimization());
00425 }
00426
00427 void optimize(const ShProgramNodePtr& n)
00428 {
00429 ShProgram p(n);
00430 optimize(p);
00431 }
00432
00433 }