//===--- RefactoringCallbacks.cpp - Structural query framework ------------===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // //===----------------------------------------------------------------------===// #include "clang/Tooling/RefactoringCallbacks.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/Basic/SourceLocation.h" #include "clang/Lex/Lexer.h" using llvm::StringError; using llvm::make_error; namespace clang { namespace tooling { RefactoringCallback::RefactoringCallback() {} tooling::Replacements &RefactoringCallback::getReplacements() { return Replace; } ASTMatchRefactorer::ASTMatchRefactorer( std::map &FileToReplaces) : FileToReplaces(FileToReplaces) {} void ASTMatchRefactorer::addDynamicMatcher( const ast_matchers::internal::DynTypedMatcher &Matcher, RefactoringCallback *Callback) { MatchFinder.addDynamicMatcher(Matcher, Callback); Callbacks.push_back(Callback); } class RefactoringASTConsumer : public ASTConsumer { public: explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring) : Refactoring(Refactoring) {} void HandleTranslationUnit(ASTContext &Context) override { // The ASTMatchRefactorer is re-used between translation units. // Clear the matchers so that each Replacement is only emitted once. for (const auto &Callback : Refactoring.Callbacks) { Callback->getReplacements().clear(); } Refactoring.MatchFinder.matchAST(Context); for (const auto &Callback : Refactoring.Callbacks) { for (const auto &Replacement : Callback->getReplacements()) { llvm::Error Err = Refactoring.FileToReplaces[Replacement.getFilePath()].add( Replacement); if (Err) { llvm::errs() << "Skipping replacement " << Replacement.toString() << " due to this error:\n" << toString(std::move(Err)) << "\n"; } } } } private: ASTMatchRefactorer &Refactoring; }; std::unique_ptr ASTMatchRefactorer::newASTConsumer() { return llvm::make_unique(*this); } static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From, StringRef Text) { return tooling::Replacement( Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text); } static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From, const Stmt &To) { return replaceStmtWithText( Sources, From, Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()), Sources, LangOptions())); } ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) : FromId(FromId), ToText(ToText) {} void ReplaceStmtWithText::run( const ast_matchers::MatchFinder::MatchResult &Result) { if (const Stmt *FromMatch = Result.Nodes.getNodeAs(FromId)) { auto Err = Replace.add(tooling::Replacement( *Result.SourceManager, CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText)); // FIXME: better error handling. For now, just print error message in the // release version. if (Err) { llvm::errs() << llvm::toString(std::move(Err)) << "\n"; assert(false); } } } ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId) : FromId(FromId), ToId(ToId) {} void ReplaceStmtWithStmt::run( const ast_matchers::MatchFinder::MatchResult &Result) { const Stmt *FromMatch = Result.Nodes.getNodeAs(FromId); const Stmt *ToMatch = Result.Nodes.getNodeAs(ToId); if (FromMatch && ToMatch) { auto Err = Replace.add( replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch)); // FIXME: better error handling. For now, just print error message in the // release version. if (Err) { llvm::errs() << llvm::toString(std::move(Err)) << "\n"; assert(false); } } } ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id, bool PickTrueBranch) : Id(Id), PickTrueBranch(PickTrueBranch) {} void ReplaceIfStmtWithItsBody::run( const ast_matchers::MatchFinder::MatchResult &Result) { if (const IfStmt *Node = Result.Nodes.getNodeAs(Id)) { const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse(); if (Body) { auto Err = Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body)); // FIXME: better error handling. For now, just print error message in the // release version. if (Err) { llvm::errs() << llvm::toString(std::move(Err)) << "\n"; assert(false); } } else if (!PickTrueBranch) { // If we want to use the 'else'-branch, but it doesn't exist, delete // the whole 'if'. auto Err = Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, "")); // FIXME: better error handling. For now, just print error message in the // release version. if (Err) { llvm::errs() << llvm::toString(std::move(Err)) << "\n"; assert(false); } } } } ReplaceNodeWithTemplate::ReplaceNodeWithTemplate( llvm::StringRef FromId, std::vector Template) : FromId(FromId), Template(std::move(Template)) {} llvm::Expected> ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) { std::vector ParsedTemplate; for (size_t Index = 0; Index < ToTemplate.size();) { if (ToTemplate[Index] == '$') { if (ToTemplate.substr(Index, 2) == "$$") { Index += 2; ParsedTemplate.push_back( TemplateElement{TemplateElement::Literal, "$"}); } else if (ToTemplate.substr(Index, 2) == "${") { size_t EndOfIdentifier = ToTemplate.find("}", Index); if (EndOfIdentifier == std::string::npos) { return make_error( "Unterminated ${...} in replacement template near " + ToTemplate.substr(Index), llvm::inconvertibleErrorCode()); } std::string SourceNodeName = ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2); ParsedTemplate.push_back( TemplateElement{TemplateElement::Identifier, SourceNodeName}); Index = EndOfIdentifier + 1; } else { return make_error( "Invalid $ in replacement template near " + ToTemplate.substr(Index), llvm::inconvertibleErrorCode()); } } else { size_t NextIndex = ToTemplate.find('$', Index + 1); ParsedTemplate.push_back( TemplateElement{TemplateElement::Literal, ToTemplate.substr(Index, NextIndex - Index)}); Index = NextIndex; } } return std::unique_ptr( new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate))); } void ReplaceNodeWithTemplate::run( const ast_matchers::MatchFinder::MatchResult &Result) { const auto &NodeMap = Result.Nodes.getMap(); std::string ToText; for (const auto &Element : Template) { switch (Element.Type) { case TemplateElement::Literal: ToText += Element.Value; break; case TemplateElement::Identifier: { auto NodeIter = NodeMap.find(Element.Value); if (NodeIter == NodeMap.end()) { llvm::errs() << "Node " << Element.Value << " used in replacement template not bound in Matcher \n"; llvm::report_fatal_error("Unbound node in replacement template."); } CharSourceRange Source = CharSourceRange::getTokenRange(NodeIter->second.getSourceRange()); ToText += Lexer::getSourceText(Source, *Result.SourceManager, Result.Context->getLangOpts()); break; } } } if (NodeMap.count(FromId) == 0) { llvm::errs() << "Node to be replaced " << FromId << " not bound in query.\n"; llvm::report_fatal_error("FromId node not bound in MatchResult"); } auto Replacement = tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText, Result.Context->getLangOpts()); llvm::Error Err = Replace.add(Replacement); if (Err) { llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() << "! " << llvm::toString(std::move(Err)) << "\n"; llvm::report_fatal_error("Replacement failed"); } } } // end namespace tooling } // end namespace clang