Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frontend: Fix SpecializeGenericTypes #5133

Merged
merged 18 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 125 additions & 76 deletions frontends/p4/specializeGenericTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,80 @@ limitations under the License.

#include "specializeGenericTypes.h"

#include "absl/strings/str_replace.h"
#include "frontends/p4/typeChecking/typeSubstitutionVisitor.h"

namespace P4 {

bool TypeSpecializationMap::same(const TypeSpecialization *spec,
const IR::Type_Specialized *right) const {
if (!spec->specialized->baseType->equiv(*right->baseType)) return false;
BUG_CHECK(spec->argumentTypes->size() == right->arguments->size(),
"Type %1% and %2% specialized with different number of arguments?", spec->specialized,
right);
for (size_t i = 0; i < spec->argumentTypes->size(); i++) {
auto argl = spec->argumentTypes->at(i);
auto argr = typeMap->getType(right->arguments->at(i), true);
if (!typeMap->equivalent(argl, argr)) return false;
const IR::Type_Declaration *TypeSpecializationMap::nextAvailable() {
for (auto &[_, specialization] : map) {
if (specialization.inserted) continue;
auto &insReq = specialization.insertion;
if (insReq.empty()) {
specialization.inserted = true;
return specialization.replacement;
}
}
return true;
return nullptr;
}

void TypeSpecializationMap::add(const IR::Type_Specialized *t, const IR::Type_StructLike *decl,
const IR::Node *insertion, NameGenerator *nameGen) {
auto it = map.find(t);
if (it != map.end()) return;

// First check if we have another specialization with the same
// type arguments, in that case reuse it
for (auto it : map) {
if (same(it.second, t)) {
map.emplace(t, it.second);
LOG3("Found to specialize: " << t << " as previous " << it.second->name);
return;
}
void TypeSpecializationMap::markDefined(const IR::Type_Declaration *tdec) {
const auto name = tdec->name.name;
for (auto &[_, specialization] : map) {
if (specialization.inserted) continue;
specialization.insertion.erase(name);
}
}

cstring name = nameGen->newName(decl->getName().string_view());
LOG3("Found to specialize: " << dbp(t) << "(" << t << ") with name " << name
<< " insert before " << dbp(insertion));
auto argTypes = new IR::Vector<IR::Type>();
for (auto a : *t->arguments) argTypes->push_back(typeMap->getType(a, true));
TypeSpecialization *s = new TypeSpecialization(name, t, decl, insertion, argTypes);
map.emplace(t, s);
void TypeSpecializationMap::fillInsertionSet(const IR::Type_StructLike *decl,
InsertionSet &insertion) {
auto handleName = [&](cstring name) {
LOG4("TSM: " << decl->toString() << " to be inserted after " << name);
insertion.insert(name);
};
// - not using type map, the struct type could be a fresh one, and even if it is not, struct
// fields will always have known type (unlike e.g. expressions)
// - using visitor to handle type names and type specialization inside types like `list` and
// `tuple`.
forAllMatching<IR::Type_Name>(
decl, [&](const IR::Type_Name *tn) { handleName(tn->path->name.name); });
forAllMatching<IR::Type_Specialized>(decl, [&](const IR::Type_Specialized *ts) {
if (const auto *specialization = get(ts)) {
handleName(specialization->name);
}
});
}

TypeSpecialization *TypeSpecializationMap::get(const IR::Type_Specialized *type) const {
for (auto it : map) {
if (same(it.second, type)) return it.second;
void TypeSpecializationMap::add(const IR::Type_Specialized *t, const IR::Type_StructLike *decl,
NameGenerator *nameGen) {
const auto sig = SpecSignature::get(t);
if (!sig) {
return;
}
return nullptr;
if (map.count(*sig)) return;

cstring name = nameGen->newName(sig->name());
LOG3("Found to specialize: " << dbp(t) << " with name " << name << " insert after "
<< dbp(decl));
map.emplace(*sig, TypeSpecialization{name, t, decl, {decl->name.name}, t->arguments});
}

namespace {

// depending on constness of Map returns a const or non-const pointer
template <typename Map>
auto *_get(Map &map, const IR::Type_Specialized *type) {
const auto sig = SpecSignature::get(type);
return sig ? getref(map, *sig) : nullptr;
}
} // namespace

const TypeSpecialization *TypeSpecializationMap::get(const IR::Type_Specialized *type) const {
return _get(map, type);
}

TypeSpecialization *TypeSpecializationMap::get(const IR::Type_Specialized *type) {
return _get(map, type);
}

namespace {
Expand Down Expand Up @@ -92,13 +119,12 @@ class ContainsTypeVariable : public Inspector {
Visitor::profile_t FindTypeSpecializations::init_apply(const IR::Node *node) {
auto rv = Inspector::init_apply(node);
node->apply(nameGen);

return rv;
}

void FindTypeSpecializations::postorder(const IR::Type_Specialized *type) {
auto baseType = specMap->typeMap->getTypeType(type->baseType, true);
auto st = baseType->to<IR::Type_StructLike>();
const auto *baseType = getDeclaration(type->baseType->path, true);
const auto *st = baseType->to<IR::Type_StructLike>();
if (st == nullptr || st->typeParameters->size() == 0)
// nothing to specialize
return;
Expand All @@ -113,55 +139,52 @@ void FindTypeSpecializations::postorder(const IR::Type_Specialized *type) {
// specialized instances of G, e.g., G<bit<32>>.
return;
}
// Find location where the specialization is to be inserted.
// This can be before a Parser, Control, or a toplevel instance declaration
const IR::Node *insert = findContext<IR::P4Parser>();
if (!insert) insert = findContext<IR::Function>();
if (!insert) insert = findContext<IR::P4Control>();
if (!insert) insert = findContext<IR::Type_Declaration>();
if (!insert) insert = findContext<IR::Declaration_Constant>();
if (!insert) insert = findContext<IR::Declaration_Variable>();
if (!insert) insert = findContext<IR::Declaration_Instance>();
if (!insert) insert = findContext<IR::P4Action>();
CHECK_NULL(insert);
specMap->add(type, st, insert, &nameGen);
specMap->add(type, st, &nameGen);
}

///////////////////////////////////////////////////////////////////////////////////////

const IR::Node *CreateSpecializedTypes::postorder(IR::Type_Declaration *type) {
for (auto it : specMap->map) {
if (it.second->declaration->name == type->name) {
auto specialized = it.first;
auto genDecl = type->to<IR::IMayBeGenericType>();
TypeVariableSubstitution ts;
ts.setBindings(type, genDecl->getTypeParameters(), specialized->arguments);
TypeSubstitutionVisitor tsv(specMap->typeMap, &ts);
tsv.setCalledBy(this);
auto renamed = type->apply(tsv)->to<IR::Type_StructLike>()->clone();
cstring name = it.second->name;
auto empty = new IR::TypeParameters();
renamed->name = name;
renamed->typeParameters = empty;
it.second->replacement = postorder(renamed)->to<IR::Type_StructLike>();
LOG3("CST Specializing " << dbp(type) << " with " << ts << " as " << dbp(renamed));
}
void CreateSpecializedTypes::postorder(IR::Type_Specialized *spec) {
if (auto *specialization = specMap->get(spec)) {
const auto *declT = getDeclaration(spec->baseType->path)->to<IR::Type_Declaration>();
BUG_CHECK(declT, "Could not get declaration for %1%", spec);
auto genDecl = declT->to<IR::IMayBeGenericType>();
BUG_CHECK(genDecl, "Not a generic declaration: %1%", declT);
TypeVariableSubstitution ts;
ts.setBindings(declT, genDecl->getTypeParameters(), specialization->argumentTypes);
TypeSubstitutionVisitor tsv(specMap->typeMap, &ts);
tsv.setCalledBy(this);
auto renamed = declT->apply(tsv)->to<IR::Type_StructLike>()->clone();
cstring name = specialization->name;
renamed->name = name;
renamed->typeParameters = new IR::TypeParameters();
specialization->replacement = renamed;
// add additional insertion constraints
specMap->fillInsertionSet(renamed, specialization->insertion);
LOG3("CST: Specializing " << dbp(declT) << " with [" << ts << "] as " << dbp(renamed));
}
return insert(type);
}

const IR::Node *CreateSpecializedTypes::insert(const IR::Node *before) {
auto specs = specMap->getSpecializations(getOriginal());
if (specs == nullptr) return before;
LOG2(specs->size() << " instantiations before " << dbp(before));
specs->push_back(before);
return specs;
void CreateSpecializedTypes::postorder(IR::P4Program *prog) {
IR::Vector<IR::Node> newObjects;
for (const auto *obj : prog->objects) {
newObjects.push_back(obj);
if (const auto *tdec = obj->to<IR::Type_Declaration>()) {
specMap->markDefined(tdec);
while (const auto *addTDec = specMap->nextAvailable()) {
newObjects.push_back(addTDec);
specMap->markDefined(addTDec);
LOG2("CST: Will insert " << dbp(addTDec) << " after " << dbp(newObjects.back()));
}
}
}
prog->objects = newObjects;
}

const IR::Node *ReplaceTypeUses::postorder(IR::Type_Specialized *type) {
auto t = specMap->get(getOriginal<IR::Type_Specialized>());
auto t = specMap->get(type);
if (!t) return type;
CHECK_NULL(t->replacement);
BUG_CHECK(t->replacement, "Missing replacement %1% -> %2%", type, t->name);
LOG3("RTU Replacing " << dbp(type) << " with " << dbp(t->replacement));
return t->replacement->getP4Type();
}
Expand Down Expand Up @@ -189,4 +212,30 @@ const IR::Node *ReplaceTypeUses::postorder(IR::StructExpression *expression) {
return expression;
}

std::string SpecSignature::name() const {
std::stringstream ss;
ss << baseType;
for (const auto &t : arguments) {
ss << "_"
<< absl::StrReplaceAll(t, {{"<", ""}, {">", ""}, {" ", ""}, {",", "_"}, {".", "_"}});
}
return ss.str();
}

std::string toString(const SpecSignature &sig) {
return absl::StrCat(sig.baseType, "<", absl::StrJoin(sig.arguments, ","), ">");
}

std::optional<SpecSignature> SpecSignature::get(const IR::Type_Specialized *spec) {
SpecSignature out;
out.baseType = spec->baseType->path->name;
for (const auto *arg : *spec->arguments) {
if (ContainsTypeVariable::inspect(arg)) {
return {};
}
out.arguments.push_back(arg->toString());
}
return out;
}

} // namespace P4
Loading
Loading