From dda3bf6c9a3125a892bff2ef2e581bb7d62f7921 Mon Sep 17 00:00:00 2001 From: Chris Dodd Date: Mon, 16 Sep 2024 23:39:24 +0000 Subject: [PATCH] Support for [lsb+:width] slices - allows for non-const lsb slices (width must still be const) --- backends/dpdk/dpdkArch.cpp | 2 +- frontends/common/constantFolding.cpp | 47 +++++++++++- frontends/common/constantFolding.h | 3 +- frontends/p4/fromv1.0/converters.cpp | 7 +- frontends/p4/fromv1.0/programStructure.cpp | 4 +- frontends/p4/strengthReduction.cpp | 27 ++++++- frontends/p4/strengthReduction.h | 3 +- frontends/p4/toP4/toP4.cpp | 1 + .../p4/typeChecking/readOnlyTypeInference.cpp | 3 +- frontends/p4/typeChecking/typeCheckExpr.cpp | 76 ++++++++++++++++++- frontends/p4/typeChecking/typeChecker.h | 9 ++- frontends/p4/typeChecking/typeInference.cpp | 3 +- frontends/parsers/p4/p4parser.ypp | 13 +++- ir/dbprint-expression.cpp | 9 ++- ir/expression.cpp | 9 ++- ir/expression.def | 38 ++++++++-- lib/flat_map.h | 1 + midend/def_use.cpp | 9 ++- midend/expandLookahead.cpp | 2 +- testdata/p4_16_samples/forloop5a.p4 | 30 ++++++++ .../p4_16_samples_outputs/forloop5a-first.p4 | 28 +++++++ .../forloop5a-frontend.p4 | 30 ++++++++ .../p4_16_samples_outputs/forloop5a-midend.p4 | 29 +++++++ testdata/p4_16_samples_outputs/forloop5a.p4 | 28 +++++++ .../p4_16_samples_outputs/forloop5a.p4-stderr | 0 25 files changed, 374 insertions(+), 37 deletions(-) create mode 100644 testdata/p4_16_samples/forloop5a.p4 create mode 100644 testdata/p4_16_samples_outputs/forloop5a-first.p4 create mode 100644 testdata/p4_16_samples_outputs/forloop5a-frontend.p4 create mode 100644 testdata/p4_16_samples_outputs/forloop5a-midend.p4 create mode 100644 testdata/p4_16_samples_outputs/forloop5a.p4 create mode 100644 testdata/p4_16_samples_outputs/forloop5a.p4-stderr diff --git a/backends/dpdk/dpdkArch.cpp b/backends/dpdk/dpdkArch.cpp index 7c081300223..4de3c0cb6ca 100644 --- a/backends/dpdk/dpdkArch.cpp +++ b/backends/dpdk/dpdkArch.cpp @@ -611,7 +611,7 @@ const IR::Node *AlignHdrMetaField::preorder(IR::Member *m) { // two different headers have field with same name. if (memVec.headerStr != hdrStrName) continue; auto mem = new IR::Member(m->expr, IR::ID(memVec.modifiedName)); - auto sliceMem = new IR::Slice(mem->clone(), memVec.msb, memVec.lsb); + auto sliceMem = new IR::AbsSlice(mem->clone(), memVec.msb, memVec.lsb); return sliceMem; } } diff --git a/frontends/common/constantFolding.cpp b/frontends/common/constantFolding.cpp index 6cf142553cb..8a9b56a97b8 100644 --- a/frontends/common/constantFolding.cpp +++ b/frontends/common/constantFolding.cpp @@ -622,7 +622,7 @@ static bool overflowWidth(const IR::Node *node, int width) { return false; } -const IR::Node *DoConstantFolding::postorder(IR::Slice *e) { +const IR::Node *DoConstantFolding::postorder(IR::AbsSlice *e) { const IR::Expression *msb = getConstant(e->e1); const IR::Expression *lsb = getConstant(e->e2); if (msb == nullptr) { @@ -678,6 +678,51 @@ const IR::Node *DoConstantFolding::postorder(IR::Slice *e) { return new IR::Constant(e->srcInfo, resultType, value, cbase->base, true); } +const IR::Node *DoConstantFolding::postorder(IR::PlusSlice *e) { + auto *e0 = getConstant(e->e0); + auto *lsb = getConstant(e->e1); + auto *width = getConstant(e->e2); + if (!width) { + if (typesKnown) + error(ErrorType::ERR_EXPECTED, + "%1%: slice indexes must be compile-time constants", e->e2); + return e; + } + + if (!e0 || !lsb) return e; + + auto clsb = lsb->to(); + if (clsb == nullptr) { + error(ErrorType::ERR_EXPECTED, "%1%: expected an integer value", lsb); + return e; + } + auto cwidth = width->to(); + if (cwidth == nullptr) { + error(ErrorType::ERR_EXPECTED, "%1%: expected an integer value", width); + return e; + } + auto cbase = e0->to(); + if (cbase == nullptr) { + error(ErrorType::ERR_EXPECTED, "%1%: expected an integer value", e->e0); + return e; + } + + int w = cwidth->asInt(); + int l = clsb->asInt(); + if (l < 0) { + ::P4::error(ErrorType::ERR_EXPECTED, "%1%: expected slice indexes to be non-negative", + e->e2); + return e; + } + if (overflowWidth(e, l) || overflowWidth(e, l+w)) return e; + big_int value = cbase->value >> l; + big_int mask = 1; + mask = (mask << w) - 1; + value = value & mask; + auto resultType = IR::Type_Bits::get(w); + return new IR::Constant(e->srcInfo, resultType, value, cbase->base, true); +} + const IR::Node *DoConstantFolding::postorder(IR::Member *e) { if (!typesKnown) return e; auto orig = getOriginal(); diff --git a/frontends/common/constantFolding.h b/frontends/common/constantFolding.h index 9d1f9a678f6..c4ca107ca23 100644 --- a/frontends/common/constantFolding.h +++ b/frontends/common/constantFolding.h @@ -140,7 +140,8 @@ class DoConstantFolding : public Transform, public ResolutionContext { const IR::Node *postorder(IR::LNot *e) override; const IR::Node *postorder(IR::LAnd *e) override; const IR::Node *postorder(IR::LOr *e) override; - const IR::Node *postorder(IR::Slice *e) override; + const IR::Node *postorder(IR::AbsSlice *e) override; + const IR::Node *postorder(IR::PlusSlice *e) override; const IR::Node *postorder(IR::Add *e) override; const IR::Node *postorder(IR::AddSat *e) override; const IR::Node *postorder(IR::Sub *e) override; diff --git a/frontends/p4/fromv1.0/converters.cpp b/frontends/p4/fromv1.0/converters.cpp index 1670171bdaf..86e8a67b052 100644 --- a/frontends/p4/fromv1.0/converters.cpp +++ b/frontends/p4/fromv1.0/converters.cpp @@ -51,8 +51,8 @@ const IR::Node *ExpressionConverter::postorder(IR::Mask *expression) { auto range = Util::findOnes(value); if (range.lowIndex == 0 && range.highIndex >= exp->type->width_bits() - 1U) return exp; if (value != range.value) return new IR::BAnd(expression->srcInfo, exp, cst); - return new IR::Slice(exp, new IR::Constant(expression->srcInfo, range.highIndex), - new IR::Constant(expression->srcInfo, range.lowIndex)); + return new IR::AbsSlice(exp, new IR::Constant(expression->srcInfo, range.highIndex), + new IR::Constant(expression->srcInfo, range.lowIndex)); } const IR::Node *ExpressionConverter::postorder(IR::Constant *expression) { @@ -110,8 +110,7 @@ const IR::Node *ExpressionConverter::postorder(IR::Primitive *primitive) { auto typeargs = new IR::Vector(); typeargs->push_back(IR::Type_Bits::get(aval + bval)); auto lookahead = new IR::MethodCallExpression(method, typeargs); - auto result = new IR::Slice(primitive->srcInfo, lookahead, new IR::Constant(bval - 1), - new IR::Constant(0)); + auto result = new IR::AbsSlice(primitive->srcInfo, lookahead, bval - 1, 0); result->type = IR::Type_Bits::get(bval); return result; } else if (primitive->name == "valid") { diff --git a/frontends/p4/fromv1.0/programStructure.cpp b/frontends/p4/fromv1.0/programStructure.cpp index f98f94891dc..31213aeb64c 100644 --- a/frontends/p4/fromv1.0/programStructure.cpp +++ b/frontends/p4/fromv1.0/programStructure.cpp @@ -1169,8 +1169,8 @@ const IR::Statement *ProgramStructure::sliceAssign(const IR::Primitive *primitiv if (cst->value == range.value) { auto h = new IR::Constant(range.highIndex); auto l = new IR::Constant(range.lowIndex); - left = new IR::Slice(left->srcInfo, left, h, l); - right = new IR::Slice(right->srcInfo, right, h, l); + left = new IR::AbsSlice(left->srcInfo, left, h, l); + right = new IR::AbsSlice(right->srcInfo, right, h, l); return assign(primitive->srcInfo, left, right, nullptr); } } diff --git a/frontends/p4/strengthReduction.cpp b/frontends/p4/strengthReduction.cpp index fe4b559065b..d4fda8a26a6 100644 --- a/frontends/p4/strengthReduction.cpp +++ b/frontends/p4/strengthReduction.cpp @@ -345,7 +345,7 @@ const IR::Node *DoStrengthReduction::postorder(IR::ArrayIndex *expr) { return expr; } -const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) { +const IR::Node *DoStrengthReduction::postorder(IR::AbsSlice *expr) { int shift_amt = 0; const IR::Expression *shift_of = nullptr; if (auto sh = expr->e0->to()) { @@ -407,8 +407,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) { return new IR::Concat( expr->srcInfo, expr->type, // type of slice is calculated by its constructor - new IR::Slice(cat->left->srcInfo, cat->left, expr->getH() - rwidth, 0), - new IR::Slice(cat->right->srcInfo, cat->right, rwidth - 1, expr->getL())); + new IR::AbsSlice(cat->left->srcInfo, cat->left, expr->getH() - rwidth, 0), + new IR::AbsSlice(cat->right->srcInfo, cat->right, rwidth - 1, expr->getL())); } } @@ -426,7 +426,7 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) { } // out-of-bound error has been caught in type checking - if (auto sl = expr->e0->to()) { + if (auto sl = expr->e0->to()) { int delta = sl->getL(); expr->e0 = sl->e0; if (delta != 0) { @@ -448,4 +448,23 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) { return expr; } +const IR::Node *DoStrengthReduction::postorder(IR::PlusSlice *expr) { + if (expr->e1->is() && expr->e2->is()) { + auto *rv = new IR::AbsSlice(expr->srcInfo, expr->e0, expr->getH(), expr->getL()); + return postorder(rv); + } + if (auto sh = expr->e0->to()) { + if (!sh->left->type->is()) return expr; + if (sh->left->type->to()->isSigned) return expr; + expr->e0 = sh->left; + expr->e1 = new IR::Add(sh->srcInfo, expr->e1, sh->right); + } + if (auto sh = expr->e0->to()) { + if (!sh->left->type->is()) return expr; + expr->e0 = sh->left; + expr->e1 = new IR::Sub(sh->srcInfo, expr->e1, sh->right); + } + return expr; +} + } // namespace P4 diff --git a/frontends/p4/strengthReduction.h b/frontends/p4/strengthReduction.h index 242c60f043c..8220ac0fdbc 100644 --- a/frontends/p4/strengthReduction.h +++ b/frontends/p4/strengthReduction.h @@ -103,7 +103,8 @@ class DoStrengthReduction final : public Transform { const IR::Node *postorder(IR::Div *expr) override; const IR::Node *postorder(IR::Mod *expr) override; const IR::Node *postorder(IR::Mux *expr) override; - const IR::Node *postorder(IR::Slice *expr) override; + const IR::Node *postorder(IR::AbsSlice *expr) override; + const IR::Node *postorder(IR::PlusSlice *expr) override; const IR::Node *postorder(IR::Mask *expr) override; const IR::Node *postorder(IR::Range *expr) override; const IR::Node *postorder(IR::Concat *expr) override; diff --git a/frontends/p4/toP4/toP4.cpp b/frontends/p4/toP4/toP4.cpp index cad9e003b5f..1faef1e540f 100644 --- a/frontends/p4/toP4/toP4.cpp +++ b/frontends/p4/toP4/toP4.cpp @@ -780,6 +780,7 @@ bool ToP4::preorder(const IR::Slice *slice) { builder.append("["); expressionPrecedence = DBPrint::Prec_Low; visit(slice->e1); + if (slice->is()) builder.append("+"); builder.append(":"); expressionPrecedence = DBPrint::Prec_Low; visit(slice->e2); diff --git a/frontends/p4/typeChecking/readOnlyTypeInference.cpp b/frontends/p4/typeChecking/readOnlyTypeInference.cpp index 46d47675d96..bb85bc67496 100644 --- a/frontends/p4/typeChecking/readOnlyTypeInference.cpp +++ b/frontends/p4/typeChecking/readOnlyTypeInference.cpp @@ -115,7 +115,8 @@ DEFINE_POSTORDER(IR::UPlus) DEFINE_POSTORDER(IR::Cmpl) DEFINE_POSTORDER(IR::Cast) DEFINE_POSTORDER(IR::Mux) -DEFINE_POSTORDER(IR::Slice) +DEFINE_POSTORDER(IR::AbsSlice) +DEFINE_POSTORDER(IR::PlusSlice) DEFINE_POSTORDER(IR::PathExpression) DEFINE_POSTORDER(IR::Member) DEFINE_POSTORDER(IR::TypeNameExpression) diff --git a/frontends/p4/typeChecking/typeCheckExpr.cpp b/frontends/p4/typeChecking/typeCheckExpr.cpp index fbea83a6fa5..65583cff50f 100644 --- a/frontends/p4/typeChecking/typeCheckExpr.cpp +++ b/frontends/p4/typeChecking/typeCheckExpr.cpp @@ -1347,7 +1347,7 @@ const IR::Node *TypeInferenceBase::postorder(const IR::PathExpression *expressio return expression; } -const IR::Node *TypeInferenceBase::postorder(const IR::Slice *expression) { +const IR::Node *TypeInferenceBase::postorder(const IR::AbsSlice *expression) { if (done()) return expression; const IR::Type *type = getType(expression->e0); if (type == nullptr) return expression; @@ -1359,7 +1359,7 @@ const IR::Node *TypeInferenceBase::postorder(const IR::Slice *expression) { return expression; } - IR::Slice *cloned = nullptr; + IR::AbsSlice *cloned = nullptr; auto e1type = getType(expression->e1); if (e1type && e1type->is()) { auto ei = EnumInstance::resolve(expression->e1, typeMap); @@ -1452,6 +1452,78 @@ const IR::Node *TypeInferenceBase::postorder(const IR::Slice *expression) { return expression; } +const IR::Node *TypeInferenceBase::postorder(const IR::PlusSlice *expression) { + if (done()) return expression; + const IR::Type *type = getType(expression->e0); + if (type == nullptr) return expression; + + if (auto se = type->to()) type = getTypeType(se->type); + + if (!type->is()) { + typeError("%1%: bit extraction only defined for bit<> types", expression); + return expression; + } + + IR::PlusSlice *cloned = nullptr; + auto e1type = getType(expression->e1); + if (e1type && e1type->is()) { + auto ei = EnumInstance::resolve(expression->e1, typeMap); + CHECK_NULL(ei); + if (auto sei = ei->to(); sei && expression->e1 != sei->value) { + cloned = expression->clone(); + cloned->e1 = sei->value; + } + } + auto e2type = getType(expression->e2); + if (e2type && e2type->is()) { + auto ei = EnumInstance::resolve(expression->e2, typeMap); + CHECK_NULL(ei); + auto sei = ei->to(); + if (sei == nullptr) { + typeError("%1%: slice bit index values must be constants", expression->e2); + return expression; + } + + if (expression->e1 != sei->value) { + cloned = (cloned ? cloned : expression->clone()); + cloned->e2 = sei->value; + } + } + if (cloned) expression = cloned; + + if (!expression->e2->is()) { + typeError("%1%: slice bit index values must be constants", expression->e2); + return expression; + } + auto width = expression->e2->checkedTo(); + if (!width->fitsInt()) { + typeError("%1%: width too large", width); + return expression; + } + int w = width->asInt(); + if (w < 0) { + typeError("%1%: negative width %2%", expression, width); + return expression; + } + + const IR::Type *resultType = IR::Type_Bits::get(type->srcInfo, w, false); + resultType = canonicalize(resultType); + if (resultType == nullptr) return expression; + setType(getOriginal(), resultType); + setType(expression, resultType); + if (isLeftValue(expression->e0)) { + setLeftValue(expression); + setLeftValue(getOriginal()); + } + if (isCompileTimeConstant(expression->e0) && isCompileTimeConstant(expression->e1)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + return expression; +} + const IR::Node *TypeInferenceBase::postorder(const IR::Dots *expression) { if (done()) return expression; setType(expression, IR::Type_Any::get()); diff --git a/frontends/p4/typeChecking/typeChecker.h b/frontends/p4/typeChecking/typeChecker.h index 58ea57ab552..b171595c54f 100644 --- a/frontends/p4/typeChecking/typeChecker.h +++ b/frontends/p4/typeChecking/typeChecker.h @@ -304,7 +304,8 @@ class TypeInferenceBase : public virtual Visitor, public ResolutionContext { const IR::Node *postorder(const IR::Cmpl *expression); const IR::Node *postorder(const IR::Cast *expression); const IR::Node *postorder(const IR::Mux *expression); - const IR::Node *postorder(const IR::Slice *expression); + const IR::Node *postorder(const IR::AbsSlice *expression); + const IR::Node *postorder(const IR::PlusSlice *expression); const IR::Node *postorder(const IR::PathExpression *expression); const IR::Node *postorder(const IR::Member *expression); const IR::Node *postorder(const IR::TypeNameExpression *expression); @@ -446,7 +447,8 @@ class ReadOnlyTypeInference : public virtual Inspector, public TypeInferenceBase void postorder(const IR::Cmpl *expression) override; void postorder(const IR::Cast *expression) override; void postorder(const IR::Mux *expression) override; - void postorder(const IR::Slice *expression) override; + void postorder(const IR::AbsSlice *expression) override; + void postorder(const IR::PlusSlice *expression) override; void postorder(const IR::PathExpression *expression) override; void postorder(const IR::Member *expression) override; void postorder(const IR::TypeNameExpression *expression) override; @@ -581,7 +583,8 @@ class TypeInference : public virtual Transform, public TypeInferenceBase { const IR::Node *postorder(IR::Cmpl *expression) override; const IR::Node *postorder(IR::Cast *expression) override; const IR::Node *postorder(IR::Mux *expression) override; - const IR::Node *postorder(IR::Slice *expression) override; + const IR::Node *postorder(IR::AbsSlice *expression) override; + const IR::Node *postorder(IR::PlusSlice *expression) override; const IR::Node *postorder(IR::PathExpression *expression) override; const IR::Node *postorder(IR::Member *expression) override; const IR::Node *postorder(IR::TypeNameExpression *expression) override; diff --git a/frontends/p4/typeChecking/typeInference.cpp b/frontends/p4/typeChecking/typeInference.cpp index c008dbedf00..fde261ec239 100644 --- a/frontends/p4/typeChecking/typeInference.cpp +++ b/frontends/p4/typeChecking/typeInference.cpp @@ -105,7 +105,8 @@ DEFINE_POSTORDER(IR::UPlus) DEFINE_POSTORDER(IR::Cmpl) DEFINE_POSTORDER(IR::Cast) DEFINE_POSTORDER(IR::Mux) -DEFINE_POSTORDER(IR::Slice) +DEFINE_POSTORDER(IR::AbsSlice) +DEFINE_POSTORDER(IR::PlusSlice) DEFINE_POSTORDER(IR::PathExpression) DEFINE_POSTORDER(IR::Member) DEFINE_POSTORDER(IR::TypeNameExpression) diff --git a/frontends/parsers/p4/p4parser.ypp b/frontends/parsers/p4/p4parser.ypp index 22ba7c5e7cd..c2b31f901d4 100644 --- a/frontends/parsers/p4/p4parser.ypp +++ b/frontends/parsers/p4/p4parser.ypp @@ -1549,7 +1549,8 @@ lvalue | THIS { $$ = new IR::This(@1); } | lvalue dot_name %prec DOT { $$ = new IR::Member(@1 + @2, $1, *$2); } | lvalue "[" expression "]" { $$ = new IR::ArrayIndex(@1 + @4, $1, $3); } - | lvalue "[" expression ":" expression "]" { $$ = new IR::Slice(@1 + @6, $1, $3, $5); } + | lvalue "[" expression ":" expression "]" { $$ = new IR::AbsSlice(@1 + @6, $1, $3, $5); } + | lvalue "[" expression "+" ":" expression "]" { $$ = new IR::PlusSlice(@1 + @7, $1, $3, $6); } | "(" lvalue ")" { $$ = $2; } ; @@ -1562,7 +1563,10 @@ expression | THIS { $$ = new IR::This(@1); } | prefixedNonTypeName { $$ = new IR::PathExpression($1); } | expression "[" expression "]" { $$ = new IR::ArrayIndex(@1 + @4, $1, $3); } - | expression "[" expression ":" expression "]" { $$ = new IR::Slice(@1 + @6, $1, $3, $5); } + | expression "[" expression ":" expression "]" + { $$ = new IR::AbsSlice(@1 + @6, $1, $3, $5); } + | expression "[" expression "+" ":" expression "]" + { $$ = new IR::PlusSlice(@1 + @7, $1, $3, $6); } | "{" expressionList optTrailingComma "}" { $$ = new IR::ListExpression(@1 + @4, *$2); } | INVALID { $$ = new IR::Invalid(@1, IR::Type::Unknown::get()); } | "{" kvList optTrailingComma "}" { $$ = new IR::StructExpression( @@ -1626,7 +1630,10 @@ nonBraceExpression | THIS { $$ = new IR::This(@1); } | prefixedNonTypeName { $$ = new IR::PathExpression($1); } | nonBraceExpression "[" expression "]" { $$ = new IR::ArrayIndex(@1 + @4, $1, $3); } - | nonBraceExpression "[" expression ":" expression "]" { $$ = new IR::Slice(@1 + @6, $1, $3, $5); } + | nonBraceExpression "[" expression ":" expression "]" + { $$ = new IR::AbsSlice(@1 + @6, $1, $3, $5); } + | nonBraceExpression "[" expression "+" ":" expression "]" + { $$ = new IR::PlusSlice(@1 + @7, $1, $3, $6); } | "(" expression ")" { $$ = $2; } | "!" expression %prec PREFIX { $$ = new IR::LNot(@1 + @2, $2); } | "~" expression %prec PREFIX { $$ = new IR::Cmpl(@1 + @2, $2); } diff --git a/ir/dbprint-expression.cpp b/ir/dbprint-expression.cpp index 07914eae2aa..82428184b03 100644 --- a/ir/dbprint-expression.cpp +++ b/ir/dbprint-expression.cpp @@ -75,13 +75,20 @@ ALL_UNARY_OPS(UNOP_DBPRINT) } ALL_BINARY_OPS(BINOP_DBPRINT) -void IR::Slice::dbprint(std::ostream &out) const { +void IR::AbsSlice::dbprint(std::ostream &out) const { int prec = getprec(out); out << setprec(Prec_Postfix) << e0 << "[" << setprec(Prec_Low) << e1 << ":" << setprec(Prec_Low) << e2 << setprec(prec) << ']'; if (prec == 0) out << ';'; } +void IR::PlusSlice::dbprint(std::ostream &out) const { + int prec = getprec(out); + out << setprec(Prec_Postfix) << e0 << "[" << setprec(Prec_Low) << e1 << "+:" + << setprec(Prec_Low) << e2 << setprec(prec) << ']'; + if (prec == 0) out << ';'; +} + void IR::Primitive::dbprint(std::ostream &out) const { const char *sep = ""; int prec = getprec(out); diff --git a/ir/expression.cpp b/ir/expression.cpp index c9314b5cc05..30977d4e28c 100644 --- a/ir/expression.cpp +++ b/ir/expression.cpp @@ -39,13 +39,18 @@ const IR::Expression *IR::Slice::make(const IR::Expression *e, unsigned lo, unsi if (hi >= src_width) hi = src_width - 1; if (lo == 0 && hi == src_width - 1) return e; } - if (auto sl = e->to()) { + if (auto sl = e->to()) { lo += sl->getL(); hi += sl->getL(); BUG_CHECK(lo >= sl->getL() && hi <= sl->getH(), "MakeSlice slice on slice type mismatch"); e = sl->e0; } - return new IR::Slice(e, hi, lo); + if (auto sl = e->to()) { + auto *e2 = sl->e2; + if (lo > 0) e2 = new IR::Add(e2, new IR::Constant(lo)); + return new IR::PlusSlice(sl->e1, e2, hi - lo + 1); + } + return new IR::AbsSlice(e, hi, lo); } int IR::Member::offset_bits() const { diff --git a/ir/expression.def b/ir/expression.def index d2e7255ac2b..ea03326bd62 100644 --- a/ir/expression.def +++ b/ir/expression.def @@ -325,7 +325,14 @@ class TypeNameExpression : Expression { "%1% unexpected type in TypeNameExpression", typeName); } } -class Slice : Operation_Ternary { +abstract Slice : Operation_Ternary { + virtual unsigned getH() const = 0; + virtual unsigned getL() const = 0; + // make a slice, folding slices on slices and slices on constants automatically + static Expression make(Expression a, unsigned hi, unsigned lo); +} + +class AbsSlice : Slice { precedence = DBPrint::Prec_Postfix; stringOp = "[:]"; toString{ return absl::StrCat(e0->toString().string_view(), @@ -334,17 +341,36 @@ class Slice : Operation_Ternary { // After type checking e1 and e2 will be constants unsigned getH() const { return e1->to()->asUnsigned(); } unsigned getL() const { return e2->to()->asUnsigned(); } - Slice(Expression a, int hi, int lo) - : Operation_Ternary(IR::Type::Bits::get(hi-lo+1), a, new Constant(hi), new Constant(lo)) {} - Slice(Util::SourceInfo si, Expression a, int hi, int lo) - : Operation_Ternary(si, IR::Type::Bits::get(hi-lo+1), a, new Constant(hi), new Constant(lo)) {} - Slice { + AbsSlice(Expression a, int hi, int lo) + : Slice(IR::Type::Bits::get(hi-lo+1), a, new Constant(hi), new Constant(lo)) {} + AbsSlice(Util::SourceInfo si, Expression a, int hi, int lo) + : Slice(si, IR::Type::Bits::get(hi-lo+1), a, new Constant(hi), new Constant(lo)) {} + AbsSlice { if (type->is() && e1 && e1->is() && e2 && e2->is()) type = IR::Type::Bits::get(getH() - getL() + 1); } // make a slice, folding slices on slices and slices on constants automatically static Expression make(Expression a, unsigned hi, unsigned lo); } +class PlusSlice : Slice { + precedence = DBPrint::Prec_Postfix; + stringOp = "[+:]"; + toString{ return absl::StrCat(e0->toString().string_view(), + "[", e1->toString().string_view(), "+:", + e2->toString().string_view(), "]"); } + unsigned getH() const { + BUG_CHECK(e1->is(), "non-const PlusSlice not handled"); + return e1->to()->asUnsigned() + e2->to()->asUnsigned() - 1; } + unsigned getL() const { + BUG_CHECK(e1->is(), "non-const PlusSlice not handled"); + return e1->to()->asUnsigned(); } + PlusSlice(Expression a, Expression lo, int width) + : Slice(IR::Type::Bits::get(width), a, lo, new Constant(width)) {} + PlusSlice { + if (type->is() && e2 && e2->is()) + type = IR::Type::Bits::get(e2->to()->asUnsigned()); } +} + class Member : Operation_Unary { precedence = DBPrint::Prec_Postfix; ID member; diff --git a/lib/flat_map.h b/lib/flat_map.h index ff707378adb..d43060c520c 100644 --- a/lib/flat_map.h +++ b/lib/flat_map.h @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include namespace P4 { diff --git a/midend/def_use.cpp b/midend/def_use.cpp index 00fdd802563..8b2e9bf2f95 100644 --- a/midend/def_use.cpp +++ b/midend/def_use.cpp @@ -506,7 +506,7 @@ const IR::Expression *ComputeDefUse::do_read(def_info_t &di, const IR::Expressio } else { BUG("%s: Member of unexpected type %s", m, m->expr->type); } - } else if (auto *sl = ctxt->node->to()) { + } else if (auto *sl = ctxt->node->to()) { le_bitrange range(sl->getL(), sl->getH()); for (auto it = di.slices_overlap_begin(range); it != di.slices.end() && range.overlaps(it->first); ++it) { @@ -600,13 +600,15 @@ const IR::Expression *ComputeDefUse::do_write(def_info_t &di, const IR::Expressi } else { BUG("%s: Member of unexpected type %s", m, m->expr->type); } - } else if (auto *sl = ctxt->node->to()) { + } else if (auto *sl = ctxt->node->to()) { le_bitrange range(sl->getL(), sl->getH()); di.live.clrrange(range.lo, range.size()); if (!di.live) di.defs.clear(); di.erase_slice(range); e = do_write(di.slices[range], sl, ctxt->parent); return e; + } else if (auto *sl = ctxt->node->to()) { + // writes an unknown part of the expression, rest (still) live } else if (auto *ai = ctxt->node->to()) { if (auto idx = ai->right->to()) { int i = idx->asInt(); @@ -619,8 +621,9 @@ const IR::Expression *ComputeDefUse::do_write(def_info_t &di, const IR::Expressi e = ai; } return e; + } else { + di.defs.clear(); } - di.defs.clear(); di.defs.insert(getLoc(e)); di.fields.clear(); di.slices.clear(); diff --git a/midend/expandLookahead.cpp b/midend/expandLookahead.cpp index 92e98a2319f..3a7a54badea 100644 --- a/midend/expandLookahead.cpp +++ b/midend/expandLookahead.cpp @@ -45,7 +45,7 @@ void DoExpandLookahead::expand( unsigned size = type->width_bits(); if (size == 0) return; const IR::Expression *expression = - new IR::Slice(bitvector->clone(), *offset - 1, *offset - size); + new IR::AbsSlice(bitvector->clone(), *offset - 1, *offset - size); auto tb = type->to(); if (!tb || tb->isSigned) expression = new IR::Cast(type, expression); *offset -= size; diff --git a/testdata/p4_16_samples/forloop5a.p4 b/testdata/p4_16_samples/forloop5a.p4 new file mode 100644 index 00000000000..089a512bdc6 --- /dev/null +++ b/testdata/p4_16_samples/forloop5a.p4 @@ -0,0 +1,30 @@ +#include +control generic(inout M m); +package top(generic c); + + +header t1 { + bit<32> x; + bit<32> y; +} + +struct headers_t { + t1 t1; +} + +control c(inout headers_t hdrs) { + action a0() { + bit<32> result = 0; + for (bit<8> i = 0; i < 32; i = i + 8) { + result = result << 8; + result = result + (bit<32>)hdrs.t1.x[i+:8] + (bit<32>)hdrs.t1.y[i+:8]; + } + hdrs.t1.x = result; + } + + apply { + a0(); + } +} + +top(c()) main; diff --git a/testdata/p4_16_samples_outputs/forloop5a-first.p4 b/testdata/p4_16_samples_outputs/forloop5a-first.p4 new file mode 100644 index 00000000000..d2f380743cf --- /dev/null +++ b/testdata/p4_16_samples_outputs/forloop5a-first.p4 @@ -0,0 +1,28 @@ +#include + +control generic(inout M m); +package top(generic c); +header t1 { + bit<32> x; + bit<32> y; +} + +struct headers_t { + t1 t1; +} + +control c(inout headers_t hdrs) { + action a0() { + bit<32> result = 32w0; + for (bit<8> i = 8w0; i < 8w32; i = i + 8w8) { + result = result << 8; + result = result + (bit<32>)hdrs.t1.x[i+:8] + (bit<32>)hdrs.t1.y[i+:8]; + } + hdrs.t1.x = result; + } + apply { + a0(); + } +} + +top(c()) main; diff --git a/testdata/p4_16_samples_outputs/forloop5a-frontend.p4 b/testdata/p4_16_samples_outputs/forloop5a-frontend.p4 new file mode 100644 index 00000000000..3e1f9bcccd9 --- /dev/null +++ b/testdata/p4_16_samples_outputs/forloop5a-frontend.p4 @@ -0,0 +1,30 @@ +#include + +control generic(inout M m); +package top(generic c); +header t1 { + bit<32> x; + bit<32> y; +} + +struct headers_t { + t1 t1; +} + +control c(inout headers_t hdrs) { + @name("c.result") bit<32> result_0; + @name("c.i") bit<8> i_0; + @name("c.a0") action a0() { + result_0 = 32w0; + for (i_0 = 8w0; i_0 < 8w32; i_0 = i_0 + 8w8) { + result_0 = result_0 << 8; + result_0 = result_0 + (bit<32>)hdrs.t1.x[i_0+:8] + (bit<32>)hdrs.t1.y[i_0+:8]; + } + hdrs.t1.x = result_0; + } + apply { + a0(); + } +} + +top(c()) main; diff --git a/testdata/p4_16_samples_outputs/forloop5a-midend.p4 b/testdata/p4_16_samples_outputs/forloop5a-midend.p4 new file mode 100644 index 00000000000..8caeea27a13 --- /dev/null +++ b/testdata/p4_16_samples_outputs/forloop5a-midend.p4 @@ -0,0 +1,29 @@ +#include + +control generic(inout M m); +package top(generic c); +header t1 { + bit<32> x; + bit<32> y; +} + +struct headers_t { + t1 t1; +} + +control c(inout headers_t hdrs) { + @name("c.a0") action a0() { + hdrs.t1.x = ((((bit<32>)hdrs.t1.x[7:0] + (bit<32>)hdrs.t1.y[7:0] << 8) + (bit<32>)hdrs.t1.x[15:8] + (bit<32>)hdrs.t1.y[15:8] << 8) + (bit<32>)hdrs.t1.x[23:16] + (bit<32>)hdrs.t1.y[23:16] << 8) + (bit<32>)hdrs.t1.x[31:24] + (bit<32>)hdrs.t1.y[31:24]; + } + @hidden table tbl_a0 { + actions = { + a0(); + } + const default_action = a0(); + } + apply { + tbl_a0.apply(); + } +} + +top(c()) main; diff --git a/testdata/p4_16_samples_outputs/forloop5a.p4 b/testdata/p4_16_samples_outputs/forloop5a.p4 new file mode 100644 index 00000000000..7f19884653f --- /dev/null +++ b/testdata/p4_16_samples_outputs/forloop5a.p4 @@ -0,0 +1,28 @@ +#include + +control generic(inout M m); +package top(generic c); +header t1 { + bit<32> x; + bit<32> y; +} + +struct headers_t { + t1 t1; +} + +control c(inout headers_t hdrs) { + action a0() { + bit<32> result = 0; + for (bit<8> i = 0; i < 32; i = i + 8) { + result = result << 8; + result = result + (bit<32>)hdrs.t1.x[i+:8] + (bit<32>)hdrs.t1.y[i+:8]; + } + hdrs.t1.x = result; + } + apply { + a0(); + } +} + +top(c()) main; diff --git a/testdata/p4_16_samples_outputs/forloop5a.p4-stderr b/testdata/p4_16_samples_outputs/forloop5a.p4-stderr new file mode 100644 index 00000000000..e69de29bb2d