Skip to content

Commit 6f28ef3

Browse files
richmckeevercopybara-github
authored andcommitted
Add the basics of generics support behind a "generics" feature flag.
I believe what's left after this is primarily: - Fixing colon refs through generics. Currently ColonRef resolution is too early for this. Static trait functions also need it moved later. - Checking if imports work. - Fixing invocation of impl functions from generified callers in IR converter. PiperOrigin-RevId: 830639806
1 parent 4279936 commit 6f28ef3

27 files changed

+748
-88
lines changed

xls/data_structures/inline_bitmap.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <cstring>
2222
#include <iterator>
2323
#include <optional>
24+
#include <string_view>
2425
#include <utility>
2526

2627
#include "absl/base/casts.h"
@@ -80,6 +81,19 @@ class InlineBitmap {
8081
return result;
8182
}
8283

84+
// Variant of FromBytes for converting a string.
85+
static InlineBitmap FromBytes(std::string_view bytes) {
86+
InlineBitmap result(bytes.size() * 8, false);
87+
// memcpy() requires valid pointers even when the number of bytes copied is
88+
// zero, and an empty absl::Span's data() pointer may not be valid. Guard
89+
// the memcpy with a check that the span is not empty.
90+
if (!result.data_.empty()) {
91+
std::memcpy(result.data_.data(), bytes.data(), bytes.size());
92+
}
93+
result.MaskLastWord();
94+
return result;
95+
}
96+
8397
// Constructs a bitmap of width `bits.size()` using the given bits,
8498
// interpreting index 0 as the *least* significant bit.
8599
//

xls/dslx/fmt/ast_fmt.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3093,6 +3093,10 @@ absl::StatusOr<DocRef> Formatter::Format(const Module& n) {
30933093
pieces.push_back(arena_.MakeText("#![feature(channel_attributes)]"));
30943094
pieces.push_back(arena_.hard_line());
30953095
break;
3096+
case ModuleAttribute::kGenerics:
3097+
pieces.push_back(arena_.MakeText("#![feature(generics)]"));
3098+
pieces.push_back(arena_.hard_line());
3099+
break;
30963100
}
30973101
}
30983102
pieces.push_back(arena_.hard_line());

xls/dslx/frontend/ast.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,10 @@ TypeVariableTypeAnnotation::TypeVariableTypeAnnotation(
11221122
internal_(internal) {}
11231123

11241124
std::string TypeVariableTypeAnnotation::ToString() const {
1125+
if (!internal_) {
1126+
return type_variable_->ToString();
1127+
}
1128+
11251129
return absl::StrCat("TypeVariableTypeAnnotation: ",
11261130
type_variable_->ToString());
11271131
}

xls/dslx/frontend/module.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ std::string Module::ToString() const {
9494
case ModuleAttribute::kChannelAttributes:
9595
absl::StrAppend(out, "#![feature(channel_attributes)]");
9696
break;
97+
case ModuleAttribute::kGenerics:
98+
absl::StrAppend(out, "#![feature(generics)]");
99+
break;
97100
}
98101
});
99102
return absl::StrCat(header, "\n\n", body);

xls/dslx/frontend/module.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ enum class ModuleAttribute : uint8_t {
8181

8282
// Enable #channel() attributes for this module.
8383
kChannelAttributes,
84+
85+
kGenerics,
8486
};
8587

8688
// Represents a syntactic module in the AST.

xls/dslx/frontend/parser.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ absl::Status Parser::ParseModuleAttribute() {
310310
} else if (feature == "channel_attributes") {
311311
module_->AddAttribute(ModuleAttribute::kChannelAttributes,
312312
attribute_span);
313+
} else if (feature == "generics") {
314+
module_->AddAttribute(ModuleAttribute::kGenerics, attribute_span);
313315
} else {
314316
return ParseErrorStatus(
315317
attribute_span,
@@ -4114,10 +4116,14 @@ absl::StatusOr<std::vector<ParametricBinding*>> Parser::ParseParametricBindings(
41144116
XLS_RETURN_IF_ERROR(
41154117
DropTokenOrError(TokenKind::kColon, /*start=*/nullptr,
41164118
"Expect type annotation on parametric"));
4117-
XLS_ASSIGN_OR_RETURN(TypeAnnotation * type,
4118-
ParseTypeAnnotation(bindings,
4119-
/*first=*/std::nullopt,
4120-
/*allow_generic_type=*/true));
4119+
XLS_ASSIGN_OR_RETURN(
4120+
TypeAnnotation * type,
4121+
ParseTypeAnnotation(
4122+
bindings,
4123+
/*first=*/std::nullopt,
4124+
/*allow_generic_type=*/
4125+
parse_fn_stubs_ ||
4126+
module_->attributes().contains(ModuleAttribute::kGenerics)));
41214127
if (GenericTypeAnnotation* gta =
41224128
dynamic_cast<GenericTypeAnnotation*>(type)) {
41234129
name_def->set_definer(gta);
@@ -4222,6 +4228,17 @@ absl::StatusOr<ExprOrType> Parser::ParseParametricArg(Bindings& bindings) {
42224228
return ParseCast(bindings, type_annotation);
42234229
}
42244230
}
4231+
if (auto name_ref = std::get_if<NameRef*>(&nocr); name_ref) {
4232+
// In a case like foo<T>(), the `nocr` for `T` is a `NameRef` to the type
4233+
// variable, and what we want to yield is a TVTA for that.
4234+
AnyNameDef any_name_def = (*name_ref)->name_def();
4235+
if (auto name_def = std::get_if<const NameDef*>(&any_name_def);
4236+
name_def && (*name_def)->definer() &&
4237+
(*name_def)->definer()->kind() == AstNodeKind::kTypeAnnotation) {
4238+
return module_->Make<TypeVariableTypeAnnotation>(*name_ref,
4239+
/*internal=*/false);
4240+
}
4241+
}
42254242
// Otherwise, it's a value or an unadorned imported type.
42264243
return ToExprNode(nocr);
42274244
}

xls/dslx/frontend/parser_test.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2584,15 +2584,20 @@ TEST_F(ParserTest, ModuleWithParametric) {
25842584
}
25852585

25862586
TEST_F(ParserTest, ModuleWithGenericType) {
2587-
RoundTrip(R"(fn parametric<T: type>() -> u32 {
2587+
RoundTrip(R"(#![feature(generics)]
2588+
2589+
fn parametric<T: type>() -> u32 {
25882590
zero!<T>()
25892591
})");
25902592
}
25912593

25922594
TEST_F(ParserTest, ModuleWithInvalidGenericType) {
2593-
constexpr std::string_view text = R"(fn non_parametric(T: type) -> u32 {
2595+
constexpr std::string_view text = R"(#![feature(generics)]
2596+
2597+
fn non_parametric(T: type) -> u32 {
25942598
zero!<T>()
2595-
})";
2599+
}
2600+
)";
25962601
Scanner s{file_table_, Fileno(0), std::string{text}};
25972602
Parser parser{"test", &s};
25982603
auto module_status = parser.ParseModule();

xls/dslx/interp_value.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ std::string TagToString(InterpValueTag tag) {
6666
return "token";
6767
case InterpValueTag::kChannelReference:
6868
return "channel_reference";
69+
case InterpValueTag::kTypeReference:
70+
return "type_reference";
6971
}
7072
return absl::StrFormat("<invalid InterpValueTag(%d)>",
7173
static_cast<int64_t>(tag));
@@ -187,7 +189,8 @@ std::string InterpValue::ToString(bool humanize,
187189
tag_ == InterpValueTag::kSBits || tag_ == InterpValueTag::kTuple ||
188190
tag_ == InterpValueTag::kArray || tag_ == InterpValueTag::kEnum ||
189191
tag_ == InterpValueTag::kFunction || tag_ == InterpValueTag::kToken ||
190-
tag_ == InterpValueTag::kChannelReference) {
192+
tag_ == InterpValueTag::kChannelReference ||
193+
tag_ == InterpValueTag::kTypeReference) {
191194
return result;
192195
}
193196
LOG(FATAL) << "Unhandled tag: " << tag_;
@@ -240,6 +243,8 @@ std::string InterpValue::ToStringInternal(bool humanize,
240243
GetChannelReferenceOrDie().GetChannelId().has_value()
241244
? absl::StrCat(*GetChannelReferenceOrDie().GetChannelId())
242245
: "none");
246+
case InterpValueTag::kTypeReference:
247+
return std::get<TypeReference>(payload_).string;
243248
}
244249
return "<INVALID>";
245250
}
@@ -426,6 +431,7 @@ bool InterpValue::Eq(const InterpValue& other) const {
426431
// bit value can be used in any place an enum type is annotated.
427432
case InterpValueTag::kSBits:
428433
case InterpValueTag::kUBits:
434+
case InterpValueTag::kTypeReference:
429435
case InterpValueTag::kEnum: {
430436
return other.HasBits() && GetBitsOrDie() == other.GetBitsOrDie();
431437
}
@@ -472,6 +478,7 @@ bool InterpValue::operator==(const InterpValue& rhs) const { return Eq(rhs); }
472478
}
473479
switch (lhs.tag_) {
474480
case InterpValueTag::kUBits:
481+
case InterpValueTag::kTypeReference:
475482
return MakeBool(ucmp(lhs.GetBitsOrDie(), rhs.GetBitsOrDie()));
476483
case InterpValueTag::kSBits:
477484
return MakeBool(scmp(lhs.GetBitsOrDie(), rhs.GetBitsOrDie()));
@@ -759,10 +766,20 @@ const Bits& InterpValue::GetBitsOrDie() const {
759766
if (std::holds_alternative<Bits>(payload_)) {
760767
return std::get<Bits>(payload_);
761768
}
769+
if (auto* ref = std::get_if<TypeReference>(&payload_)) {
770+
return ref->bits;
771+
}
762772

763773
return std::get<EnumData>(payload_).value;
764774
}
765775

776+
absl::StatusOr<const TypeAnnotation*> InterpValue::GetTypeReference() const {
777+
if (std::holds_alternative<TypeReference>(payload_)) {
778+
return std::get<TypeReference>(payload_).annotation;
779+
}
780+
return absl::InvalidArgumentError("Value does not contain a type reference.");
781+
}
782+
766783
absl::StatusOr<InterpValue::ChannelReference> InterpValue::GetChannelReference()
767784
const {
768785
if (std::holds_alternative<ChannelReference>(payload_)) {
@@ -1031,6 +1048,10 @@ absl::StatusOr<xls::Value> InterpValue::ConvertToIr() const {
10311048
return absl::InvalidArgumentError(absl::StrFormat(
10321049
"Cannot convert channel-reference-typed values to IR."));
10331050
}
1051+
case InterpValueTag::kTypeReference: {
1052+
return absl::InvalidArgumentError(
1053+
absl::StrFormat("Cannot convert type-reference-typed values to IR."));
1054+
}
10341055
}
10351056
LOG(FATAL) << "Unhandled tag: " << tag_;
10361057
}

xls/dslx/interp_value.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ enum class InterpValueTag : uint8_t {
5454
kFunction,
5555
kToken,
5656
kChannelReference,
57+
kTypeReference,
5758
};
5859

5960
std::string TagToString(InterpValueTag tag);
@@ -169,6 +170,14 @@ class InterpValue {
169170
return InterpValue(InterpValueTag::kToken, std::make_shared<TokenData>());
170171
}
171172

173+
static InterpValue MakeTypeReference(const TypeAnnotation* annotation) {
174+
std::string string = annotation->ToString();
175+
return InterpValue(InterpValueTag::kTypeReference,
176+
TypeReference{.annotation = annotation,
177+
.string = string,
178+
.bits = Bits::FromBytes(string)});
179+
}
180+
172181
// Queries
173182

174183
bool IsTuple() const { return tag_ == InterpValueTag::kTuple; }
@@ -187,6 +196,9 @@ class InterpValue {
187196
return IsBuiltinFunction() &&
188197
std::get<Builtin>(GetFunctionOrDie()) == Builtin::kTrace;
189198
}
199+
bool IsTypeReference() const {
200+
return tag_ == InterpValueTag::kTypeReference;
201+
}
190202

191203
bool IsFalse() const { return IsBool() && GetBitsOrDie().IsZero(); }
192204
bool IsTrue() const { return IsBool() && GetBitsOrDie().IsAllOnes(); }
@@ -344,6 +356,7 @@ class InterpValue {
344356
const FnData& GetFunctionOrDie() const { return *GetFunction().value(); }
345357
absl::StatusOr<Bits> GetBits() const;
346358
const Bits& GetBitsOrDie() const;
359+
absl::StatusOr<const TypeAnnotation*> GetTypeReference() const;
347360
absl::StatusOr<ChannelReference> GetChannelReference() const;
348361
ChannelReference GetChannelReferenceOrDie() const {
349362
return std::get<ChannelReference>(payload_);
@@ -370,7 +383,8 @@ class InterpValue {
370383
// apply to enum values as well.
371384
bool HasBits() const {
372385
return std::holds_alternative<Bits>(payload_) ||
373-
std::holds_alternative<EnumData>(payload_);
386+
std::holds_alternative<EnumData>(payload_) ||
387+
std::holds_alternative<TypeReference>(payload_);
374388
}
375389

376390
bool HasValues() const {
@@ -427,8 +441,15 @@ class InterpValue {
427441
//
428442
// TODO(leary): 2020-02-10 When all Python bindings are eliminated we can more
429443
// easily make an interpreter scoped lifetime that InterpValues can live in.
430-
using Payload = std::variant<Bits, EnumData, std::vector<InterpValue>, FnData,
431-
std::shared_ptr<TokenData>, ChannelReference>;
444+
struct TypeReference {
445+
const TypeAnnotation* annotation;
446+
std::string string;
447+
Bits bits;
448+
};
449+
450+
using Payload =
451+
std::variant<Bits, EnumData, std::vector<InterpValue>, FnData,
452+
std::shared_ptr<TokenData>, ChannelReference, TypeReference>;
432453

433454
InterpValue(InterpValueTag tag, Payload payload)
434455
: tag_(tag), payload_(std::move(payload)) {}

xls/dslx/ir_convert/function_converter.cc

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3252,21 +3252,25 @@ absl::Status FunctionConverter::HandleFunction(
32523252
VLOG(5) << "Resolving parametric binding: "
32533253
<< parametric_binding->ToString();
32543254

3255-
std::optional<InterpValue> parametric_value =
3256-
GetParametricBinding(parametric_binding->identifier());
3257-
XLS_RET_CHECK(parametric_value.has_value());
3258-
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> parametric_type,
3259-
ResolveType(parametric_binding->name_def()));
3260-
XLS_RET_CHECK(!parametric_type->IsMeta());
3261-
XLS_ASSIGN_OR_RETURN(TypeDim parametric_width_ctd,
3262-
parametric_type->GetTotalBitCount());
3263-
XLS_ASSIGN_OR_RETURN(Value param_value,
3264-
InterpValueToValue(*parametric_value));
3265-
const CValue evaluated = DefConst(parametric_binding, param_value);
3266-
const_prefill.SetNamedValue(parametric_binding->name_def()->identifier(),
3267-
evaluated.ir_value);
3268-
XLS_RETURN_IF_ERROR(
3269-
DefAlias(parametric_binding, /*to=*/parametric_binding->name_def()));
3255+
if (!parametric_binding->type_annotation()
3256+
->IsAnnotation<GenericTypeAnnotation>()) {
3257+
std::optional<InterpValue> parametric_value =
3258+
GetParametricBinding(parametric_binding->identifier());
3259+
XLS_RET_CHECK(parametric_value.has_value());
3260+
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> parametric_type,
3261+
ResolveType(parametric_binding->name_def()));
3262+
XLS_RET_CHECK(!parametric_type->IsMeta());
3263+
XLS_ASSIGN_OR_RETURN(TypeDim parametric_width_ctd,
3264+
parametric_type->GetTotalBitCount());
3265+
3266+
XLS_ASSIGN_OR_RETURN(Value param_value,
3267+
InterpValueToValue(*parametric_value));
3268+
const CValue evaluated = DefConst(parametric_binding, param_value);
3269+
const_prefill.SetNamedValue(parametric_binding->name_def()->identifier(),
3270+
evaluated.ir_value);
3271+
XLS_RETURN_IF_ERROR(
3272+
DefAlias(parametric_binding, /*to=*/parametric_binding->name_def()));
3273+
}
32703274
}
32713275

32723276
// If there is foreign function data, all constant values are replaced now.

0 commit comments

Comments
 (0)