Skip to content

Commit cd6e390

Browse files
ericastorcopybara-github
authored andcommitted
[ir] Extend controlled Stages to track two more signals
We missed some signals that are critical in actual implementation: 1. The external signal `outputs_ready`, which triggers the Stage's consuming I/O (`receive` & `state_read`), and 2. The internal signal `active_inputs_valid`, which (when combined with `inputs_valid`) triggers the Stage's committing I/O (`send` & `next_value`). PiperOrigin-RevId: 836704396
1 parent 6d61761 commit cd6e390

13 files changed

+225
-66
lines changed

xls/ir/block.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,14 @@ std::string Block::DumpIr() const {
351351

352352
const Stage& stage = stages_[stage_idx];
353353
CHECK(stage.IsControlled());
354-
absl::StrAppendFormat(
355-
&res, " controlled_stage(%s%s) {\n", stage.inputs_valid()->GetName(),
356-
stage.contains(stage.outputs_valid())
357-
? ""
358-
: absl::StrCat(", outputs_valid=",
359-
stage.outputs_valid()->GetName()));
354+
absl::StrAppendFormat(&res, " controlled_stage(%s, %s) {\n",
355+
stage.inputs_valid()->GetName(),
356+
stage.outputs_ready()->GetName());
360357
for (Node* node : staged_nodes[stage_idx]) {
361-
absl::StrAppendFormat(&res, " %s%s\n",
362-
node == stage.outputs_valid() ? "ret " : "",
363-
node->ToString());
358+
absl::StrAppend(
359+
&res, " ", (node == stage.outputs_valid() ? "ret " : ""),
360+
(node == stage.active_inputs_valid() ? "active_inputs_valid " : ""),
361+
node->ToString(), "\n");
364362
}
365363
absl::StrAppend(&res, " }\n");
366364
}
@@ -1134,6 +1132,12 @@ absl::StatusOr<bool> Block::RemoveNodeFromStage(Node* node) {
11341132
return false;
11351133
}
11361134
int64_t stage_index = it->second;
1135+
Stage& stage = stages_[stage_index];
1136+
if (stage.active_inputs_valid() == node || stage.outputs_valid() == node) {
1137+
return absl::InvalidArgumentError(absl::StrFormat(
1138+
"Node %s has implicit uses in stage %d and cannot be removed.",
1139+
node->GetName(), stage_index));
1140+
}
11371141
node_to_stage_.erase(it);
11381142
stages_[stage_index].erase(node);
11391143
return true;

xls/ir/block.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ class Block : public FunctionBase {
340340

341341
bool HasImplicitUse(Node* node) const override {
342342
return absl::c_any_of(stages_, [node](const Stage& stage) {
343-
return stage.inputs_valid() == node || stage.outputs_valid() == node;
343+
return stage.inputs_valid() == node || stage.outputs_valid() == node ||
344+
stage.active_inputs_valid() == node ||
345+
stage.outputs_ready() == node;
344346
});
345347
}
346348

xls/ir/block_test.cc

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <optional>
1919
#include <string>
2020
#include <string_view>
21+
#include <utility>
2122
#include <vector>
2223

2324
#include "gmock/gmock.h"
@@ -1006,18 +1007,22 @@ class ScheduledBlockTest : public IrTestBase {
10061007
struct TestBlock {
10071008
ScheduledBlock* block;
10081009
Node* iv0;
1010+
Node* or0;
1011+
Node* aiv0;
10091012
Node* ov0;
10101013
};
10111014
absl::StatusOr<TestBlock> CreateScheduledBlock(Package* p) {
10121015
ScheduledBlockBuilder bb("b", p);
10131016
BValue iv0 = bb.Literal(UBits(1, 1), SourceInfo(), "iv0");
1014-
bb.StartStage(iv0);
1017+
BValue or0 = bb.Literal(UBits(1, 1), SourceInfo(), "or0");
1018+
bb.StartStage(iv0, or0);
10151019
BValue x = bb.InputPort("x", p->GetBitsType(32));
10161020
bb.OutputPort("out", x);
1021+
BValue aiv0 = bb.Literal(UBits(1, 1), SourceInfo(), "aiv0");
10171022
BValue ov0 = bb.Literal(UBits(1, 1), SourceInfo(), "ov0");
1018-
bb.EndStage(ov0);
1023+
bb.EndStage(aiv0, ov0);
10191024
XLS_ASSIGN_OR_RETURN(ScheduledBlock * block, bb.Build());
1020-
return TestBlock{block, iv0.node(), ov0.node()};
1025+
return TestBlock{block, iv0.node(), or0.node(), aiv0.node(), ov0.node()};
10211026
}
10221027

10231028
void ExpectIr(std::string_view got, std::string_view test_name) {
@@ -1044,23 +1049,38 @@ TEST_F(ScheduledBlockTest, StageAddAndClear) {
10441049

10451050
XLS_ASSERT_OK_AND_ASSIGN(
10461051
Node * iv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1052+
XLS_ASSERT_OK_AND_ASSIGN(
1053+
Node * or1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1054+
XLS_ASSERT_OK_AND_ASSIGN(
1055+
Node * aiv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
10471056
XLS_ASSERT_OK_AND_ASSIGN(
10481057
Node * ov1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1049-
block->AddStage(Stage(iv1, ov1));
1058+
Stage stage1(iv1, or1, aiv1, ov1);
1059+
stage1.AddNode(aiv1);
1060+
stage1.AddNode(ov1);
1061+
block->AddStage(std::move(stage1));
1062+
10501063
EXPECT_EQ(block->stages().size(), 2);
10511064

10521065
XLS_ASSERT_OK_AND_ASSIGN(
10531066
Node * iv2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1067+
XLS_ASSERT_OK_AND_ASSIGN(
1068+
Node * or2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1069+
XLS_ASSERT_OK_AND_ASSIGN(
1070+
Node * aiv2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
10541071
XLS_ASSERT_OK_AND_ASSIGN(
10551072
Node * ov2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1056-
block->AddStage(Stage(iv2, ov2));
1073+
Stage stage2(iv2, or2, aiv2, ov2);
1074+
stage2.AddNode(aiv2);
1075+
stage2.AddNode(ov2);
1076+
block->AddStage(std::move(stage2));
10571077
EXPECT_EQ(block->stages().size(), 3);
10581078

10591079
block->ClearStages();
10601080
EXPECT_TRUE(block->stages().empty());
10611081

10621082
// Re-stage nodes to satisfy the verifier on destruction.
1063-
block->AddStage(Stage(tb.iv0, tb.ov0));
1083+
block->AddStage(Stage(tb.iv0, tb.or0, tb.aiv0, tb.ov0));
10641084
for (Node* node : block->nodes()) {
10651085
if (!block->IsStaged(node)) {
10661086
XLS_ASSERT_OK(block->AddNodeToStage(0, node).status());
@@ -1075,9 +1095,16 @@ TEST_F(ScheduledBlockTest, GetStageIndex) {
10751095

10761096
XLS_ASSERT_OK_AND_ASSIGN(
10771097
Node * iv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1098+
XLS_ASSERT_OK_AND_ASSIGN(
1099+
Node * or1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1100+
XLS_ASSERT_OK_AND_ASSIGN(
1101+
Node * aiv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
10781102
XLS_ASSERT_OK_AND_ASSIGN(
10791103
Node * ov1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1080-
block->AddStage(Stage(iv1, ov1)); // Stage 1
1104+
Stage stage1(iv1, or1, aiv1, ov1);
1105+
stage1.AddNode(aiv1);
1106+
stage1.AddNode(ov1);
1107+
block->AddStage(std::move(stage1));
10811108

10821109
XLS_ASSERT_OK_AND_ASSIGN(
10831110
Node * lit1,
@@ -1094,9 +1121,16 @@ TEST_F(ScheduledBlockTest, AddNodeToStage) {
10941121
ScheduledBlock* block = tb.block;
10951122
XLS_ASSERT_OK_AND_ASSIGN(
10961123
Node * iv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1124+
XLS_ASSERT_OK_AND_ASSIGN(
1125+
Node * or1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1126+
XLS_ASSERT_OK_AND_ASSIGN(
1127+
Node * aiv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
10971128
XLS_ASSERT_OK_AND_ASSIGN(
10981129
Node * ov1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1099-
block->AddStage(Stage(iv1, ov1)); // Stage 1
1130+
Stage stage1(iv1, or1, aiv1, ov1);
1131+
stage1.AddNode(aiv1);
1132+
stage1.AddNode(ov1);
1133+
block->AddStage(std::move(stage1));
11001134

11011135
XLS_ASSERT_OK_AND_ASSIGN(
11021136
Node * lit, block->MakeNode<Literal>(SourceInfo(), Value(UBits(10, 32))));
@@ -1114,9 +1148,16 @@ TEST_F(ScheduledBlockTest, MakeNodeInStage) {
11141148
ScheduledBlock* block = tb.block;
11151149
XLS_ASSERT_OK_AND_ASSIGN(
11161150
Node * iv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1151+
XLS_ASSERT_OK_AND_ASSIGN(
1152+
Node * or1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1153+
XLS_ASSERT_OK_AND_ASSIGN(
1154+
Node * aiv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
11171155
XLS_ASSERT_OK_AND_ASSIGN(
11181156
Node * ov1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1119-
block->AddStage(Stage(iv1, ov1)); // Stage 1
1157+
Stage stage1(iv1, or1, aiv1, ov1);
1158+
stage1.AddNode(aiv1);
1159+
stage1.AddNode(ov1);
1160+
block->AddStage(std::move(stage1));
11201161

11211162
XLS_ASSERT_OK_AND_ASSIGN(
11221163
Literal * literal,
@@ -1130,9 +1171,16 @@ TEST_F(ScheduledBlockTest, MakeNodeWithNameInStage) {
11301171
ScheduledBlock* block = tb.block;
11311172
XLS_ASSERT_OK_AND_ASSIGN(
11321173
Node * iv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1174+
XLS_ASSERT_OK_AND_ASSIGN(
1175+
Node * or1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1176+
XLS_ASSERT_OK_AND_ASSIGN(
1177+
Node * aiv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
11331178
XLS_ASSERT_OK_AND_ASSIGN(
11341179
Node * ov1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1135-
block->AddStage(Stage(iv1, ov1)); // Stage 1
1180+
Stage stage1(iv1, or1, aiv1, ov1);
1181+
stage1.AddNode(aiv1);
1182+
stage1.AddNode(ov1);
1183+
block->AddStage(std::move(stage1));
11361184

11371185
XLS_ASSERT_OK_AND_ASSIGN(
11381186
Literal * literal, block->MakeNodeWithNameInStage<Literal>(
@@ -1154,14 +1202,28 @@ TEST_F(ScheduledBlockTest, CloneScheduledBlock) {
11541202
ScheduledBlock* block = tb.block;
11551203
XLS_ASSERT_OK_AND_ASSIGN(
11561204
Node * iv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1205+
XLS_ASSERT_OK_AND_ASSIGN(
1206+
Node * or1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1207+
XLS_ASSERT_OK_AND_ASSIGN(
1208+
Node * aiv1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
11571209
XLS_ASSERT_OK_AND_ASSIGN(
11581210
Node * ov1, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1159-
block->AddStage(Stage(iv1, ov1));
1211+
Stage stage1(iv1, or1, aiv1, ov1);
1212+
stage1.AddNode(aiv1);
1213+
stage1.AddNode(ov1);
1214+
block->AddStage(std::move(stage1));
11601215
XLS_ASSERT_OK_AND_ASSIGN(
11611216
Node * iv2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1217+
XLS_ASSERT_OK_AND_ASSIGN(
1218+
Node * or2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1219+
XLS_ASSERT_OK_AND_ASSIGN(
1220+
Node * aiv2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
11621221
XLS_ASSERT_OK_AND_ASSIGN(
11631222
Node * ov2, block->MakeNode<Literal>(SourceInfo(), Value(UBits(1, 1))));
1164-
block->AddStage(Stage(iv2, ov2));
1223+
Stage stage2(iv2, or2, aiv2, ov2);
1224+
stage1.AddNode(aiv2);
1225+
stage1.AddNode(ov2);
1226+
block->AddStage(std::move(stage2));
11651227

11661228
XLS_ASSERT_OK(block->MakeNodeWithNameInStage<Literal>(
11671229
0, SourceInfo(), Value(UBits(1, 32)), "my_x"));

xls/ir/function_base.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ absl::StatusOr<Stage> Stage::Clone(
9494
if (inputs_valid_ != nullptr) {
9595
XLS_ASSIGN_OR_RETURN(cloned_stage.inputs_valid_, map_node(inputs_valid_));
9696
}
97+
if (outputs_ready_ != nullptr) {
98+
XLS_ASSIGN_OR_RETURN(cloned_stage.outputs_ready_, map_node(outputs_ready_));
99+
}
100+
if (active_inputs_valid_ != nullptr) {
101+
XLS_ASSIGN_OR_RETURN(cloned_stage.active_inputs_valid_,
102+
map_node(active_inputs_valid_));
103+
}
97104
if (outputs_valid_ != nullptr) {
98105
XLS_ASSIGN_OR_RETURN(cloned_stage.outputs_valid_, map_node(outputs_valid_));
99106
}

xls/ir/function_base.h

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,16 @@ class Proc;
5252
// Represents a pipeline stage after scheduling.
5353
class Stage {
5454
public:
55-
Stage(Node* inputs_valid, Node* outputs_valid)
56-
: inputs_valid_(inputs_valid), outputs_valid_(outputs_valid) {
55+
Stage(Node* inputs_valid, Node* outputs_ready, Node* active_inputs_valid,
56+
Node* outputs_valid)
57+
: inputs_valid_(inputs_valid),
58+
outputs_ready_(outputs_ready),
59+
active_inputs_valid_(active_inputs_valid),
60+
outputs_valid_(outputs_valid) {
61+
CHECK_EQ(inputs_valid_ == nullptr, outputs_ready_ == nullptr);
5762
CHECK_EQ(inputs_valid_ == nullptr, outputs_valid_ == nullptr);
5863
}
59-
Stage() : Stage(nullptr, nullptr) {}
64+
Stage() : Stage(nullptr, nullptr, nullptr, nullptr) {}
6065

6166
Stage(const Stage& other) = default;
6267
Stage& operator=(const Stage& other) = default;
@@ -94,19 +99,37 @@ class Stage {
9499
const absl::flat_hash_map<Node*, Node*>& node_mapping) const;
95100

96101
bool IsControlled() const {
97-
return inputs_valid_ != nullptr && outputs_valid_ != nullptr;
102+
return inputs_valid_ != nullptr && outputs_ready_ != nullptr &&
103+
active_inputs_valid_ != nullptr && outputs_valid_ != nullptr;
98104
}
99105

100106
// Returns the node that signals whether it would be valid for this stage to
101-
// execute; i.e., that all passive inputs are updated for the next activation,
102-
// and all active inputs are valid.
107+
// execute; i.e., that all passive inputs are updated for the next activation.
103108
Node* inputs_valid() const { return inputs_valid_; }
104109

105110
void set_inputs_valid(Node* inputs_valid) { inputs_valid_ = inputs_valid; }
106111

107-
// Returns the node that signals whether the stage's outputs are valid to be
108-
// read; i.e., that all logic nodes are updated by the current activation, and
109-
// all active outputs have completed execution.
112+
// Returns the node that signals whether it is safe for this stage to execute;
113+
// i.e., that the receiver for all passive outputs will have space to store
114+
// the data.
115+
Node* outputs_ready() const { return outputs_ready_; }
116+
117+
void set_outputs_ready(Node* outputs_ready) {
118+
outputs_ready_ = outputs_ready;
119+
}
120+
121+
// Returns the node that signals whether all active inputs to this stage are
122+
// valid; i.e., that all active receives & all actively-read state values have
123+
// the correct values for this activation.
124+
Node* active_inputs_valid() const { return active_inputs_valid_; }
125+
126+
void set_active_inputs_valid(Node* active_inputs_valid) {
127+
active_inputs_valid_ = active_inputs_valid;
128+
}
129+
130+
// Returns the node that signals whether it would be safe for this stage to be
131+
// done executing; i.e., that all logic nodes are updated by the current
132+
// activation, and all active outputs have completed execution.
110133
Node* outputs_valid() const { return outputs_valid_; }
111134

112135
void set_outputs_valid(Node* outputs_valid) {
@@ -118,6 +141,8 @@ class Stage {
118141
absl::btree_set<Node*, Node::NodeIdLessThan> logic_;
119142
absl::btree_set<Node*, Node::NodeIdLessThan> active_outputs_;
120143
Node* inputs_valid_ = nullptr;
144+
Node* outputs_ready_ = nullptr;
145+
Node* active_inputs_valid_ = nullptr;
121146
Node* outputs_valid_ = nullptr;
122147
};
123148

xls/ir/ir_parser.cc

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3397,33 +3397,57 @@ absl::Status Parser::ParseControlledStage(
33973397
absl::flat_hash_map<std::string, BValue>* name_to_value, Package* package) {
33983398
XLS_RETURN_IF_ERROR(scanner_.DropKeywordOrError("controlled_stage"));
33993399
XLS_RETURN_IF_ERROR(scanner_.DropTokenOrError(LexicalTokenType::kParenOpen));
3400-
XLS_ASSIGN_OR_RETURN(BValue iv, ParseAndResolveIdentifier(*name_to_value));
3400+
XLS_ASSIGN_OR_RETURN(BValue inputs_valid,
3401+
ParseAndResolveIdentifier(*name_to_value));
3402+
XLS_RETURN_IF_ERROR(scanner_.DropTokenOrError(LexicalTokenType::kComma));
3403+
XLS_ASSIGN_OR_RETURN(BValue outputs_ready,
3404+
ParseAndResolveIdentifier(*name_to_value));
34013405
XLS_RETURN_IF_ERROR(scanner_.DropTokenOrError(LexicalTokenType::kParenClose));
34023406
XLS_RETURN_IF_ERROR(scanner_.DropTokenOrError(LexicalTokenType::kCurlOpen));
34033407

3404-
builder->StartStage(iv);
3408+
builder->StartStage(inputs_valid, outputs_ready);
34053409

3406-
std::optional<BValue> ov;
3410+
std::optional<BValue> active_inputs_valid;
3411+
std::optional<BValue> outputs_valid;
34073412
while (!scanner_.PeekTokenIs(LexicalTokenType::kCurlClose)) {
3413+
bool is_active_inputs_valid =
3414+
scanner_.TryDropKeyword("active_inputs_valid");
34083415
bool is_ret = scanner_.TryDropKeyword("ret");
3416+
if (is_ret && !is_active_inputs_valid) {
3417+
is_active_inputs_valid = scanner_.TryDropKeyword("active_inputs_valid");
3418+
}
34093419
XLS_ASSIGN_OR_RETURN(BValue result, ParseNode(builder, name_to_value));
34103420
(*name_to_value)[result.node()->GetName()] = result;
3421+
if (is_active_inputs_valid) {
3422+
if (active_inputs_valid.has_value()) {
3423+
return absl::InvalidArgumentError(absl::StrFormat(
3424+
"More than one active_inputs_valid found in controlled_stage @ %s",
3425+
result.node()->loc().ToString()));
3426+
}
3427+
active_inputs_valid = result;
3428+
}
34113429
if (is_ret) {
3412-
if (ov.has_value()) {
3430+
if (outputs_valid.has_value()) {
34133431
return absl::InvalidArgumentError(
34143432
absl::StrFormat("More than one ret found in controlled_stage @ %s",
34153433
result.node()->loc().ToString()));
34163434
}
3417-
ov = result;
3435+
outputs_valid = result;
34183436
}
34193437
}
34203438

3421-
if (!ov.has_value()) {
3439+
if (!active_inputs_valid.has_value()) {
3440+
return absl::InvalidArgumentError(absl::StrFormat(
3441+
"No 'active_inputs_valid' node found in controlled_stage @ %s",
3442+
scanner_.PeekTokenOrDie().pos().ToHumanString()));
3443+
}
3444+
if (!outputs_valid.has_value()) {
34223445
return absl::InvalidArgumentError(
3423-
"No 'ret' node found in controlled_stage");
3446+
absl::StrFormat("No 'ret' node found in controlled_stage @ %s",
3447+
scanner_.PeekTokenOrDie().pos().ToHumanString()));
34243448
}
34253449

3426-
builder->EndStage(*ov);
3450+
builder->EndStage(*active_inputs_valid, *outputs_valid);
34273451

34283452
XLS_RETURN_IF_ERROR(scanner_.DropTokenOrError(LexicalTokenType::kCurlClose));
34293453
return absl::OkStatus();

0 commit comments

Comments
 (0)