diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md index 32fb81fd..0ae383da 100644 --- a/design/mvp/CanonicalABI.md +++ b/design/mvp/CanonicalABI.md @@ -246,14 +246,44 @@ when lifting individual parameters and results: class LiftOptions: string_encoding: str = 'utf8' memory: Optional[bytearray] = None + addr_type: str = 'i32' + tbl_idx_type: str = 'i32' def equal(lhs, rhs): return lhs.string_encoding == rhs.string_encoding and \ - lhs.memory is rhs.memory + lhs.memory is rhs.memory and \ + lhs.addr_type == rhs.addr_type and \ + lhs.tbl_idx_type == rhs.tbl_idx_type ``` The `equal` static method is used by `task.return` below to dynamically compare equality of just this subset of `canonopt`. +The `addr_type` is `'i32'` when the `memory` canonopt refers to a memory32 +and `'i64'` when it refers to a memory64. The `tbl_idx_type` is `'i32'` by +default and `'i64'` when the `table64` canonopt is present. These two +dimensions are independent (e.g., a 64-bit memory with 32-bit table indices +is valid). + +The following helper functions return the byte size and core value type of +memory pointers and table indices, based on the options: +```python +def ptr_size(opts): + match opts.addr_type: + case 'i32': return 4 + case 'i64': return 8 + +def ptr_type(opts): + return opts.addr_type + +def idx_size(opts): + match opts.tbl_idx_type: + case 'i32': return 4 + case 'i64': return 8 + +def idx_type(opts): + return opts.tbl_idx_type +``` + The `LiftLowerOptions` class contains the subset of [`canonopt`] which are needed when lifting *or* lowering individual parameters and results: ```python @@ -1355,8 +1385,8 @@ class BufferGuestImpl(Buffer): def __init__(self, t, cx, ptr, length): trap_if(length > Buffer.MAX_LENGTH) if t and length > 0: - trap_if(ptr != align_to(ptr, alignment(t))) - trap_if(ptr + length * elem_size(t) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(t, cx.opts))) + trap_if(ptr + length * elem_size(t, cx.opts) > len(cx.opts.memory)) self.cx = cx self.t = t self.ptr = ptr @@ -1374,7 +1404,7 @@ class ReadableBufferGuestImpl(BufferGuestImpl): assert(n <= self.remain()) if self.t: vs = load_list_from_valid_range(self.cx, self.ptr, n, self.t) - self.ptr += n * elem_size(self.t) + self.ptr += n * elem_size(self.t, self.cx.opts) else: vs = n * [()] self.progress += n @@ -1385,7 +1415,7 @@ class WritableBufferGuestImpl(BufferGuestImpl, WritableBuffer): assert(len(vs) <= self.remain()) if self.t: store_list_into_valid_range(self.cx, vs, self.ptr, self.t) - self.ptr += len(vs) * elem_size(self.t) + self.ptr += len(vs) * elem_size(self.t, self.cx.opts) else: assert(all(v == () for v in vs)) self.progress += len(vs) @@ -1860,7 +1890,7 @@ Each value type is assigned an [alignment] which is used by subsequent Canonical ABI definitions. Presenting the definition of `alignment` piecewise, we start with the top-level case analysis: ```python -def alignment(t): +def alignment(t, opts): match despecialize(t): case BoolType() : return 1 case S8Type() | U8Type() : return 1 @@ -1870,11 +1900,11 @@ def alignment(t): case F32Type() : return 4 case F64Type() : return 8 case CharType() : return 4 - case StringType() : return 4 + case StringType() : return ptr_size(opts) case ErrorContextType() : return 4 - case ListType(t, l) : return alignment_list(t, l) - case RecordType(fields) : return alignment_record(fields) - case VariantType(cases) : return alignment_variant(cases) + case ListType(t, l) : return alignment_list(t, l, opts) + case RecordType(fields) : return alignment_record(fields, opts) + case VariantType(cases) : return alignment_variant(cases, opts) case FlagsType(labels) : return alignment_flags(labels) case OwnType() | BorrowType() : return 4 case StreamType() | FutureType() : return 4 @@ -1883,18 +1913,18 @@ def alignment(t): List alignment is the same as tuple alignment when the length is fixed and otherwise uses the alignment of pointers. ```python -def alignment_list(elem_type, maybe_length): +def alignment_list(elem_type, maybe_length, opts): if maybe_length is not None: - return alignment(elem_type) - return 4 + return alignment(elem_type, opts) + return ptr_size(opts) ``` Record alignment is tuple alignment, with the definitions split for reuse below: ```python -def alignment_record(fields): +def alignment_record(fields, opts): a = 1 for f in fields: - a = max(a, alignment(f.t)) + a = max(a, alignment(f.t, opts)) return a ``` @@ -1904,8 +1934,8 @@ covering the number of cases in the variant (with cases numbered in order from compact representations of variants in memory. This smallest integer type is selected by the following function, used above and below: ```python -def alignment_variant(cases): - return max(alignment(discriminant_type(cases)), max_case_alignment(cases)) +def alignment_variant(cases, opts): + return max(alignment(discriminant_type(cases), opts), max_case_alignment(cases, opts)) def discriminant_type(cases): n = len(cases) @@ -1916,11 +1946,11 @@ def discriminant_type(cases): case 2: return U16Type() case 3: return U32Type() -def max_case_alignment(cases): +def max_case_alignment(cases, opts): a = 1 for c in cases: if c.t is not None: - a = max(a, alignment(c.t)) + a = max(a, alignment(c.t, opts)) return a ``` @@ -1946,7 +1976,7 @@ maps well to languages which represent `list`s as random-access arrays. Empty types, such as records with no fields, are not permitted, to avoid complications in source languages. ```python -def elem_size(t): +def elem_size(t, opts): match despecialize(t): case BoolType() : return 1 case S8Type() | U8Type() : return 1 @@ -1956,40 +1986,40 @@ def elem_size(t): case F32Type() : return 4 case F64Type() : return 8 case CharType() : return 4 - case StringType() : return 8 + case StringType() : return 2 * ptr_size(opts) case ErrorContextType() : return 4 - case ListType(t, l) : return elem_size_list(t, l) - case RecordType(fields) : return elem_size_record(fields) - case VariantType(cases) : return elem_size_variant(cases) + case ListType(t, l) : return elem_size_list(t, l, opts) + case RecordType(fields) : return elem_size_record(fields, opts) + case VariantType(cases) : return elem_size_variant(cases, opts) case FlagsType(labels) : return elem_size_flags(labels) case OwnType() | BorrowType() : return 4 case StreamType() | FutureType() : return 4 -def elem_size_list(elem_type, maybe_length): +def elem_size_list(elem_type, maybe_length, opts): if maybe_length is not None: - return maybe_length * elem_size(elem_type) - return 8 + return maybe_length * elem_size(elem_type, opts) + return 2 * ptr_size(opts) -def elem_size_record(fields): +def elem_size_record(fields, opts): s = 0 for f in fields: - s = align_to(s, alignment(f.t)) - s += elem_size(f.t) + s = align_to(s, alignment(f.t, opts)) + s += elem_size(f.t, opts) assert(s > 0) - return align_to(s, alignment_record(fields)) + return align_to(s, alignment_record(fields, opts)) def align_to(ptr, alignment): return math.ceil(ptr / alignment) * alignment -def elem_size_variant(cases): - s = elem_size(discriminant_type(cases)) - s = align_to(s, max_case_alignment(cases)) +def elem_size_variant(cases, opts): + s = elem_size(discriminant_type(cases), opts) + s = align_to(s, max_case_alignment(cases, opts)) cs = 0 for c in cases: if c.t is not None: - cs = max(cs, elem_size(c.t)) + cs = max(cs, elem_size(c.t, opts)) s += cs - return align_to(s, alignment_variant(cases)) + return align_to(s, alignment_variant(cases, opts)) def elem_size_flags(labels): n = len(labels) @@ -2007,8 +2037,8 @@ as a Python value. Presenting the definition of `load` piecewise, we start with the top-level case analysis: ```python def load(cx, ptr, t): - assert(ptr == align_to(ptr, alignment(t))) - assert(ptr + elem_size(t) <= len(cx.opts.memory)) + assert(ptr == align_to(ptr, alignment(t, cx.opts))) + assert(ptr + elem_size(t, cx.opts) <= len(cx.opts.memory)) match despecialize(t): case BoolType() : return convert_int_to_bool(load_int(cx, ptr, 1)) case U8Type() : return load_int(cx, ptr, 1) @@ -2098,24 +2128,26 @@ def convert_i32_to_char(cx, i): return chr(i) ``` -Strings are loaded from two `i32` values: a pointer (offset in linear memory) -and a number of [code units]. There are three supported string encodings in -[`canonopt`]: [UTF-8], [UTF-16] and `latin1+utf16`. This last options allows a -*dynamic* choice between [Latin-1] and UTF-16, indicated by the high bit of the -second `i32`. String values include their original encoding and length in -tagged code units as a "hint" that enables `store_string` (defined below) to -make better up-front allocation size choices in many cases. Thus, the value -produced by `load_string` isn't simply a Python `str`, but a *tuple* containing -a `str`, the original encoding and the number of source code units. +Strings are loaded from two pointer-sized values: a pointer (offset in linear +memory) and a number of [code units]. There are three supported string +encodings in [`canonopt`]: [UTF-8], [UTF-16] and `latin1+utf16`. This last +option allows a *dynamic* choice between [Latin-1] and UTF-16, indicated by +the high bit of the second pointer-sized value. String values include their +original encoding and length in tagged code units as a "hint" that enables +`store_string` (defined below) to make better up-front allocation size choices +in many cases. Thus, the value produced by `load_string` isn't simply a Python +`str`, but a *tuple* containing a `str`, the original encoding and the number +of source code units. ```python String = tuple[str, str, int] def load_string(cx, ptr) -> String: - begin = load_int(cx, ptr, 4) - tagged_code_units = load_int(cx, ptr + 4, 4) + begin = load_int(cx, ptr, ptr_size(cx.opts)) + tagged_code_units = load_int(cx, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) return load_string_from_range(cx, begin, tagged_code_units) -UTF16_TAG = 1 << 31 +def utf16_tag(opts): + return 1 << (ptr_size(opts) * 8 - 1) def load_string_from_range(cx, ptr, tagged_code_units) -> String: match cx.opts.string_encoding: @@ -2129,8 +2161,8 @@ def load_string_from_range(cx, ptr, tagged_code_units) -> String: encoding = 'utf-16-le' case 'latin1+utf16': alignment = 2 - if bool(tagged_code_units & UTF16_TAG): - byte_length = 2 * (tagged_code_units ^ UTF16_TAG) + if bool(tagged_code_units & utf16_tag(cx.opts)): + byte_length = 2 * (tagged_code_units ^ utf16_tag(cx.opts)) encoding = 'utf-16-le' else: byte_length = tagged_code_units @@ -2160,27 +2192,27 @@ Lists and records are loaded by recursively loading their elements/fields: def load_list(cx, ptr, elem_type, maybe_length): if maybe_length is not None: return load_list_from_valid_range(cx, ptr, maybe_length, elem_type) - begin = load_int(cx, ptr, 4) - length = load_int(cx, ptr + 4, 4) + begin = load_int(cx, ptr, ptr_size(cx.opts)) + length = load_int(cx, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) return load_list_from_range(cx, begin, length, elem_type) def load_list_from_range(cx, ptr, length, elem_type): - trap_if(ptr != align_to(ptr, alignment(elem_type))) - trap_if(ptr + length * elem_size(elem_type) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(elem_type, cx.opts))) + trap_if(ptr + length * elem_size(elem_type, cx.opts) > len(cx.opts.memory)) return load_list_from_valid_range(cx, ptr, length, elem_type) def load_list_from_valid_range(cx, ptr, length, elem_type): a = [] for i in range(length): - a.append(load(cx, ptr + i * elem_size(elem_type), elem_type)) + a.append(load(cx, ptr + i * elem_size(elem_type, cx.opts), elem_type)) return a def load_record(cx, ptr, fields): record = {} for field in fields: - ptr = align_to(ptr, alignment(field.t)) + ptr = align_to(ptr, alignment(field.t, cx.opts)) record[field.label] = load(cx, ptr, field.t) - ptr += elem_size(field.t) + ptr += elem_size(field.t, cx.opts) return record ``` As a technical detail: the `align_to` in the loop in `load_record` is @@ -2194,12 +2226,12 @@ implementation can build the appropriate index tables at compile-time so that variant-passing is always O(1) and not involving string operations. ```python def load_variant(cx, ptr, cases): - disc_size = elem_size(discriminant_type(cases)) + disc_size = elem_size(discriminant_type(cases), cx.opts) case_index = load_int(cx, ptr, disc_size) ptr += disc_size trap_if(case_index >= len(cases)) c = cases[case_index] - ptr = align_to(ptr, max_case_alignment(cases)) + ptr = align_to(ptr, max_case_alignment(cases, cx.opts)) if c.t is None: return { c.label: None } return { c.label: load(cx, ptr, c.t) } @@ -2291,8 +2323,8 @@ The `store` function defines how to write a value `v` of a given value type `store` piecewise, we start with the top-level case analysis: ```python def store(cx, v, t, ptr): - assert(ptr == align_to(ptr, alignment(t))) - assert(ptr + elem_size(t) <= len(cx.opts.memory)) + assert(ptr == align_to(ptr, alignment(t, cx.opts))) + assert(ptr + elem_size(t, cx.opts) <= len(cx.opts.memory)) match despecialize(t): case BoolType() : store_int(cx, int(bool(v)), ptr, 1) case U8Type() : store_int(cx, v, ptr, 1) @@ -2405,20 +2437,20 @@ original encoding and number of source [code units]. From this hint data, We start with a case analysis to enumerate all the meaningful encoding combinations, subdividing the `latin1+utf16` encoding into either `latin1` or -`utf16` based on the `UTF16_TAG` flag set by `load_string`: +`utf16` based on the `utf16_tag` flag set by `load_string`: ```python def store_string(cx, v: String, ptr): begin, tagged_code_units = store_string_into_range(cx, v) - store_int(cx, begin, ptr, 4) - store_int(cx, tagged_code_units, ptr + 4, 4) + store_int(cx, begin, ptr, ptr_size(cx.opts)) + store_int(cx, tagged_code_units, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) def store_string_into_range(cx, v: String): src, src_encoding, src_tagged_code_units = v if src_encoding == 'latin1+utf16': - if bool(src_tagged_code_units & UTF16_TAG): + if bool(src_tagged_code_units & utf16_tag(cx.opts)): src_simple_encoding = 'utf16' - src_code_units = src_tagged_code_units ^ UTF16_TAG + src_code_units = src_tagged_code_units ^ utf16_tag(cx.opts) else: src_simple_encoding = 'latin1' src_code_units = src_tagged_code_units @@ -2451,11 +2483,12 @@ The simplest 4 cases above can compute the exact destination size and then copy with a simply loop (that possibly inflates Latin-1 to UTF-16 by injecting a 0 byte after every Latin-1 byte). ```python -MAX_STRING_BYTE_LENGTH = (1 << 31) - 1 +def max_string_byte_length(opts): + return (1 << (ptr_size(opts) * 8 - 1)) - 1 def store_string_copy(cx, src, src_code_units, dst_code_unit_size, dst_alignment, dst_encoding): dst_byte_length = dst_code_unit_size * src_code_units - trap_if(dst_byte_length > MAX_STRING_BYTE_LENGTH) + trap_if(dst_byte_length > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, dst_alignment, dst_byte_length) trap_if(ptr != align_to(ptr, dst_alignment)) trap_if(ptr + dst_byte_length > len(cx.opts.memory)) @@ -2464,8 +2497,8 @@ def store_string_copy(cx, src, src_code_units, dst_code_unit_size, dst_alignment cx.opts.memory[ptr : ptr+len(encoded)] = encoded return (ptr, src_code_units) ``` -The choice of `MAX_STRING_BYTE_LENGTH` constant ensures that the high bit of a -string's number of code units is never set, keeping it clear for `UTF16_TAG`. +The `max_string_byte_length` function ensures that the high bit of a +string's number of code units is never set, keeping it clear for `utf16_tag`. The 2 cases of transcoding into UTF-8 share an algorithm that starts by optimistically assuming that each code unit of the source string fits in a @@ -2481,14 +2514,14 @@ def store_latin1_to_utf8(cx, src, src_code_units): return store_string_to_utf8(cx, src, src_code_units, worst_case_size) def store_string_to_utf8(cx, src, src_code_units, worst_case_size): - assert(src_code_units <= MAX_STRING_BYTE_LENGTH) + assert(src_code_units <= max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 1, src_code_units) trap_if(ptr + src_code_units > len(cx.opts.memory)) for i,code_point in enumerate(src): if ord(code_point) < 2**7: cx.opts.memory[ptr + i] = ord(code_point) else: - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) + trap_if(worst_case_size > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(ptr, src_code_units, 1, worst_case_size) trap_if(ptr + worst_case_size > len(cx.opts.memory)) encoded = src.encode('utf-8') @@ -2507,7 +2540,7 @@ if multiple UTF-8 bytes were collapsed into a single 2-byte UTF-16 code unit: ```python def store_utf8_to_utf16(cx, src, src_code_units): worst_case_size = 2 * src_code_units - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) + trap_if(worst_case_size > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 2, worst_case_size) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + worst_case_size > len(cx.opts.memory)) @@ -2531,7 +2564,7 @@ after every Latin-1 byte (iterating in reverse to avoid clobbering later bytes): ```python def store_string_to_latin1_or_utf16(cx, src, src_code_units): - assert(src_code_units <= MAX_STRING_BYTE_LENGTH) + assert(src_code_units <= max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 2, src_code_units) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + src_code_units > len(cx.opts.memory)) @@ -2542,7 +2575,7 @@ def store_string_to_latin1_or_utf16(cx, src, src_code_units): dst_byte_length += 1 else: worst_case_size = 2 * src_code_units - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) + trap_if(worst_case_size > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(ptr, src_code_units, 2, worst_case_size) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + worst_case_size > len(cx.opts.memory)) @@ -2555,7 +2588,7 @@ def store_string_to_latin1_or_utf16(cx, src, src_code_units): ptr = cx.opts.realloc(ptr, worst_case_size, 2, len(encoded)) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + len(encoded) > len(cx.opts.memory)) - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG + tagged_code_units = int(len(encoded) / 2) | utf16_tag(cx.opts) return (ptr, tagged_code_units) if dst_byte_length < src_code_units: ptr = cx.opts.realloc(ptr, src_code_units, 2, dst_byte_length) @@ -2577,14 +2610,14 @@ inexpensively fused with the UTF-16 validate+copy loop.) ```python def store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units): src_byte_length = 2 * src_code_units - trap_if(src_byte_length > MAX_STRING_BYTE_LENGTH) + trap_if(src_byte_length > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 2, src_byte_length) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + src_byte_length > len(cx.opts.memory)) encoded = src.encode('utf-16-le') cx.opts.memory[ptr : ptr+len(encoded)] = encoded if any(ord(c) >= (1 << 8) for c in src): - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG + tagged_code_units = int(len(encoded) / 2) | utf16_tag(cx.opts) return (ptr, tagged_code_units) latin1_size = int(len(encoded) / 2) for i in range(latin1_size): @@ -2612,27 +2645,27 @@ def store_list(cx, v, ptr, elem_type, maybe_length): store_list_into_valid_range(cx, v, ptr, elem_type) return begin, length = store_list_into_range(cx, v, elem_type) - store_int(cx, begin, ptr, 4) - store_int(cx, length, ptr + 4, 4) + store_int(cx, begin, ptr, ptr_size(cx.opts)) + store_int(cx, length, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) def store_list_into_range(cx, v, elem_type): - byte_length = len(v) * elem_size(elem_type) - trap_if(byte_length >= (1 << 32)) - ptr = cx.opts.realloc(0, 0, alignment(elem_type), byte_length) - trap_if(ptr != align_to(ptr, alignment(elem_type))) + byte_length = len(v) * elem_size(elem_type, cx.opts) + trap_if(byte_length >= (1 << (ptr_size(cx.opts) * 8))) + ptr = cx.opts.realloc(0, 0, alignment(elem_type, cx.opts), byte_length) + trap_if(ptr != align_to(ptr, alignment(elem_type, cx.opts))) trap_if(ptr + byte_length > len(cx.opts.memory)) store_list_into_valid_range(cx, v, ptr, elem_type) return (ptr, len(v)) def store_list_into_valid_range(cx, v, ptr, elem_type): for i,e in enumerate(v): - store(cx, e, elem_type, ptr + i * elem_size(elem_type)) + store(cx, e, elem_type, ptr + i * elem_size(elem_type, cx.opts)) def store_record(cx, v, ptr, fields): for f in fields: - ptr = align_to(ptr, alignment(f.t)) + ptr = align_to(ptr, alignment(f.t, cx.opts)) store(cx, v[f.label], f.t, ptr) - ptr += elem_size(f.t) + ptr += elem_size(f.t, cx.opts) ``` Variant values are represented as Python dictionaries containing exactly one @@ -2645,10 +2678,10 @@ indices. ```python def store_variant(cx, v, ptr, cases): case_index, case_value = match_case(v, cases) - disc_size = elem_size(discriminant_type(cases)) + disc_size = elem_size(discriminant_type(cases), cx.opts) store_int(cx, case_index, ptr, disc_size) ptr += disc_size - ptr = align_to(ptr, max_case_alignment(cases)) + ptr = align_to(ptr, max_case_alignment(cases, cx.opts)) c = cases[case_index] if c.t is not None: store(cx, case_value, c.t, ptr) @@ -2752,38 +2785,38 @@ MAX_FLAT_ASYNC_PARAMS = 4 MAX_FLAT_RESULTS = 1 def flatten_functype(opts, ft, context): - flat_params = flatten_types(ft.param_types()) - flat_results = flatten_types(ft.result_type()) + flat_params = flatten_types(ft.param_types(), opts) + flat_results = flatten_types(ft.result_type(), opts) if not opts.async_: if len(flat_params) > MAX_FLAT_PARAMS: - flat_params = ['i32'] + flat_params = [ptr_type(opts)] if len(flat_results) > MAX_FLAT_RESULTS: match context: case 'lift': - flat_results = ['i32'] + flat_results = [ptr_type(opts)] case 'lower': - flat_params += ['i32'] + flat_params += [ptr_type(opts)] flat_results = [] return CoreFuncType(flat_params, flat_results) else: match context: case 'lift': if len(flat_params) > MAX_FLAT_PARAMS: - flat_params = ['i32'] + flat_params = [ptr_type(opts)] if opts.callback: flat_results = ['i32'] else: flat_results = [] case 'lower': if len(flat_params) > MAX_FLAT_ASYNC_PARAMS: - flat_params = ['i32'] + flat_params = [ptr_type(opts)] if len(flat_results) > 0: - flat_params += ['i32'] + flat_params += [ptr_type(opts)] flat_results = ['i32'] return CoreFuncType(flat_params, flat_results) -def flatten_types(ts): - return [ft for t in ts for ft in flatten_type(t)] +def flatten_types(ts, opts): + return [ft for t in ts for ft in flatten_type(t, opts)] ``` As shown here, the core signatures `async` functions use a lower limit on the maximum number of parameters (1) and results (0) passed as scalars before @@ -2792,7 +2825,7 @@ falling back to passing through memory. Presenting the definition of `flatten_type` piecewise, we start with the top-level case analysis: ```python -def flatten_type(t): +def flatten_type(t, opts): match despecialize(t): case BoolType() : return ['i32'] case U8Type() | U16Type() | U32Type() : return ['i32'] @@ -2801,11 +2834,11 @@ def flatten_type(t): case F32Type() : return ['f32'] case F64Type() : return ['f64'] case CharType() : return ['i32'] - case StringType() : return ['i32', 'i32'] + case StringType() : return [ptr_type(opts), ptr_type(opts)] case ErrorContextType() : return ['i32'] - case ListType(t, l) : return flatten_list(t, l) - case RecordType(fields) : return flatten_record(fields) - case VariantType(cases) : return flatten_variant(cases) + case ListType(t, l) : return flatten_list(t, l, opts) + case RecordType(fields) : return flatten_record(fields, opts) + case VariantType(cases) : return flatten_variant(cases, opts) case FlagsType(labels) : return ['i32'] case OwnType() | BorrowType() : return ['i32'] case StreamType() | FutureType() : return ['i32'] @@ -2814,18 +2847,18 @@ def flatten_type(t): List flattening of a fixed-length list uses the same flattening as a tuple (via `flatten_record` below). ```python -def flatten_list(elem_type, maybe_length): +def flatten_list(elem_type, maybe_length, opts): if maybe_length is not None: - return flatten_type(elem_type) * maybe_length - return ['i32', 'i32'] + return flatten_type(elem_type, opts) * maybe_length + return [ptr_type(opts), ptr_type(opts)] ``` Record flattening simply flattens each field in sequence. ```python -def flatten_record(fields): +def flatten_record(fields, opts): flat = [] for f in fields: - flat += flatten_type(f.t) + flat += flatten_type(f.t, opts) return flat ``` @@ -2838,16 +2871,16 @@ case, all flattened variants are passed with the same static set of core types, which may involve, e.g., reinterpreting an `f32` as an `i32` or zero-extending an `i32` into an `i64`. ```python -def flatten_variant(cases): +def flatten_variant(cases, opts): flat = [] for c in cases: if c.t is not None: - for i,ft in enumerate(flatten_type(c.t)): + for i,ft in enumerate(flatten_type(c.t, opts)): if i < len(flat): flat[i] = join(flat[i], ft) else: flat.append(ft) - return flatten_type(discriminant_type(cases)) + flat + return flatten_type(discriminant_type(cases), opts) + flat def join(a, b): if a == b: return a @@ -2938,13 +2971,13 @@ def lift_flat_signed(vi, core_width, t_width): The contents of strings and variable-length lists are stored in memory so lifting these types is essentially the same as loading them from memory; the -only difference is that the pointer and length come from `i32` values instead -of from linear memory. Fixed-length lists are lifted the same way as a +only difference is that the pointer and length come from ptr-sized values +instead of from linear memory. Fixed-length lists are lifted the same way as a tuple (via `lift_flat_record` below). ```python def lift_flat_string(cx, vi): - ptr = vi.next('i32') - packed_length = vi.next('i32') + ptr = vi.next(ptr_type(cx.opts)) + packed_length = vi.next(ptr_type(cx.opts)) return load_string_from_range(cx, ptr, packed_length) def lift_flat_list(cx, vi, elem_type, maybe_length): @@ -2953,8 +2986,8 @@ def lift_flat_list(cx, vi, elem_type, maybe_length): for i in range(maybe_length): a.append(lift_flat(cx, vi, elem_type)) return a - ptr = vi.next('i32') - length = vi.next('i32') + ptr = vi.next(ptr_type(cx.opts)) + length = vi.next(ptr_type(cx.opts)) return load_list_from_range(cx, ptr, length, elem_type) ``` @@ -2975,7 +3008,7 @@ reinterprets between the different types appropriately and also traps if the high bits of an `i64` are set for a 32-bit type: ```python def lift_flat_variant(cx, vi, cases): - flat_types = flatten_variant(cases) + flat_types = flatten_variant(cases, cx.opts) assert(flat_types.pop(0) == 'i32') case_index = vi.next('i32') trap_if(case_index >= len(cases)) @@ -3092,14 +3125,14 @@ manually coercing the otherwise-incompatible type pairings allowed by `join`: ```python def lower_flat_variant(cx, v, cases): case_index, case_value = match_case(v, cases) - flat_types = flatten_variant(cases) + flat_types = flatten_variant(cases, cx.opts) assert(flat_types.pop(0) == 'i32') c = cases[case_index] if c.t is None: payload = [] else: payload = lower_flat(cx, case_value, c.t) - for i,(fv,have) in enumerate(zip(payload, flatten_type(c.t))): + for i,(fv,have) in enumerate(zip(payload, flatten_type(c.t, cx.opts))): want = flat_types.pop(0) match (have, want): case ('f32', 'i32') : payload[i] = encode_float_as_i32(fv) @@ -3126,12 +3159,12 @@ parameters or results (given by the `CoreValueIter` `vi`) into a tuple of component-level values with types `ts`. ```python def lift_flat_values(cx, max_flat, vi, ts): - flat_types = flatten_types(ts) + flat_types = flatten_types(ts, cx.opts) if len(flat_types) > max_flat: - ptr = vi.next('i32') + ptr = vi.next(ptr_type(cx.opts)) tuple_type = TupleType(ts) - trap_if(ptr != align_to(ptr, alignment(tuple_type))) - trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(tuple_type, cx.opts))) + trap_if(ptr + elem_size(tuple_type, cx.opts) > len(cx.opts.memory)) return list(load(cx, ptr, tuple_type).values()) else: return [ lift_flat(cx, vi, t) for t in ts ] @@ -3146,18 +3179,18 @@ out-param: ```python def lower_flat_values(cx, max_flat, vs, ts, out_param = None): cx.inst.may_leave = False - flat_types = flatten_types(ts) + flat_types = flatten_types(ts, cx.opts) if len(flat_types) > max_flat: tuple_type = TupleType(ts) tuple_value = {str(i): v for i,v in enumerate(vs)} if out_param is None: - ptr = cx.opts.realloc(0, 0, alignment(tuple_type), elem_size(tuple_type)) + ptr = cx.opts.realloc(0, 0, alignment(tuple_type, cx.opts), elem_size(tuple_type, cx.opts)) flat_vals = [ptr] else: - ptr = out_param.next('i32') + ptr = out_param.next(ptr_type(cx.opts)) flat_vals = [] - trap_if(ptr != align_to(ptr, alignment(tuple_type))) - trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(tuple_type, cx.opts))) + trap_if(ptr + elem_size(tuple_type, cx.opts) > len(cx.opts.memory)) store(cx, tuple_value, tuple_type, ptr) else: flat_vals = [] @@ -3188,15 +3221,17 @@ present, is validated as such: * `string-encoding=N` - can be passed at most once, regardless of `N`. * `memory` - this is a subtype of `(memory 1)` -* `realloc` - the function has type `(func (param i32 i32 i32 i32) (result i32))` +* `realloc` - the function has type `(func (param T T i32 T) (result T))` + where `T` is `i32` or `i64` determined by `memory` as described [above](#canonopt) * if `realloc` is present, then `memory` must be present * `post-return` - only allowed on [`canon lift`](#canon-lift), which has rules for validation * ๐Ÿ”€ `async` - cannot be present with `post-return` * ๐Ÿ”€,not(๐ŸšŸ) `async` - `callback` must also be present. Note that with the ๐ŸšŸ feature (the "stackful" ABI), this restriction is lifted. -* ๐Ÿ”€ `callback` - the function has type `(func (param i32 i32 i32) (result i32))` - and cannot be present without `async` and is only allowed with +* ๐Ÿ”€ `callback` - the function has type `(func (param i32 i32 T) (result i32))` + where the `T` parameter is the payload address and cannot be present + without `async` and is only allowed with [`canon lift`](#canon-lift) Additionally some options are required depending on lift/lower operations @@ -3714,15 +3749,15 @@ For a canonical definition: (canon context.get $t $i (core func $f)) ``` validation specifies: -* `$t` must be `i32` (for now; see [here][thread-local storage]) +* `$t` must be `i32` or `i64` (see [here][thread-local storage]) * `$i` must be less than `Thread.CONTEXT_LENGTH` (`2`) -* `$f` is given type `(func (result i32))` +* `$f` is given type `(func (result $t))` Calling `$f` invokes the following function, which reads the [thread-local storage] of the [current thread]: ```python def canon_context_get(t, i, thread): - assert(t == 'i32') + assert(t == 'i32' or t == 'i64') assert(i < Thread.CONTEXT_LENGTH) return [thread.context[i]] ``` @@ -3735,15 +3770,15 @@ For a canonical definition: (canon context.set $t $i (core func $f)) ``` validation specifies: -* `$t` must be `i32` (for now; see [here][thread-local storage]) +* `$t` must be `i32` or `i64` (see [here][thread-local storage]) * `$i` must be less than `Thread.CONTEXT_LENGTH` (`2`) -* `$f` is given type `(func (param $v i32))` +* `$f` is given type `(func (param $v $t))` Calling `$f` invokes the following function, which writes to the [thread-local storage] of the [current thread]: ```python def canon_context_set(t, i, thread, v): - assert(t == 'i32') + assert(t == 'i32' or t == 'i64') assert(i < Thread.CONTEXT_LENGTH) thread.context[i] = v return [] @@ -3907,23 +3942,24 @@ For a canonical definition: (canon waitable-set.wait $cancellable? (memory $mem) (core func $f)) ``` validation specifies: -* `$f` is given type `(func (param $si) (param $ptr i32) (result i32))` +* `$f` is given type `(func (param $si i32) (param $ptr T) (result i32))` where + `T` is the address type of `$mem`. Calling `$f` invokes the following function which waits for progress to be made on a `Waitable` in the given waitable set (indicated by index `$si`) and then returning its `EventCode` and writing the payload values into linear memory: ```python -def canon_waitable_set_wait(cancellable, mem, thread, si, ptr): +def canon_waitable_set_wait(cancellable, mem, opts, thread, si, ptr): trap_if(not thread.task.inst.may_leave) trap_if(not thread.task.may_block()) wset = thread.task.inst.handles.get(si) trap_if(not isinstance(wset, WaitableSet)) event = thread.task.wait_until(lambda: True, thread, wset, cancellable) - return unpack_event(mem, thread, ptr, event) + return unpack_event(mem, opts, thread, ptr, event) -def unpack_event(mem, thread, ptr, e: EventTuple): +def unpack_event(mem, opts, thread, ptr, e: EventTuple): event, p1, p2 = e - cx = LiftLowerContext(LiftLowerOptions(memory = mem), thread.task.inst) + cx = LiftLowerContext(LiftLowerOptions(memory = mem, addr_type = opts.addr_type, tbl_idx_type = opts.tbl_idx_type), thread.task.inst) store(cx, p1, U32Type(), ptr) store(cx, p2, U32Type(), ptr + 4) return [event] @@ -3950,13 +3986,14 @@ For a canonical definition: (canon waitable-set.poll $cancellable? (memory $mem) (core func $f)) ``` validation specifies: -* `$f` is given type `(func (param $si i32) (param $ptr i32) (result i32))` +* `$f` is given type `(func (param $si i32) (param $ptr T) (result i32))` where + `T` is the address type of `$mem`. Calling `$f` invokes the following function, which either returns an event that was pending on one of the waitables in the given waitable set (the same way as `waitable-set.wait`) or, if there is none, returns `0`. ```python -def canon_waitable_set_poll(cancellable, mem, thread, si, ptr): +def canon_waitable_set_poll(cancellable, mem, opts, thread, si, ptr): trap_if(not thread.task.inst.may_leave) wset = thread.task.inst.handles.get(si) trap_if(not isinstance(wset, WaitableSet)) @@ -3966,7 +4003,7 @@ def canon_waitable_set_poll(cancellable, mem, thread, si, ptr): event = (EventCode.NONE, 0, 0) else: event = wset.get_pending_event() - return unpack_event(mem, thread, ptr, event) + return unpack_event(mem, opts, thread, ptr, event) ``` If `cancellable` is set, then `waitable-set.poll` will return whether the supertask has already or concurrently requested cancellation. @@ -4167,7 +4204,8 @@ For canonical definitions: ``` In addition to [general validation of `$opts`](#canonopt-validation) validation specifies: -* `$f` is given type `(func (param i32 i32 i32) (result i32))` +* `$f` is given type `(func (param i32 T T) (result T))` where `T` is `i32` or + `i64` determined by the `memory` from `$opts` * `$stream_t` must be a type of the form `(stream $t?)` * If `$t` is present: * [`lower($t)` above](#canonopt-validation) defines required options for `stream.write` @@ -4232,7 +4270,7 @@ context switches. Next, the stream's `state` is updated based on the result being delivered to core wasm so that, once a stream end has been notified that the other end dropped, calling anything other than `stream.drop-*` traps. Lastly, `stream_event` packs the `CopyResult` and number of elements copied up -until this point into a single `i32` payload for core wasm. +until this point into a single `T`-sized payload for core wasm. ```python def stream_event(result, reclaim_buffer): reclaim_buffer() @@ -4282,7 +4320,8 @@ For canonical definitions: ``` In addition to [general validation of `$opts`](#canonopt-validation) validation specifies: -* `$f` is given type `(func (param i32 i32) (result i32))` +* `$f` is given type `(func (param i32 T) (result i32))` where `T` is `i32` or + `i64` determined by the `memory` from `$opts` * `$future_t` must be a type of the form `(future $t?)` * If `$t` is present: * [`lift($t)` above](#canonopt-validation) defines required options for `future.read` @@ -4329,7 +4368,7 @@ state (in which the only valid operation is to call `future.drop-*`) on read/written at most once and futures are only passed to other components in a state where they are ready to be read/written. Another important difference is that, since the buffer length is always implied by the `CopyResult`, the number -of elements copied is not packed in the high 28 bits; they're always zero. +of elements copied is not packed in the high bits; they're always zero. ```python def future_event(result): assert((buffer.remain() == 0) == (result == CopyResult.COMPLETED)) @@ -4505,9 +4544,11 @@ For a canonical definition: (canon thread.new-indirect $ft $ftbl (core func $new_indirect)) ``` validation specifies -* `$ft` must refer to the type `(func (param $c i32))` +* `$ft` must refer to the type `(func (param $c T))` where `T` is `i32` or + `i64` determined by the linear memory of the component instance * `$ftbl` must refer to a table whose element type matches `funcref` -* `$new_indirect` is given type `(func (param $fi i32) (param $c i32) (result i32))` +* `$new_indirect` is given type `(func (param $fi I) (param $c T) (result i32))` + where `I` is `i32` or `i64` determined by `$ftbl`'s table type Calling `$new_indirect` invokes the following function which reads a `funcref` from `$ftbl` (trapping if out-of-bounds, null or the wrong type), calls the @@ -4522,7 +4563,7 @@ class CoreFuncRef: def canon_thread_new_indirect(ft, ftbl: Table[CoreFuncRef], thread, fi, c): trap_if(not thread.task.inst.may_leave) f = ftbl.get(fi) - assert(ft == CoreFuncType(['i32'], [])) + assert(ft == CoreFuncType(['i32'], []) or ft == CoreFuncType(['i64'], [])) trap_if(f.t != ft) def thread_func(thread): [] = call_and_trap_on_throw(f.callee, thread, [c]) @@ -4702,7 +4743,8 @@ For a canonical definition: (canon error-context.new $opts (core func $f)) ``` validation specifies: -* `$f` is given type `(func (param i32 i32) (result i32))` +* `$f` is given type `(func (param T T) (result i32))` where `T` is `i32` or + `i64` determined by the `memory` from `$opts` * `async` is not present * `memory` must be present @@ -4743,7 +4785,8 @@ For a canonical definition: (canon error-context.debug-message $opts (core func $f)) ``` validation specifies: -* `$f` is given type `(func (param i32 i32))` +* `$f` is given type `(func (param i32 T))` where `T` is `i32` or `i64` + determined by the `memory` from `$opts` * `async` is not present * `memory` must be present * `realloc` must be present @@ -4762,8 +4805,9 @@ def canon_error_context_debug_message(opts, thread, i, ptr): store_string(cx, errctx.debug_message, ptr) return [] ``` -Note that `ptr` points to an 8-byte region of memory into which will be stored -the pointer and length of the debug string (allocated via `opts.realloc`). +Note that `ptr` points to a region of memory (8 bytes for memory32, 16 bytes +for memory64) into which will be stored the pointer and length of the debug +string (allocated via `opts.realloc`). ### ๐Ÿ“ `canon error-context.drop` @@ -4793,9 +4837,11 @@ For a canonical definition: (canon thread.spawn-ref shared? $ft (core func $spawn_ref)) ``` validation specifies: -* `$ft` must refer to the type `(shared? (func (param $c i32)))` (see explanation below) +* `$ft` must refer to the type `(shared? (func (param $c T)))` where `T` is + `i32` or `i64` determined by the linear memory of the component instance + (see explanation below) * `$spawn_ref` is given type - `(shared? (func (param $f (ref null $ft)) (param $c i32) (result $e i32)))` + `(shared? (func (param $f (ref null $ft)) (param $c T) (result $e i32)))` When the `shared` immediate is not present, the spawned thread is *cooperative*, only switching at specific program points. When the `shared` @@ -4804,7 +4850,7 @@ parallel with all other threads. > Note: ideally, a thread could be spawned with [arbitrary thread parameters]. > Currently, that would require additional work in the toolchain to support so, -> for simplicity, the current proposal simply fixes a single `i32` parameter +> for simplicity, the current proposal simply fixes a single `T` parameter > type. However, `thread.spawn-ref` could be extended to allow arbitrary thread > parameters in the future, once it's concretely beneficial to the toolchain. > The inclusion of `$ft` ensures backwards compatibility for when arbitrary @@ -4834,12 +4880,14 @@ For a canonical definition: (canon thread.spawn-indirect shared? $ft $tbl (core func $spawn_indirect)) ``` validation specifies: -* `$ft` must refer to the type `(shared? (func (param $c i32)))` is allowed - (see explanation in `thread.spawn-ref` above) +* `$ft` must refer to the type `(shared? (func (param $c T)))` is allowed + where `T` is `i32` or `i64` determined by the linear memory of the component + instance (see explanation in `thread.spawn-ref` above) * `$tbl` must refer to a shared table whose element type matches `(ref null (shared? func))` * `$spawn_indirect` is given type - `(shared? (func (param $i i32) (param $c i32) (result $e i32)))` + `(shared? (func (param $i I) (param $c T) (result $e i32)))` where `I` is + `i32` or `i64` determined by `$tbl`'s table type When the `shared` immediate is not present, the spawned thread is *cooperative*, only switching at specific program points. When the `shared` diff --git a/design/mvp/Concurrency.md b/design/mvp/Concurrency.md index 6c9c5f6f..b357d3f2 100644 --- a/design/mvp/Concurrency.md +++ b/design/mvp/Concurrency.md @@ -151,7 +151,7 @@ use cases mentioned in the [goals](#goals). Until the Core WebAssembly [shared-everything-threads] proposal allows Core WebAssembly function types to be annotated with `shared`, `thread.new-indirect` -can only call non-`shared` functions (via `i32` `(table funcref)` index, just +can only call non-`shared` functions (via `(table funcref)` index, just like `call_indirect`) and thus currently all threads must execute [cooperatively] in a sequentially-interleaved fashion, switching between threads only at explicit program points just like (and implementable via) a @@ -232,7 +232,7 @@ unique ownership of the *readable end* of the future or stream. To get a end pair (via the [`{stream,future}.new`] built-ins) and then pass the readable end elsewhere (e.g., in the above WIT, as a parameter to an imported `pipe.write` or as a result of an exported `transform`). Given the readable or -writable end of a future or stream (represented as an `i32` index into the +writable end of a future or stream (represented as an index into the component instance's handle table), Core WebAssembly can then call a [`{stream,future}.{read,write}`] built-in to synchronously or asynchronously copy into or out of a caller-provided buffer of Core WebAssembly linear (or, @@ -369,9 +369,9 @@ creating and running threads. New threads are created with the [`thread.new-indirect`] built-in. As mentioned [above](#threads-and-tasks), a spawned thread inherits the task of the spawning thread which is why threads and tasks are N:1. `thread.new-indirect` adds a new -thread to the component instance's threads table and returns the `i32` index of +thread to the component instance's threads table and returns the index of this table entry to the Core WebAssembly caller. Like [`pthread_create`], -`thread.new-indirect` takes a Core WebAssembly function (via `i32` index into a +`thread.new-indirect` takes a Core WebAssembly function (via index into a `funcref` table) and a "closure" parameter to pass to the function when called on the new thread. However, unlike `pthread_create`, the new thread is initially in a "suspended" state and must be explicitly "resumed" using one of @@ -413,10 +413,10 @@ Each thread contains a distinct mutable **thread-local storage** array. The current thread's thread-local storage can be read and written from core wasm code by calling the [`context.get`] and [`context.set`] built-ins. -The thread-local storage array's length is currently fixed to contain exactly -2 `i32`s with the goal of allowing this array to be stored inline in whatever -existing runtime data structure is already efficiently reachable from ambient -compiled wasm code. Because module instantiation is declarative in the +The thread-local storage array's length is currently fixed to contain exactly 2 +`i32`s or `i64`s with the goal of allowing this array to be stored inline in +whatever existing runtime data structure is already efficiently reachable from +ambient compiled wasm code. Because module instantiation is declarative in the Component Model, the imported `context.{get,set}` built-ins can be inlined by the core wasm compiler as-if they were instructions, allowing the generated machine code to be a single load or store. This makes thread-local storage a @@ -436,9 +436,8 @@ stackless async ABI is used, returning the "exit" code to the event loop. This non-reuse of thread-local storage between distinct export calls avoids what would otherwise be a likely source of TLS-related memory leaks. -When [memory64] is integrated into the Component Model's Canonical ABI, -`context.{get,set}` will be backwards-compatibly relaxed to allow `i64` -pointers (overlaying the `i32` values like hardware 32/64-bit registers). When +The type of `context.{get,set}` values (`i32` or `i64`) is determined by the +`memory` canonopt (matching the pointer size of the linear memory). When [wasm-gc] is integrated, these integral context values can serve as indices into guest-managed tables of typed GC references. @@ -945,8 +944,8 @@ Other example asynchronous lowered signatures: async func(s1: stream>, s2: list>) -> result, stream> ``` In *both* the sync and async ABIs, a `future` or `stream` in the WIT-level type -translates to a single `i32` in the ABI. This `i32` is an index into the -current component instance's handle table. For example, for the WIT function type: +translates to a single `i32` index into the current component instance's handle +table. For example, for the WIT function type: ```wit async func(f: future) -> future ``` @@ -959,10 +958,10 @@ and the asynchronous ABI has the signature: (func (param $f i32) (param $out-ptr i32) (result i32)) ``` where `$f` is the index of a future (not a pointer to one) while while -`$out-ptr` is a pointer to a linear memory location that will receive an `i32` +`$out-ptr` is a pointer to a linear memory location that will receive a handle index. -For the runtime semantics of this `i32` index, see `lift_stream`, +For the runtime semantics of this handle index, see `lift_stream`, `lift_future`, `lower_stream` and `lower_future` in the [Canonical ABI Explainer]. For a complete description of how async imports work, see [`canon_lower`] in the Canonical ABI Explainer. @@ -1032,18 +1031,19 @@ The `(result i32)` lets the core function return what it wants the runtime to do * If the low 4 bits are `1`, the callee wants to yield, allowing other code to run, but resuming thereafter without waiting on anything else. * If the low 4 bits are `2`, the callee wants to wait for an event to occur in - the waitable set whose index is stored in the high 28 bits. + the waitable set whose index is stored in the remaining high bits. When an async stackless function is exported, a companion "callback" function must also be exported with signature: ```wat -(func (param i32 i32 i32) (result i32)) +(func (param i32 i32 $addr) (result i32)) ``` +where `$addr` is `i32` or `i64` depending on the `memory` canonopt. The `(result i32)` has the same interpretation as the stackless export function and the runtime will repeatedly call the callback until a value of `0` is -returned. The `i32` parameters describe what happened that caused the callback -to be called again. +returned. The first two `i32` parameters describe what happened that caused the +callback to be called again and the `$addr` parameter is the payload address. For a complete description of how async exports work, see [`canon_lift`] in the Canonical ABI Explainer. diff --git a/design/mvp/Explainer.md b/design/mvp/Explainer.md index d8c0bdd1..26661274 100644 --- a/design/mvp/Explainer.md +++ b/design/mvp/Explainer.md @@ -718,11 +718,11 @@ Just like with handles, in the Component Model, async value types are lifted-from and lowered-into `i32` values that index an encapsulated per-component-instance table that is maintained by the canonical ABI built-ins [below](#canonical-definitions). The Component-Model-defined ABI for creating, -writing-to and reading-from `stream` and `future` values is meant to be bound -to analogous source-language features like promises, futures, streams, -iterators, generators and channels so that developers can use these familiar -high-level concepts when working directly with component types, without the -need to manually write low-level async glue code. For languages like C without +writing-to and reading-from `stream` and `future` values is meant to be bound to +analogous source-language features like promises, futures, streams, iterators, +generators and channels so that developers can use these familiar high-level +concepts when working directly with component types, without the need to +manually write low-level async glue code. For languages like C without language-level concurrency support, these ABIs (described in detail in the [Canonical ABI explainer]) can be exposed directly as function imports and used like normal low-level Operation System I/O APIs. @@ -1281,6 +1281,7 @@ canonopt ::= string-encoding=utf8 | (memory ) | (realloc ) | (post-return ) + | table64 | async ๐Ÿ”€ | (callback ) ๐Ÿ”€ ``` @@ -1304,15 +1305,17 @@ validation requires this option to be present (there is no default). The `(realloc ...)` option specifies a core function that is validated to have the following core function type: ```wat -(func (param $originalPtr i32) - (param $originalSize i32) +(func (param $originalPtr $addr) + (param $originalSize $addr) (param $alignment i32) - (param $newSize i32) - (result i32)) + (param $newSize $addr) + (result $addr)) ``` -The Canonical ABI will use `realloc` both to allocate (passing `0` for the -first two parameters) and reallocate. If the Canonical ABI needs `realloc`, -validation requires this option to be present (there is no default). +where `$addr` is `i32` when the `memory` canonopt refers to a 32-bit memory or +`i64` when it refers to a 64-bit memory. The Canonical ABI will use `realloc` +both to allocate (passing `0` for the first two parameters) and reallocate. If +the Canonical ABI needs `realloc`, validation requires this option to be +present (there is no default). The `(post-return ...)` option may only be present in `canon lift` when `async` is not present and specifies a core function to be called with the @@ -1335,9 +1338,10 @@ validated to have the following core function type: ```wat (func (param $ctx i32) (param $event i32) - (param $payload i32) + (param $payload $addr) (result $done i32)) ``` +where `$addr` is determined by the `memory` canonopt as described above. Again, see the [concurrency explainer] for more details. Based on this description of the AST, the [Canonical ABI explainer] gives a @@ -1452,9 +1456,9 @@ canon ::= ... | (canon thread.new-indirect (core func ?)) ๐Ÿงต | (canon thread.switch-to cancellable? (core func ?)) ๐Ÿงต | (canon thread.suspend cancellable? (core func ?)) ๐Ÿงต - | (canon thread.resume-later (core func ?) ๐Ÿงต - | (canon thread.yield-to cancellable? (core func ?) ๐Ÿงต - | (canon thread.yield cancellable? (core func ?) ๐Ÿงต + | (canon thread.resume-later (core func ?)) ๐Ÿงต + | (canon thread.yield-to cancellable? (core func ?)) ๐Ÿงต + | (canon thread.yield cancellable? (core func ?)) ๐Ÿงต | (canon error-context.new * (core func ?)) ๐Ÿ“ | (canon error-context.debug-message * (core func ?)) ๐Ÿ“ | (canon error-context.drop (core func ?)) ๐Ÿ“ @@ -1480,7 +1484,7 @@ component that defined `T`. In the Canonical ABI, `T.rep` is defined to be the `$rep` in the `(type $T (resource (rep $rep) ...))` type definition that defined `T`. While it's designed to allow different types in the future, it is currently -hard-coded to always be `i32`. +hard-coded to always be `i32` or `i64`. For details, see [`canon_resource_new`] in the Canonical ABI explainer. @@ -1554,12 +1558,12 @@ See the [concurrency explainer] for background. | Synopsis | | | -------------------------- | ------------------ | | Approximate WIT signature | `func() -> T` | -| Canonical ABI signature | `[] -> [i32]` | +| Canonical ABI signature | `[] -> [$T]` | The `context.get` built-in returns the `i`th element of the [current thread]'s [thread-local storage] array. Validation currently restricts `i` to be less -than 2 and `t` to be `i32`, but these restrictions may be relaxed in the -future. +than 2 and `T` to be `i32` or `i64` (determined by the `memory` canonopt), but +these restrictions may be relaxed in the future. For details, see [Thread-Local Storage] in the concurrency explainer and [`canon_context_get`] in the Canonical ABI explainer. @@ -1569,12 +1573,12 @@ For details, see [Thread-Local Storage] in the concurrency explainer and | Synopsis | | | -------------------------- | ----------------- | | Approximate WIT signature | `func(v: T)` | -| Canonical ABI signature | `[i32] -> []` | +| Canonical ABI signature | `[$T] -> []` | The `context.set` built-in sets the `i`th element of the [current thread]'s [thread-local storage] array to the value `v`. Validation currently restricts -`i` to be less than 2 and `t` to be `i32`, but these restrictions may be -relaxed in the future. +`i` to be less than 2 and `T` to be `i32` or `i64` (determined by the `memory` +canonopt), but these restrictions may be relaxed in the future. For details, see [Thread-Local Storage] in the concurrency explainer and [`canon_context_set`] in the Canonical ABI explainer. @@ -1673,7 +1677,7 @@ For details, see [Waitables and Waitable Sets] in the concurrency explainer and | Synopsis | | | -------------------------- | ---------------------------------------------- | | Approximate WIT signature | `func(s: waitable-set) -> event` | -| Canonical ABI signature | `[s:i32 payload-addr:i32] -> [event-code:i32]` | +| Canonical ABI signature | `[s:i32 payload-addr:$addr] -> [event-code:i32]` | where `event` is defined in WIT as: ```wit @@ -1738,7 +1742,7 @@ For details, see [Waitables and Waitable Sets] in the concurrency explainer and | Synopsis | | | -------------------------- | ---------------------------------------------- | | Approximate WIT signature | `func(s: waitable-set) -> event` | -| Canonical ABI signature | `[s:i32 payload-addr:i32] -> [event-code:i32]` | +| Canonical ABI signature | `[s:i32 payload-addr:$addr] -> [event-code:i32]` | where `event` is defined as in [`waitable-set.wait`](#-waitable-setwait). @@ -1856,7 +1860,7 @@ For details, see [Streams and Futures] in the concurrency explainer and | -------------------------------------------- | ----------------------------------------------------------------------------------------------- | | Approximate WIT signature for `stream.read` | `func>(e: readable-stream-end, b: writable-buffer?) -> option` | | Approximate WIT signature for `stream.write` | `func>(e: writable-stream-end, b: readable-buffer?) -> option` | -| Canonical ABI signature | `[stream-end:i32 ptr:i32 num:i32] -> [i32]` | +| Canonical ABI signature | `[stream-end:i32 ptr:$addr num:$addr] -> [$addr]` | where `stream-result` is defined in WIT as: ```wit @@ -1912,13 +1916,13 @@ any subsequent operation on the stream other than `stream.drop-{readable,writabl traps. In the Canonical ABI, the `{readable,writable}-stream-end` is passed as an -`i32` index into the component instance's table followed by a pair of `i32`s +`i32` index into the component instance's table followed by a pair of `$addr`s describing the linear memory offset and size-in-elements of the `{readable,writable}-buffer`. The `option` return value is -bit-packed into a single `i32` where: -* `0xffff_ffff` represents `none`. +bit-packed into a single `$addr` where: +* all-ones represents `none`. * Otherwise, the `result` is in the low 4 bits and the `progress` is in the - high 28 bits. + remaining high bits. For details, see [Streams and Futures] in the concurrency explainer and [`canon_stream_read`] in the Canonical ABI explainer. @@ -1929,7 +1933,7 @@ For details, see [Streams and Futures] in the concurrency explainer and | -------------------------------------------- | -------------------------------------------------------------------------------------------------------- | | Approximate WIT signature for `future.read` | `func>(e: readable-future-end, b: writable-buffer?) -> option` | | Approximate WIT signature for `future.write` | `func>(e: writable-future-end, v: readable-buffer?) -> option` | -| Canonical ABI signature | `[readable-future-end:i32 ptr:i32] -> [i32]` | +| Canonical ABI signature | `[readable-future-end:i32 ptr:$addr] -> [i32]` | where `future-{read,write}-result` are defined in WIT as: ```wit @@ -1980,10 +1984,10 @@ called before successfully writing a value. In the Canonical ABI, the `{readable,writable}-future-end` is passed as an `i32` index into the component instance's table followed by a single -`i32` describing the linear memory offset of the +`$addr` describing the linear memory offset of the `{readable,writable}-buffer`. The `option` -return value is bit-packed into the single `i32` return value where -`0xffff_ffff` represents `none`. And, `future-read-result.cancelled` is encoded +return value is bit-packed into the single `i32` return value where all-ones +represents `none`. And, `future-read-result.cancelled` is encoded as the value of `future-write-result.cancelled`, rather than the value implied by the `enum` definition above. @@ -2057,7 +2061,7 @@ For details, see [Thread Built-ins] in the concurrency explainer and | Synopsis | | | -------------------------- | ------------------------------------------------------------- | | Approximate WIT signature | `func(fi: u32, c: FuncT.params[0]) -> thread` | -| Canonical ABI signature | `[fi:i32 c:i32] -> [i32]` | +| Canonical ABI signature | `[fi:$idx c:$addr] -> [i32]` | The `thread.new-indirect` built-in adds a new thread to the current component instance's table, returning the index of the new thread. The function table @@ -2066,9 +2070,9 @@ dynamically checked to match the type `FuncT` (in the same manner as `call_indirect`). Lastly, the indexed function is called in the new thread with `c` as its first and only parameter. -Currently, `FuncT` must be `(func (param i32))` and thus `c` must always be an -`i32`, but this restriction can be loosened in the future as the Canonical -ABI is extended for [memory64] and [GC]. +Currently, `FuncT` must be `(func (param $addr))` and thus `c` must always be +an `$addr`, but this restriction can be loosened in the future as the Canonical +ABI is extended for [GC]. As explained in the [concurrency explainer], a thread created by `thread.new-indirect` is initially in a suspended state and must be resumed @@ -2151,7 +2155,7 @@ For details, see [Thread Built-ins] in the concurrency explainer and | Synopsis | | | -------------------------- | ------------------------------- | | Approximate WIT signature | `func(t: thread)` | -| Canonical ABI signature | `[t:i32] -> [suspend-result]` | +| Canonical ABI signature | `[t:i32] -> [i32]` | The `thread.yield-to` built-in immediately resumes execution of the thread `t`, (trapping if `t` is not in a "suspended" state) leaving the [current thread] in @@ -2201,7 +2205,7 @@ For details, see [Thread Built-ins] in the concurrency explainer and | Synopsis | | | -------------------------- | ------------------------------------------------------------------ | | Approximate WIT signature | `func(f: FuncT, c: FuncT.params[0]) -> bool` | -| Canonical ABI signature | `shared? [f:(ref null (shared (func (param i32))) c:i32] -> [i32]` | +| Canonical ABI signature | `shared? [f:(ref null (shared (func (param $addr))) c:$addr] -> [i32]` | The `thread.spawn-ref` built-in is an optimization, fusing a call to `thread.new_ref` (assuming `thread.new_ref` was added as part of adding a @@ -2216,7 +2220,7 @@ For details, see [`canon_thread_spawn_ref`] in the Canonical ABI explainer. | Synopsis | | | -------------------------- | ------------------------------------------------------------------ | | Approximate WIT signature | `func(i: u32, c: FuncT.params[0]) -> bool` | -| Canonical ABI signature | `shared? [i:i32 c:i32] -> [i32]` | +| Canonical ABI signature | `shared? [i:$idx c:$addr] -> [i32]` | The `thread.spawn-indirect` built-in is an optimization, fusing a call to [`thread.new-indirect`](#-threadnew-indirect) with a call to @@ -2251,7 +2255,7 @@ explainer. | Synopsis | | | -------------------------------- | ---------------------------------------- | | Approximate WIT signature | `func(message: string) -> error-context` | -| Canonical ABI signature | `[ptr:i32 len:i32] -> [i32]` | +| Canonical ABI signature | `[ptr:$addr len:$addr] -> [i32]` | The `error-context.new` built-in returns a new `error-context` value. The given string is non-deterministically transformed to produce the `error-context`'s @@ -2267,14 +2271,14 @@ For details, see [`canon_error_context_new`] in the Canonical ABI explainer. | Synopsis | | | -------------------------------- | --------------------------------------- | | Approximate WIT signature | `func(errctx: error-context) -> string` | -| Canonical ABI signature | `[errctxi:i32 ptr:i32] -> []` | +| Canonical ABI signature | `[errctxi:i32 ptr:$addr] -> []` | The `error-context.debug-message` built-in returns the [debug message](#error-context-type) of the given `error-context`. -In the Canonical ABI, it writes the debug message into `ptr` as an 8-byte -(`ptr`, `length`) pair, according to the Canonical ABI for `string`, given the -`*` immediates. +In the Canonical ABI, it writes the debug message into `ptr` as an 8-byte or +16-byte (`ptr`, `length`) pair, according to the Canonical ABI for `string`, +given the `*` immediates. For details, see [`canon_error_context_debug_message`] in the Canonical ABI explainer. diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py index 32db6db0..0fea104d 100644 --- a/design/mvp/canonical-abi/definitions.py +++ b/design/mvp/canonical-abi/definitions.py @@ -234,10 +234,30 @@ def __init__(self, opts, inst, borrow_scope = None): class LiftOptions: string_encoding: str = 'utf8' memory: Optional[bytearray] = None + addr_type: str = 'i32' + tbl_idx_type: str = 'i32' def equal(lhs, rhs): return lhs.string_encoding == rhs.string_encoding and \ - lhs.memory is rhs.memory + lhs.memory is rhs.memory and \ + lhs.addr_type == rhs.addr_type and \ + lhs.tbl_idx_type == rhs.tbl_idx_type + +def ptr_size(opts): + match opts.addr_type: + case 'i32': return 4 + case 'i64': return 8 + +def ptr_type(opts): + return opts.addr_type + +def idx_size(opts): + match opts.tbl_idx_type: + case 'i32': return 4 + case 'i64': return 8 + +def idx_type(opts): + return opts.tbl_idx_type @dataclass class LiftLowerOptions(LiftOptions): @@ -775,8 +795,8 @@ class BufferGuestImpl(Buffer): def __init__(self, t, cx, ptr, length): trap_if(length > Buffer.MAX_LENGTH) if t and length > 0: - trap_if(ptr != align_to(ptr, alignment(t))) - trap_if(ptr + length * elem_size(t) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(t, cx.opts))) + trap_if(ptr + length * elem_size(t, cx.opts) > len(cx.opts.memory)) self.cx = cx self.t = t self.ptr = ptr @@ -794,7 +814,7 @@ def read(self, n): assert(n <= self.remain()) if self.t: vs = load_list_from_valid_range(self.cx, self.ptr, n, self.t) - self.ptr += n * elem_size(self.t) + self.ptr += n * elem_size(self.t, self.cx.opts) else: vs = n * [()] self.progress += n @@ -805,7 +825,7 @@ def write(self, vs): assert(len(vs) <= self.remain()) if self.t: store_list_into_valid_range(self.cx, vs, self.ptr, self.t) - self.ptr += len(vs) * elem_size(self.t) + self.ptr += len(vs) * elem_size(self.t, self.cx.opts) else: assert(all(v == () for v in vs)) self.progress += len(vs) @@ -1062,7 +1082,7 @@ def contains(t, p): ### Alignment -def alignment(t): +def alignment(t, opts): match despecialize(t): case BoolType() : return 1 case S8Type() | U8Type() : return 1 @@ -1072,28 +1092,28 @@ def alignment(t): case F32Type() : return 4 case F64Type() : return 8 case CharType() : return 4 - case StringType() : return 4 + case StringType() : return ptr_size(opts) case ErrorContextType() : return 4 - case ListType(t, l) : return alignment_list(t, l) - case RecordType(fields) : return alignment_record(fields) - case VariantType(cases) : return alignment_variant(cases) + case ListType(t, l) : return alignment_list(t, l, opts) + case RecordType(fields) : return alignment_record(fields, opts) + case VariantType(cases) : return alignment_variant(cases, opts) case FlagsType(labels) : return alignment_flags(labels) case OwnType() | BorrowType() : return 4 case StreamType() | FutureType() : return 4 -def alignment_list(elem_type, maybe_length): +def alignment_list(elem_type, maybe_length, opts): if maybe_length is not None: - return alignment(elem_type) - return 4 + return alignment(elem_type, opts) + return ptr_size(opts) -def alignment_record(fields): +def alignment_record(fields, opts): a = 1 for f in fields: - a = max(a, alignment(f.t)) + a = max(a, alignment(f.t, opts)) return a -def alignment_variant(cases): - return max(alignment(discriminant_type(cases)), max_case_alignment(cases)) +def alignment_variant(cases, opts): + return max(alignment(discriminant_type(cases), opts), max_case_alignment(cases, opts)) def discriminant_type(cases): n = len(cases) @@ -1104,11 +1124,11 @@ def discriminant_type(cases): case 2: return U16Type() case 3: return U32Type() -def max_case_alignment(cases): +def max_case_alignment(cases, opts): a = 1 for c in cases: if c.t is not None: - a = max(a, alignment(c.t)) + a = max(a, alignment(c.t, opts)) return a def alignment_flags(labels): @@ -1120,7 +1140,7 @@ def alignment_flags(labels): ### Element Size -def elem_size(t): +def elem_size(t, opts): match despecialize(t): case BoolType() : return 1 case S8Type() | U8Type() : return 1 @@ -1130,40 +1150,40 @@ def elem_size(t): case F32Type() : return 4 case F64Type() : return 8 case CharType() : return 4 - case StringType() : return 8 + case StringType() : return 2 * ptr_size(opts) case ErrorContextType() : return 4 - case ListType(t, l) : return elem_size_list(t, l) - case RecordType(fields) : return elem_size_record(fields) - case VariantType(cases) : return elem_size_variant(cases) + case ListType(t, l) : return elem_size_list(t, l, opts) + case RecordType(fields) : return elem_size_record(fields, opts) + case VariantType(cases) : return elem_size_variant(cases, opts) case FlagsType(labels) : return elem_size_flags(labels) case OwnType() | BorrowType() : return 4 case StreamType() | FutureType() : return 4 -def elem_size_list(elem_type, maybe_length): +def elem_size_list(elem_type, maybe_length, opts): if maybe_length is not None: - return maybe_length * elem_size(elem_type) - return 8 + return maybe_length * elem_size(elem_type, opts) + return 2 * ptr_size(opts) -def elem_size_record(fields): +def elem_size_record(fields, opts): s = 0 for f in fields: - s = align_to(s, alignment(f.t)) - s += elem_size(f.t) + s = align_to(s, alignment(f.t, opts)) + s += elem_size(f.t, opts) assert(s > 0) - return align_to(s, alignment_record(fields)) + return align_to(s, alignment_record(fields, opts)) def align_to(ptr, alignment): return math.ceil(ptr / alignment) * alignment -def elem_size_variant(cases): - s = elem_size(discriminant_type(cases)) - s = align_to(s, max_case_alignment(cases)) +def elem_size_variant(cases, opts): + s = elem_size(discriminant_type(cases), opts) + s = align_to(s, max_case_alignment(cases, opts)) cs = 0 for c in cases: if c.t is not None: - cs = max(cs, elem_size(c.t)) + cs = max(cs, elem_size(c.t, opts)) s += cs - return align_to(s, alignment_variant(cases)) + return align_to(s, alignment_variant(cases, opts)) def elem_size_flags(labels): n = len(labels) @@ -1175,8 +1195,8 @@ def elem_size_flags(labels): ### Loading def load(cx, ptr, t): - assert(ptr == align_to(ptr, alignment(t))) - assert(ptr + elem_size(t) <= len(cx.opts.memory)) + assert(ptr == align_to(ptr, alignment(t, cx.opts))) + assert(ptr + elem_size(t, cx.opts) <= len(cx.opts.memory)) match despecialize(t): case BoolType() : return convert_int_to_bool(load_int(cx, ptr, 1)) case U8Type() : return load_int(cx, ptr, 1) @@ -1245,11 +1265,12 @@ def convert_i32_to_char(cx, i): String = tuple[str, str, int] def load_string(cx, ptr) -> String: - begin = load_int(cx, ptr, 4) - tagged_code_units = load_int(cx, ptr + 4, 4) + begin = load_int(cx, ptr, ptr_size(cx.opts)) + tagged_code_units = load_int(cx, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) return load_string_from_range(cx, begin, tagged_code_units) -UTF16_TAG = 1 << 31 +def utf16_tag(opts): + return 1 << (ptr_size(opts) * 8 - 1) def load_string_from_range(cx, ptr, tagged_code_units) -> String: match cx.opts.string_encoding: @@ -1263,8 +1284,8 @@ def load_string_from_range(cx, ptr, tagged_code_units) -> String: encoding = 'utf-16-le' case 'latin1+utf16': alignment = 2 - if bool(tagged_code_units & UTF16_TAG): - byte_length = 2 * (tagged_code_units ^ UTF16_TAG) + if bool(tagged_code_units & utf16_tag(cx.opts)): + byte_length = 2 * (tagged_code_units ^ utf16_tag(cx.opts)) encoding = 'utf-16-le' else: byte_length = tagged_code_units @@ -1287,36 +1308,36 @@ def lift_error_context(cx, i): def load_list(cx, ptr, elem_type, maybe_length): if maybe_length is not None: return load_list_from_valid_range(cx, ptr, maybe_length, elem_type) - begin = load_int(cx, ptr, 4) - length = load_int(cx, ptr + 4, 4) + begin = load_int(cx, ptr, ptr_size(cx.opts)) + length = load_int(cx, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) return load_list_from_range(cx, begin, length, elem_type) def load_list_from_range(cx, ptr, length, elem_type): - trap_if(ptr != align_to(ptr, alignment(elem_type))) - trap_if(ptr + length * elem_size(elem_type) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(elem_type, cx.opts))) + trap_if(ptr + length * elem_size(elem_type, cx.opts) > len(cx.opts.memory)) return load_list_from_valid_range(cx, ptr, length, elem_type) def load_list_from_valid_range(cx, ptr, length, elem_type): a = [] for i in range(length): - a.append(load(cx, ptr + i * elem_size(elem_type), elem_type)) + a.append(load(cx, ptr + i * elem_size(elem_type, cx.opts), elem_type)) return a def load_record(cx, ptr, fields): record = {} for field in fields: - ptr = align_to(ptr, alignment(field.t)) + ptr = align_to(ptr, alignment(field.t, cx.opts)) record[field.label] = load(cx, ptr, field.t) - ptr += elem_size(field.t) + ptr += elem_size(field.t, cx.opts) return record def load_variant(cx, ptr, cases): - disc_size = elem_size(discriminant_type(cases)) + disc_size = elem_size(discriminant_type(cases), cx.opts) case_index = load_int(cx, ptr, disc_size) ptr += disc_size trap_if(case_index >= len(cases)) c = cases[case_index] - ptr = align_to(ptr, max_case_alignment(cases)) + ptr = align_to(ptr, max_case_alignment(cases, cx.opts)) if c.t is None: return { c.label: None } return { c.label: load(cx, ptr, c.t) } @@ -1365,8 +1386,8 @@ def lift_async_value(ReadableEndT, cx, i, t): ### Storing def store(cx, v, t, ptr): - assert(ptr == align_to(ptr, alignment(t))) - assert(ptr + elem_size(t) <= len(cx.opts.memory)) + assert(ptr == align_to(ptr, alignment(t, cx.opts))) + assert(ptr + elem_size(t, cx.opts) <= len(cx.opts.memory)) match despecialize(t): case BoolType() : store_int(cx, int(bool(v)), ptr, 1) case U8Type() : store_int(cx, v, ptr, 1) @@ -1438,16 +1459,16 @@ def char_to_i32(c): def store_string(cx, v: String, ptr): begin, tagged_code_units = store_string_into_range(cx, v) - store_int(cx, begin, ptr, 4) - store_int(cx, tagged_code_units, ptr + 4, 4) + store_int(cx, begin, ptr, ptr_size(cx.opts)) + store_int(cx, tagged_code_units, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) def store_string_into_range(cx, v: String): src, src_encoding, src_tagged_code_units = v if src_encoding == 'latin1+utf16': - if bool(src_tagged_code_units & UTF16_TAG): + if bool(src_tagged_code_units & utf16_tag(cx.opts)): src_simple_encoding = 'utf16' - src_code_units = src_tagged_code_units ^ UTF16_TAG + src_code_units = src_tagged_code_units ^ utf16_tag(cx.opts) else: src_simple_encoding = 'latin1' src_code_units = src_tagged_code_units @@ -1475,11 +1496,12 @@ def store_string_into_range(cx, v: String): case 'latin1' : return store_string_copy(cx, src, src_code_units, 1, 2, 'latin-1') case 'utf16' : return store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units) -MAX_STRING_BYTE_LENGTH = (1 << 31) - 1 +def max_string_byte_length(opts): + return (1 << (ptr_size(opts) * 8 - 1)) - 1 def store_string_copy(cx, src, src_code_units, dst_code_unit_size, dst_alignment, dst_encoding): dst_byte_length = dst_code_unit_size * src_code_units - trap_if(dst_byte_length > MAX_STRING_BYTE_LENGTH) + trap_if(dst_byte_length > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, dst_alignment, dst_byte_length) trap_if(ptr != align_to(ptr, dst_alignment)) trap_if(ptr + dst_byte_length > len(cx.opts.memory)) @@ -1497,14 +1519,14 @@ def store_latin1_to_utf8(cx, src, src_code_units): return store_string_to_utf8(cx, src, src_code_units, worst_case_size) def store_string_to_utf8(cx, src, src_code_units, worst_case_size): - assert(src_code_units <= MAX_STRING_BYTE_LENGTH) + assert(src_code_units <= max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 1, src_code_units) trap_if(ptr + src_code_units > len(cx.opts.memory)) for i,code_point in enumerate(src): if ord(code_point) < 2**7: cx.opts.memory[ptr + i] = ord(code_point) else: - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) + trap_if(worst_case_size > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(ptr, src_code_units, 1, worst_case_size) trap_if(ptr + worst_case_size > len(cx.opts.memory)) encoded = src.encode('utf-8') @@ -1517,7 +1539,7 @@ def store_string_to_utf8(cx, src, src_code_units, worst_case_size): def store_utf8_to_utf16(cx, src, src_code_units): worst_case_size = 2 * src_code_units - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) + trap_if(worst_case_size > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 2, worst_case_size) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + worst_case_size > len(cx.opts.memory)) @@ -1531,7 +1553,7 @@ def store_utf8_to_utf16(cx, src, src_code_units): return (ptr, code_units) def store_string_to_latin1_or_utf16(cx, src, src_code_units): - assert(src_code_units <= MAX_STRING_BYTE_LENGTH) + assert(src_code_units <= max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 2, src_code_units) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + src_code_units > len(cx.opts.memory)) @@ -1542,7 +1564,7 @@ def store_string_to_latin1_or_utf16(cx, src, src_code_units): dst_byte_length += 1 else: worst_case_size = 2 * src_code_units - trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH) + trap_if(worst_case_size > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(ptr, src_code_units, 2, worst_case_size) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + worst_case_size > len(cx.opts.memory)) @@ -1555,7 +1577,7 @@ def store_string_to_latin1_or_utf16(cx, src, src_code_units): ptr = cx.opts.realloc(ptr, worst_case_size, 2, len(encoded)) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + len(encoded) > len(cx.opts.memory)) - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG + tagged_code_units = int(len(encoded) / 2) | utf16_tag(cx.opts) return (ptr, tagged_code_units) if dst_byte_length < src_code_units: ptr = cx.opts.realloc(ptr, src_code_units, 2, dst_byte_length) @@ -1565,14 +1587,14 @@ def store_string_to_latin1_or_utf16(cx, src, src_code_units): def store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units): src_byte_length = 2 * src_code_units - trap_if(src_byte_length > MAX_STRING_BYTE_LENGTH) + trap_if(src_byte_length > max_string_byte_length(cx.opts)) ptr = cx.opts.realloc(0, 0, 2, src_byte_length) trap_if(ptr != align_to(ptr, 2)) trap_if(ptr + src_byte_length > len(cx.opts.memory)) encoded = src.encode('utf-16-le') cx.opts.memory[ptr : ptr+len(encoded)] = encoded if any(ord(c) >= (1 << 8) for c in src): - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG + tagged_code_units = int(len(encoded) / 2) | utf16_tag(cx.opts) return (ptr, tagged_code_units) latin1_size = int(len(encoded) / 2) for i in range(latin1_size): @@ -1590,34 +1612,34 @@ def store_list(cx, v, ptr, elem_type, maybe_length): store_list_into_valid_range(cx, v, ptr, elem_type) return begin, length = store_list_into_range(cx, v, elem_type) - store_int(cx, begin, ptr, 4) - store_int(cx, length, ptr + 4, 4) + store_int(cx, begin, ptr, ptr_size(cx.opts)) + store_int(cx, length, ptr + ptr_size(cx.opts), ptr_size(cx.opts)) def store_list_into_range(cx, v, elem_type): - byte_length = len(v) * elem_size(elem_type) - trap_if(byte_length >= (1 << 32)) - ptr = cx.opts.realloc(0, 0, alignment(elem_type), byte_length) - trap_if(ptr != align_to(ptr, alignment(elem_type))) + byte_length = len(v) * elem_size(elem_type, cx.opts) + trap_if(byte_length >= (1 << (ptr_size(cx.opts) * 8))) + ptr = cx.opts.realloc(0, 0, alignment(elem_type, cx.opts), byte_length) + trap_if(ptr != align_to(ptr, alignment(elem_type, cx.opts))) trap_if(ptr + byte_length > len(cx.opts.memory)) store_list_into_valid_range(cx, v, ptr, elem_type) return (ptr, len(v)) def store_list_into_valid_range(cx, v, ptr, elem_type): for i,e in enumerate(v): - store(cx, e, elem_type, ptr + i * elem_size(elem_type)) + store(cx, e, elem_type, ptr + i * elem_size(elem_type, cx.opts)) def store_record(cx, v, ptr, fields): for f in fields: - ptr = align_to(ptr, alignment(f.t)) + ptr = align_to(ptr, alignment(f.t, cx.opts)) store(cx, v[f.label], f.t, ptr) - ptr += elem_size(f.t) + ptr += elem_size(f.t, cx.opts) def store_variant(cx, v, ptr, cases): case_index, case_value = match_case(v, cases) - disc_size = elem_size(discriminant_type(cases)) + disc_size = elem_size(discriminant_type(cases), cx.opts) store_int(cx, case_index, ptr, disc_size) ptr += disc_size - ptr = align_to(ptr, max_case_alignment(cases)) + ptr = align_to(ptr, max_case_alignment(cases, cx.opts)) c = cases[case_index] if c.t is not None: store(cx, case_value, c.t, ptr) @@ -1669,40 +1691,40 @@ def lower_future(cx, v, t): MAX_FLAT_RESULTS = 1 def flatten_functype(opts, ft, context): - flat_params = flatten_types(ft.param_types()) - flat_results = flatten_types(ft.result_type()) + flat_params = flatten_types(ft.param_types(), opts) + flat_results = flatten_types(ft.result_type(), opts) if not opts.async_: if len(flat_params) > MAX_FLAT_PARAMS: - flat_params = ['i32'] + flat_params = [ptr_type(opts)] if len(flat_results) > MAX_FLAT_RESULTS: match context: case 'lift': - flat_results = ['i32'] + flat_results = [ptr_type(opts)] case 'lower': - flat_params += ['i32'] + flat_params += [ptr_type(opts)] flat_results = [] return CoreFuncType(flat_params, flat_results) else: match context: case 'lift': if len(flat_params) > MAX_FLAT_PARAMS: - flat_params = ['i32'] + flat_params = [ptr_type(opts)] if opts.callback: flat_results = ['i32'] else: flat_results = [] case 'lower': if len(flat_params) > MAX_FLAT_ASYNC_PARAMS: - flat_params = ['i32'] + flat_params = [ptr_type(opts)] if len(flat_results) > 0: - flat_params += ['i32'] + flat_params += [ptr_type(opts)] flat_results = ['i32'] return CoreFuncType(flat_params, flat_results) -def flatten_types(ts): - return [ft for t in ts for ft in flatten_type(t)] +def flatten_types(ts, opts): + return [ft for t in ts for ft in flatten_type(t, opts)] -def flatten_type(t): +def flatten_type(t, opts): match despecialize(t): case BoolType() : return ['i32'] case U8Type() | U16Type() | U32Type() : return ['i32'] @@ -1711,36 +1733,36 @@ def flatten_type(t): case F32Type() : return ['f32'] case F64Type() : return ['f64'] case CharType() : return ['i32'] - case StringType() : return ['i32', 'i32'] + case StringType() : return [ptr_type(opts), ptr_type(opts)] case ErrorContextType() : return ['i32'] - case ListType(t, l) : return flatten_list(t, l) - case RecordType(fields) : return flatten_record(fields) - case VariantType(cases) : return flatten_variant(cases) + case ListType(t, l) : return flatten_list(t, l, opts) + case RecordType(fields) : return flatten_record(fields, opts) + case VariantType(cases) : return flatten_variant(cases, opts) case FlagsType(labels) : return ['i32'] case OwnType() | BorrowType() : return ['i32'] case StreamType() | FutureType() : return ['i32'] -def flatten_list(elem_type, maybe_length): +def flatten_list(elem_type, maybe_length, opts): if maybe_length is not None: - return flatten_type(elem_type) * maybe_length - return ['i32', 'i32'] + return flatten_type(elem_type, opts) * maybe_length + return [ptr_type(opts), ptr_type(opts)] -def flatten_record(fields): +def flatten_record(fields, opts): flat = [] for f in fields: - flat += flatten_type(f.t) + flat += flatten_type(f.t, opts) return flat -def flatten_variant(cases): +def flatten_variant(cases, opts): flat = [] for c in cases: if c.t is not None: - for i,ft in enumerate(flatten_type(c.t)): + for i,ft in enumerate(flatten_type(c.t, opts)): if i < len(flat): flat[i] = join(flat[i], ft) else: flat.append(ft) - return flatten_type(discriminant_type(cases)) + flat + return flatten_type(discriminant_type(cases), opts) + flat def join(a, b): if a == b: return a @@ -1810,8 +1832,8 @@ def lift_flat_signed(vi, core_width, t_width): return i def lift_flat_string(cx, vi): - ptr = vi.next('i32') - packed_length = vi.next('i32') + ptr = vi.next(ptr_type(cx.opts)) + packed_length = vi.next(ptr_type(cx.opts)) return load_string_from_range(cx, ptr, packed_length) def lift_flat_list(cx, vi, elem_type, maybe_length): @@ -1820,8 +1842,8 @@ def lift_flat_list(cx, vi, elem_type, maybe_length): for i in range(maybe_length): a.append(lift_flat(cx, vi, elem_type)) return a - ptr = vi.next('i32') - length = vi.next('i32') + ptr = vi.next(ptr_type(cx.opts)) + length = vi.next(ptr_type(cx.opts)) return load_list_from_range(cx, ptr, length, elem_type) def lift_flat_record(cx, vi, fields): @@ -1831,7 +1853,7 @@ def lift_flat_record(cx, vi, fields): return record def lift_flat_variant(cx, vi, cases): - flat_types = flatten_variant(cases) + flat_types = flatten_variant(cases, cx.opts) assert(flat_types.pop(0) == 'i32') case_index = vi.next('i32') trap_if(case_index >= len(cases)) @@ -1917,14 +1939,14 @@ def lower_flat_record(cx, v, fields): def lower_flat_variant(cx, v, cases): case_index, case_value = match_case(v, cases) - flat_types = flatten_variant(cases) + flat_types = flatten_variant(cases, cx.opts) assert(flat_types.pop(0) == 'i32') c = cases[case_index] if c.t is None: payload = [] else: payload = lower_flat(cx, case_value, c.t) - for i,(fv,have) in enumerate(zip(payload, flatten_type(c.t))): + for i,(fv,have) in enumerate(zip(payload, flatten_type(c.t, cx.opts))): want = flat_types.pop(0) match (have, want): case ('f32', 'i32') : payload[i] = encode_float_as_i32(fv) @@ -1943,30 +1965,30 @@ def lower_flat_flags(v, labels): ### Lifting and Lowering Values def lift_flat_values(cx, max_flat, vi, ts): - flat_types = flatten_types(ts) + flat_types = flatten_types(ts, cx.opts) if len(flat_types) > max_flat: - ptr = vi.next('i32') + ptr = vi.next(ptr_type(cx.opts)) tuple_type = TupleType(ts) - trap_if(ptr != align_to(ptr, alignment(tuple_type))) - trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(tuple_type, cx.opts))) + trap_if(ptr + elem_size(tuple_type, cx.opts) > len(cx.opts.memory)) return list(load(cx, ptr, tuple_type).values()) else: return [ lift_flat(cx, vi, t) for t in ts ] def lower_flat_values(cx, max_flat, vs, ts, out_param = None): cx.inst.may_leave = False - flat_types = flatten_types(ts) + flat_types = flatten_types(ts, cx.opts) if len(flat_types) > max_flat: tuple_type = TupleType(ts) tuple_value = {str(i): v for i,v in enumerate(vs)} if out_param is None: - ptr = cx.opts.realloc(0, 0, alignment(tuple_type), elem_size(tuple_type)) + ptr = cx.opts.realloc(0, 0, alignment(tuple_type, cx.opts), elem_size(tuple_type, cx.opts)) flat_vals = [ptr] else: - ptr = out_param.next('i32') + ptr = out_param.next(ptr_type(cx.opts)) flat_vals = [] - trap_if(ptr != align_to(ptr, alignment(tuple_type))) - trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory)) + trap_if(ptr != align_to(ptr, alignment(tuple_type, cx.opts))) + trap_if(ptr + elem_size(tuple_type, cx.opts) > len(cx.opts.memory)) store(cx, tuple_value, tuple_type, ptr) else: flat_vals = [] @@ -2177,14 +2199,14 @@ def canon_resource_rep(rt, thread, i): ### ๐Ÿ”€ `canon context.get` def canon_context_get(t, i, thread): - assert(t == 'i32') + assert(t == 'i32' or t == 'i64') assert(i < Thread.CONTEXT_LENGTH) return [thread.context[i]] ### ๐Ÿ”€ `canon context.set` def canon_context_set(t, i, thread, v): - assert(t == 'i32') + assert(t == 'i32' or t == 'i64') assert(i < Thread.CONTEXT_LENGTH) thread.context[i] = v return [] @@ -2240,24 +2262,24 @@ def canon_waitable_set_new(thread): ### ๐Ÿ”€ `canon waitable-set.wait` -def canon_waitable_set_wait(cancellable, mem, thread, si, ptr): +def canon_waitable_set_wait(cancellable, mem, opts, thread, si, ptr): trap_if(not thread.task.inst.may_leave) trap_if(not thread.task.may_block()) wset = thread.task.inst.handles.get(si) trap_if(not isinstance(wset, WaitableSet)) event = thread.task.wait_until(lambda: True, thread, wset, cancellable) - return unpack_event(mem, thread, ptr, event) + return unpack_event(mem, opts, thread, ptr, event) -def unpack_event(mem, thread, ptr, e: EventTuple): +def unpack_event(mem, opts, thread, ptr, e: EventTuple): event, p1, p2 = e - cx = LiftLowerContext(LiftLowerOptions(memory = mem), thread.task.inst) + cx = LiftLowerContext(LiftLowerOptions(memory = mem, addr_type = opts.addr_type, tbl_idx_type = opts.tbl_idx_type), thread.task.inst) store(cx, p1, U32Type(), ptr) store(cx, p2, U32Type(), ptr + 4) return [event] ### ๐Ÿ”€ `canon waitable-set.poll` -def canon_waitable_set_poll(cancellable, mem, thread, si, ptr): +def canon_waitable_set_poll(cancellable, mem, opts, thread, si, ptr): trap_if(not thread.task.inst.may_leave) wset = thread.task.inst.handles.get(si) trap_if(not isinstance(wset, WaitableSet)) @@ -2267,7 +2289,7 @@ def canon_waitable_set_poll(cancellable, mem, thread, si, ptr): event = (EventCode.NONE, 0, 0) else: event = wset.get_pending_event() - return unpack_event(mem, thread, ptr, event) + return unpack_event(mem, opts, thread, ptr, event) ### ๐Ÿ”€ `canon waitable-set.drop` @@ -2516,7 +2538,7 @@ class CoreFuncRef: def canon_thread_new_indirect(ft, ftbl: Table[CoreFuncRef], thread, fi, c): trap_if(not thread.task.inst.may_leave) f = ftbl.get(fi) - assert(ft == CoreFuncType(['i32'], [])) + assert(ft == CoreFuncType(['i32'], []) or ft == CoreFuncType(['i64'], [])) trap_if(f.t != ft) def thread_func(thread): [] = call_and_trap_on_throw(f.callee, thread, [c]) diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py index cd7ee74a..b71cc1be 100644 --- a/design/mvp/canonical-abi/run_tests.py +++ b/design/mvp/canonical-abi/run_tests.py @@ -35,7 +35,7 @@ def realloc(self, original_ptr, original_size, alignment, new_size): self.memory[ret : ret + original_size] = self.memory[original_ptr : original_ptr + original_size] return ret -def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None, sync_task_return = False, async_ = False): +def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None, sync_task_return = False, async_ = False, addr_type = 'i32', tbl_idx_type = 'i32'): opts = CanonicalOptions() opts.memory = memory opts.string_encoding = encoding @@ -44,10 +44,12 @@ def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return opts.sync_task_return = sync_task_return opts.async_ = async_ opts.callback = None + opts.addr_type = addr_type + opts.tbl_idx_type = tbl_idx_type return opts -def mk_cx(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None): - opts = mk_opts(memory, encoding, realloc, post_return) +def mk_cx(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None, addr_type = 'i32', tbl_idx_type = 'i32'): + opts = mk_opts(memory, encoding, realloc, post_return, addr_type=addr_type, tbl_idx_type=tbl_idx_type) inst = ComponentInstance(Store()) return LiftLowerContext(opts, inst) @@ -132,7 +134,7 @@ def test_name(): heap = Heap(5*len(cx.opts.memory)) if dst_encoding is None: dst_encoding = cx.opts.string_encoding - cx = mk_cx(heap.memory, dst_encoding, heap.realloc) + cx = mk_cx(heap.memory, dst_encoding, heap.realloc, addr_type=cx.opts.addr_type, tbl_idx_type=cx.opts.tbl_idx_type) lowered_vals = lower_flat(cx, v, lower_t) vi = CoreValueIter(lowered_vals) @@ -243,32 +245,32 @@ def test_nan64(inbits, outbits): test_nan64(0x7ff0000000000000, 0x7ff0000000000000) test_nan64(0x3ff0000000000000, 0x3ff0000000000000) -def test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units): +def test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units, addr_type='i32'): heap = Heap(len(encoded)) heap.memory[:] = encoded[:] - cx = mk_cx(heap.memory, src_encoding) + cx = mk_cx(heap.memory, src_encoding, addr_type=addr_type) v = (s, src_encoding, tagged_code_units) test(StringType(), [0, tagged_code_units], v, cx, dst_encoding) -def test_string(src_encoding, dst_encoding, s): +def test_string(src_encoding, dst_encoding, s, addr_type='i32'): if src_encoding == 'utf8': encoded = s.encode('utf-8') tagged_code_units = len(encoded) - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) + test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units, addr_type) elif src_encoding == 'utf16': encoded = s.encode('utf-16-le') tagged_code_units = int(len(encoded) / 2) - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) + test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units, addr_type) elif src_encoding == 'latin1+utf16': try: encoded = s.encode('latin-1') tagged_code_units = len(encoded) - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) + test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units, addr_type) except UnicodeEncodeError: pass encoded = s.encode('utf-16-le') - tagged_code_units = int(len(encoded) / 2) | UTF16_TAG - test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units) + tagged_code_units = int(len(encoded) / 2) | utf16_tag(LiftLowerOptions(addr_type=addr_type)) + test_string_internal(src_encoding, dst_encoding, s, encoded, tagged_code_units, addr_type) encodings = ['utf8', 'utf16', 'latin1+utf16'] @@ -276,14 +278,15 @@ def test_string(src_encoding, dst_encoding, s): '\u01ffy', 'xy\u01ff', 'a\ud7ffb', 'a\u02ff\u03ff\u04ffbc', '\uf123', '\uf123\uf123abc', 'abcdef\uf123'] -for src_encoding in encodings: - for dst_encoding in encodings: - for s in fun_strings: - test_string(src_encoding, dst_encoding, s) +for addr_type in ['i32', 'i64']: + for src_encoding in encodings: + for dst_encoding in encodings: + for s in fun_strings: + test_string(src_encoding, dst_encoding, s, addr_type) -def test_heap(t, expect, args, byte_array): +def test_heap(t, expect, args, byte_array, addr_type='i32', tbl_idx_type='i32'): heap = Heap(byte_array) - cx = mk_cx(heap.memory) + cx = mk_cx(heap.memory, addr_type=addr_type, tbl_idx_type=tbl_idx_type) test(t, args, expect, cx) # Empty record types are not permitted yet. @@ -309,15 +312,34 @@ def test_heap(t, expect, args, byte_array): test_heap(ListType(StringType()), [mk_str("hi"),mk_str("wat")], [0,2], [16,0,0,0, 2,0,0,0, 21,0,0,0, 3,0,0,0, ord('h'), ord('i'), 0xf,0xf,0xf, ord('w'), ord('a'), ord('t')]) +test_heap(ListType(StringType()), [mk_str("hi"),mk_str("wat")], [0,2], + [32,0,0,0,0,0,0,0, 2,0,0,0,0,0,0,0, + 37,0,0,0,0,0,0,0, 3,0,0,0,0,0,0,0, + ord('h'), ord('i'), 0xf,0xf,0xf, ord('w'), ord('a'), ord('t')], + addr_type='i64') test_heap(ListType(ListType(U8Type())), [[3,4,5],[],[6,7]], [0,3], [24,0,0,0, 3,0,0,0, 0,0,0,0, 0,0,0,0, 27,0,0,0, 2,0,0,0, 3,4,5, 6,7]) +test_heap(ListType(ListType(U8Type())), [[3,4,5],[],[6,7]], [0,3], + [48,0,0,0,0,0,0,0, 3,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0, + 51,0,0,0,0,0,0,0, 2,0,0,0,0,0,0,0, + 3,4,5, 6,7], + addr_type='i64') test_heap(ListType(ListType(U16Type())), [[5,6]], [0,1], [8,0,0,0, 2,0,0,0, 5,0, 6,0]) +test_heap(ListType(ListType(U16Type())), [[5,6]], [0,1], + [16,0,0,0,0,0,0,0, 2,0,0,0,0,0,0,0, + 5,0, 6,0], + addr_type='i64') test_heap(ListType(ListType(U16Type())), None, [0,1], [9,0,0,0, 2,0,0,0, 0, 5,0, 6,0]) +test_heap(ListType(ListType(U16Type())), None, [0,1], + [17,0,0,0,0,0,0,0, 2,0,0,0,0,0,0,0, + 0, 5,0, 6,0], + addr_type='i64') test_heap(ListType(ListType(U8Type(),2)), [[1,2],[3,4]], [0,2], [1,2, 3,4]) test_heap(ListType(ListType(U32Type(),2)), [[1,2],[3,4]], [0,2], @@ -369,21 +391,22 @@ def test_heap(t, expect, args, byte_array): test_heap(t, v, [0,2], [0xff,0xff,0xff,0xff, 0,0,0,0]) -def test_flatten(t, params, results): +def test_flatten(t, params, results, addr_type='i32', tbl_idx_type='i32'): + opts = mk_opts(addr_type=addr_type, tbl_idx_type=tbl_idx_type) expect = CoreFuncType(params, results) if len(params) > definitions.MAX_FLAT_PARAMS: - expect.params = ['i32'] + expect.params = [addr_type] if len(results) > definitions.MAX_FLAT_RESULTS: - expect.results = ['i32'] - got = flatten_functype(CanonicalOptions(), t, 'lift') + expect.results = [addr_type] + got = flatten_functype(opts, t, 'lift') assert(got == expect) if len(results) > definitions.MAX_FLAT_RESULTS: - expect.params += ['i32'] + expect.params += [addr_type] expect.results = [] - got = flatten_functype(CanonicalOptions(), t, 'lower') + got = flatten_functype(opts, t, 'lower') assert(got == expect) test_flatten(FuncType([U8Type(),F32Type(),F64Type()],[]), ['i32','f32','f64'], []) @@ -393,11 +416,13 @@ def test_flatten(t, params, results): test_flatten(FuncType([U8Type(),F32Type(),F64Type()],[TupleType([F32Type(),F32Type()])]), ['i32','f32','f64'], ['f32','f32']) test_flatten(FuncType([U8Type(),F32Type(),F64Type()],[F32Type(),F32Type()]), ['i32','f32','f64'], ['f32','f32']) test_flatten(FuncType([U8Type() for _ in range(17)],[]), ['i32' for _ in range(17)], []) +test_flatten(FuncType([U8Type() for _ in range(17)],[]), ['i32' for _ in range(17)], [], addr_type='i64') test_flatten(FuncType([U8Type() for _ in range(17)],[TupleType([U8Type(),U8Type()])]), ['i32' for _ in range(17)], ['i32','i32']) +test_flatten(FuncType([U8Type() for _ in range(17)],[TupleType([U8Type(),U8Type()])]), ['i32' for _ in range(17)], ['i32','i32'], addr_type='i64') def test_roundtrips(): - def test_roundtrip(t, v): + def test_roundtrip(t, v, addr_type='i32', tbl_idx_type='i32'): before = definitions.MAX_FLAT_RESULTS definitions.MAX_FLAT_RESULTS = 16 @@ -408,9 +433,8 @@ def callee(thread, x): return x callee_heap = Heap(1000) - callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc) + callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc, addr_type=addr_type, tbl_idx_type=tbl_idx_type) callee_inst = ComponentInstance(store) - lifted_callee = partial(canon_lift, callee_opts, callee_inst, ft, callee) got = None def on_start(): @@ -425,17 +449,22 @@ def on_resolve(result): definitions.MAX_FLAT_RESULTS = before - test_roundtrip(S8Type(), -1) - test_roundtrip(TupleType([U16Type(),U16Type()]), mk_tup(3,4)) - test_roundtrip(ListType(StringType()), [mk_str("hello there")]) - test_roundtrip(ListType(ListType(StringType())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) - test_roundtrip(ListType(OptionType(TupleType([StringType(),U16Type()]))), [{'some':mk_tup(mk_str("answer"),42)}]) - test_roundtrip(VariantType([CaseType('x', TupleType([U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - StringType()]))]), - {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}) + cases = [ + (S8Type(), -1), + (TupleType([U16Type(),U16Type()]), mk_tup(3,4)), + (ListType(StringType()), [mk_str("hello there")]), + (ListType(ListType(StringType())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]), + (ListType(OptionType(TupleType([StringType(),U16Type()]))), [{'some':mk_tup(mk_str("answer"),42)}]), + (VariantType([CaseType('x', TupleType([U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + StringType()]))]), + {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}), + ] + for addr_type in ['i32', 'i64']: + for t, v in cases: + test_roundtrip(t, v, addr_type=addr_type) def test_handles(): @@ -449,12 +478,6 @@ def dtor(thread, args): dtor_value = args[0] return [] - store = Store() - rt = ResourceType(ComponentInstance(store), dtor) # usable in imports and exports - inst = ComponentInstance(store) - rt2 = ResourceType(inst, dtor) # only usable in exports - opts = mk_opts() - def host_import(caller, on_start, on_resolve): args = on_start() assert(len(args) == 2) @@ -518,34 +541,39 @@ def core_wasm(thread, args): return [h, h2, h4] - ft = FuncType([ - OwnType(rt), - OwnType(rt), - BorrowType(rt), - BorrowType(rt2) - ],[ - OwnType(rt), - OwnType(rt), - OwnType(rt) - ]) + for tbl_idx_type in ['i32', 'i64']: + store = Store() + rt = ResourceType(ComponentInstance(store), dtor) # usable in imports and exports + inst = ComponentInstance(store) + rt2 = ResourceType(inst, dtor) # only usable in exports + opts = mk_opts(tbl_idx_type=tbl_idx_type) + + ft = FuncType([ + OwnType(rt), + OwnType(rt), + BorrowType(rt), + BorrowType(rt2) + ],[ + OwnType(rt), + OwnType(rt), + OwnType(rt) + ]) - def on_start(): - return [ 42, 43, 44, 13 ] + got = None + def on_resolve(results): + nonlocal got + got = results - got = None - def on_resolve(results): - nonlocal got - got = results + run_lift(opts, inst, ft, core_wasm, lambda: [42, 43, 44, 13], on_resolve) - run_lift(opts, inst, ft, core_wasm, on_start, on_resolve) + assert(len(got) == 3) + assert(got[0] == 46) + assert(got[1] == 43) + assert(got[2] == 45) + assert(len(inst.handles.array) == 5) + assert(all(inst.handles.array[i] is None for i in range(4))) + assert(len(inst.handles.free) == 4) - assert(len(got) == 3) - assert(got[0] == 46) - assert(got[1] == 43) - assert(got[2] == 45) - assert(len(inst.handles.array) == 5) - assert(all(inst.handles.array[i] is None for i in range(4))) - assert(len(inst.handles.free) == 4) definitions.MAX_FLAT_RESULTS = before @@ -617,21 +645,21 @@ def consumer(thread, args): fut1_1.set() waitretp = consumer_heap.realloc(0, 0, 8, 4) - [event] = canon_waitable_set_wait(True, consumer_heap.memory, thread, seti, waitretp) + [event] = canon_waitable_set_wait(True, consumer_heap.memory, LiftLowerOptions(), thread, seti, waitretp) assert(event == EventCode.SUBTASK) assert(consumer_heap.memory[waitretp] == subi1) assert(consumer_heap.memory[waitretp+4] == Subtask.State.RETURNED) [] = canon_subtask_drop(thread, subi1) fut1_2.set() - [event] = canon_waitable_set_wait(True, consumer_heap.memory, thread, seti, waitretp) + [event] = canon_waitable_set_wait(True, consumer_heap.memory, LiftLowerOptions(), thread, seti, waitretp) assert(event == EventCode.SUBTASK) assert(consumer_heap.memory[waitretp] == subi2) assert(consumer_heap.memory[waitretp+4] == Subtask.State.STARTED) assert(consumer_heap.memory[retp] == 13) fut2.set() - [event] = canon_waitable_set_wait(True, consumer_heap.memory, thread, seti, waitretp) + [event] = canon_waitable_set_wait(True, consumer_heap.memory, LiftLowerOptions(), thread, seti, waitretp) assert(event == EventCode.SUBTASK) assert(consumer_heap.memory[waitretp] == subi2) assert(consumer_heap.memory[waitretp+4] == Subtask.State.RETURNED) @@ -852,7 +880,7 @@ def core_consumer(thread, args): assert(ret == CopyResult.COMPLETED) retp = 0 - [event] = canon_waitable_set_wait(True, consumer_mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, consumer_mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.SUBTASK) assert(consumer_mem[retp+0] == subi2) assert(consumer_mem[retp+4] == Subtask.State.STARTED) @@ -864,14 +892,14 @@ def core_consumer(thread, args): [ret] = canon_thread_yield(True, thread) assert(ret == 0) retp = 0 - [ret] = canon_waitable_set_poll(True, consumer_mem, thread, seti, retp) + [ret] = canon_waitable_set_poll(True, consumer_mem, LiftLowerOptions(), thread, seti, retp) assert(ret == EventCode.NONE) [ret] = canon_future_write(FutureType(None), consumer_opts, thread, wfut21, 0xdeadbeef) assert(ret == CopyResult.COMPLETED) retp = 0 - [event] = canon_waitable_set_wait(True, consumer_mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, consumer_mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.SUBTASK) assert(consumer_mem[retp+0] == subi1) assert(consumer_mem[retp+4] == Subtask.State.RETURNED) @@ -885,14 +913,14 @@ def core_consumer(thread, args): [ret] = canon_thread_yield(True, thread) assert(ret == 0) retp = 0 - [ret] = canon_waitable_set_poll(True, consumer_mem, thread, seti, retp) + [ret] = canon_waitable_set_poll(True, consumer_mem, LiftLowerOptions(), thread, seti, retp) assert(ret == EventCode.NONE) [ret] = canon_future_write(FutureType(None), consumer_opts, thread, wfut13, 0xdeadbeef) assert(ret == CopyResult.COMPLETED) retp = 0 - [event] = canon_waitable_set_wait(True, consumer_mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, consumer_mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.SUBTASK) assert(consumer_mem[retp+0] == subi2) assert(consumer_mem[retp+4] == Subtask.State.RETURNED) @@ -909,7 +937,7 @@ def core_consumer(thread, args): assert(ret == CopyResult.COMPLETED) retp = 0 - [event] = canon_waitable_set_wait(True, consumer_mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, consumer_mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.SUBTASK) assert(consumer_mem[retp+0] == subi3) assert(consumer_mem[retp+4] == Subtask.State.RETURNED) @@ -967,7 +995,7 @@ def core_caller(thread, args): [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, subi, seti) retp3 = 12 - [event] = canon_waitable_set_wait(True, caller_mem, thread, seti, retp3) + [event] = canon_waitable_set_wait(True, caller_mem, LiftLowerOptions(), thread, seti, retp3) assert(event == EventCode.SUBTASK) assert(caller_mem[retp3+0] == subi) assert(caller_mem[retp3+4] == Subtask.State.RETURNED) @@ -1034,7 +1062,7 @@ def consumer(thread, args): [ret] = canon_thread_yield(True, thread) assert(ret == 0) retp = 8 - [event] = canon_waitable_set_poll(True, consumer_heap.memory, thread, seti, retp) + [event] = canon_waitable_set_poll(True, consumer_heap.memory, LiftLowerOptions(), thread, seti, retp) if event == EventCode.NONE: continue assert(event == EventCode.SUBTASK) @@ -1120,7 +1148,7 @@ def consumer(thread, args): remain = [subi1, subi2] while remain: retp = 8 - [event] = canon_waitable_set_wait(True, consumer_heap.memory, thread, seti, retp) + [event] = canon_waitable_set_wait(True, consumer_heap.memory, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.SUBTASK) assert(consumer_heap.memory[retp+4] == Subtask.State.RETURNED) subi = consumer_heap.memory[retp] @@ -1186,14 +1214,14 @@ def core_func(thread, args): fut1.set() retp = lower_heap.realloc(0,0,8,4) - [event] = canon_waitable_set_wait(True, lower_heap.memory, thread, seti, retp) + [event] = canon_waitable_set_wait(True, lower_heap.memory, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.SUBTASK) assert(lower_heap.memory[retp] == subi1) assert(lower_heap.memory[retp+4] == Subtask.State.RETURNED) fut2.set() - [event] = canon_waitable_set_wait(True, lower_heap.memory, thread, seti, retp) + [event] = canon_waitable_set_wait(True, lower_heap.memory, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.SUBTASK) assert(lower_heap.memory[retp] == subi2) assert(lower_heap.memory[retp+4] == Subtask.State.RETURNED) @@ -1510,7 +1538,7 @@ def core_func(thread, args): [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, rsi1, seti) definitions.throw_it = True - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) ## + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) ## assert(event == EventCode.STREAM_READ) assert(mem[retp+0] == rsi1) result,n = unpack_result(mem[retp+4]) @@ -1526,7 +1554,7 @@ def core_func(thread, args): assert(ret == definitions.BLOCKED) host_import_incoming.set_remain(100) [] = canon_waitable_join(thread, wsi3, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_WRITE) assert(mem[retp+0] == wsi3) result,n = unpack_result(mem[retp+4]) @@ -1538,7 +1566,7 @@ def core_func(thread, args): assert(ret == definitions.BLOCKED) dst_stream.set_remain(100) [] = canon_waitable_join(thread, wsi2, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_WRITE) assert(mem[retp+0] == wsi2) result,n = unpack_result(mem[retp+4]) @@ -1557,7 +1585,7 @@ def core_func(thread, args): [ret] = canon_stream_read(StreamType(U8Type()), opts, thread, rsi4, 0, 4) assert(ret == definitions.BLOCKED) [] = canon_waitable_join(thread, rsi4, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_READ) assert(mem[retp+0] == rsi4) result,n = unpack_result(mem[retp+4]) @@ -1681,7 +1709,7 @@ def core_func(thread, args): [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, rsi, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_READ) assert(mem[retp+0] == rsi) result,n = unpack_result(mem[retp+4]) @@ -1702,7 +1730,7 @@ def core_func(thread, args): assert(ret == definitions.BLOCKED) dst.set_remain(4) [] = canon_waitable_join(thread, wsi, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_WRITE) assert(mem[retp+0] == wsi) result,n = unpack_result(mem[retp+4]) @@ -1763,7 +1791,7 @@ def core_func1(thread, args): retp = 16 [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, wsi, seti) - [event] = canon_waitable_set_wait(True, mem1, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem1, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_WRITE) assert(mem1[retp+0] == wsi) result,n = unpack_result(mem1[retp+4]) @@ -1774,7 +1802,7 @@ def core_func1(thread, args): fut4.set() - [event] = canon_waitable_set_wait(True, mem1, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem1, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_WRITE) assert(mem1[retp+0] == wsi) assert(mem1[retp+4] == 0) @@ -1812,7 +1840,7 @@ def core_func2(thread, args): [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, rsi, seti) - [event] = canon_waitable_set_wait(True, mem2, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem2, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_READ) assert(mem2[retp+0] == rsi) result,n = unpack_result(mem2[retp+4]) @@ -1840,7 +1868,7 @@ def core_func2(thread, args): [ret] = canon_stream_read(StreamType(U8Type()), opts2, thread, rsi, 12345, 0) assert(ret == definitions.BLOCKED) - [event] = canon_waitable_set_wait(True, mem2, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem2, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_READ) assert(mem2[retp+0] == rsi) p2 = int.from_bytes(mem2[retp+4 : retp+8], 'little', signed=False) @@ -1886,7 +1914,7 @@ def core_func1(thread, args): retp = 16 [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, wsi, seti) - [event] = canon_waitable_set_wait(True, mem1, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem1, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_WRITE) assert(mem1[retp+0] == wsi) result,n = unpack_result(mem1[retp+4]) @@ -1923,7 +1951,7 @@ def core_func2(thread, args): [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, rsi, seti) - [event] = canon_waitable_set_wait(True, mem2, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem2, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_READ) assert(mem2[retp+0] == rsi) result,n = unpack_result(mem2[retp+4]) @@ -2040,7 +2068,7 @@ def core_func(thread, args): host_source.unblock_cancel() [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, rsi, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.STREAM_READ) assert(mem[retp+0] == rsi) result,n = unpack_result(mem[retp+4]) @@ -2145,7 +2173,7 @@ def core_func(thread, args): [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, rfi, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, retp) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, retp) assert(event == EventCode.FUTURE_READ) assert(mem[retp+0] == rfi) assert(mem[retp+4] == CopyResult.COMPLETED) @@ -2211,7 +2239,7 @@ def core_callee1(thread, args): def core_callee2(thread, args): [x] = args [si] = canon_waitable_set_new(thread) - [ret] = canon_waitable_set_wait(True, callee_heap.memory, thread, si, 0) + [ret] = canon_waitable_set_wait(True, callee_heap.memory, LiftLowerOptions(), thread, si, 0) assert(ret == EventCode.TASK_CANCELLED) match x: case 1: @@ -2258,9 +2286,9 @@ def core_callee4(thread, args): except Trap: pass [seti] = canon_waitable_set_new(thread) - [result] = canon_waitable_set_wait(True, callee_heap.memory, thread, seti, 0) + [result] = canon_waitable_set_wait(True, callee_heap.memory, LiftLowerOptions(), thread, seti, 0) assert(result == EventCode.TASK_CANCELLED) - [result] = canon_waitable_set_poll(True, callee_heap.memory, thread, seti, 0) + [result] = canon_waitable_set_poll(True, callee_heap.memory, LiftLowerOptions(), thread, seti, 0) assert(result == EventCode.NONE) [] = canon_task_cancel(thread) return [] @@ -2395,7 +2423,7 @@ def core_caller(thread, args): assert(caller_heap.memory[0] == 13) [] = canon_waitable_join(thread, subi3, seti) retp = 8 - [ret] = canon_waitable_set_wait(True, caller_heap.memory, thread, seti, retp) + [ret] = canon_waitable_set_wait(True, caller_heap.memory, LiftLowerOptions(), thread, seti, retp) assert(ret == EventCode.SUBTASK) assert(caller_heap.memory[retp+0] == subi3) assert(caller_heap.memory[retp+4] == Subtask.State.RETURNED) @@ -2414,7 +2442,7 @@ def core_caller(thread, args): assert(caller_heap.memory[0] == 13) [] = canon_waitable_join(thread, subi4, seti) retp = 8 - [ret] = canon_waitable_set_wait(True, caller_heap.memory, thread, seti, retp) + [ret] = canon_waitable_set_wait(True, caller_heap.memory, LiftLowerOptions(), thread, seti, retp) assert(ret == EventCode.SUBTASK) assert(caller_heap.memory[retp+0] == subi4) assert(caller_heap.memory[retp+4] == Subtask.State.CANCELLED_BEFORE_RETURNED) @@ -2456,7 +2484,7 @@ def core_caller(thread, args): host_fut4.set() [] = canon_waitable_join(thread, subi, seti) waitretp = 4 - [event] = canon_waitable_set_wait(True, caller_heap.memory, thread, seti, waitretp) + [event] = canon_waitable_set_wait(True, caller_heap.memory, LiftLowerOptions(), thread, seti, waitretp) assert(event == EventCode.SUBTASK) assert(caller_heap.memory[waitretp] == subi) assert(caller_heap.memory[waitretp+4] == Subtask.State.CANCELLED_BEFORE_RETURNED) @@ -2472,7 +2500,7 @@ def core_caller(thread, args): host_fut5.set() [] = canon_waitable_join(thread, subi, seti) waitretp = 4 - [event] = canon_waitable_set_wait(True, caller_heap.memory, thread, seti, waitretp) + [event] = canon_waitable_set_wait(True, caller_heap.memory, LiftLowerOptions(), thread, seti, waitretp) assert(event == EventCode.SUBTASK) assert(caller_heap.memory[waitretp] == subi) assert(caller_heap.memory[waitretp+4] == Subtask.State.RETURNED) @@ -2487,7 +2515,7 @@ def core_caller(thread, args): assert(ret == definitions.BLOCKED) [] = canon_waitable_join(thread, subi, seti) - [event] = canon_waitable_set_wait(True, caller_heap.memory, thread, seti, 4) + [event] = canon_waitable_set_wait(True, caller_heap.memory, LiftLowerOptions(), thread, seti, 4) assert(event == EventCode.SUBTASK) assert(caller_heap.memory[0] == 45) assert(caller_heap.memory[4] == subi) @@ -2534,7 +2562,7 @@ def core_func(thread, args): [] = canon_future_drop_readable(FutureType(elemt), thread, rfi) [] = canon_waitable_join(thread, wfi, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, 0) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, 0) assert(event == EventCode.FUTURE_WRITE) assert(mem[0] == wfi) assert(mem[4] == CopyResult.COMPLETED) @@ -2554,7 +2582,7 @@ def core_func(thread, args): [] = canon_stream_drop_readable(StreamType(elemt), thread, rsi) [] = canon_waitable_join(thread, wsi, seti) - [event] = canon_waitable_set_wait(True, mem, thread, seti, 0) + [event] = canon_waitable_set_wait(True, mem, LiftLowerOptions(), thread, seti, 0) assert(event == EventCode.STREAM_WRITE) assert(mem[0] == wsi) result,n = unpack_result(mem[4]) @@ -2568,6 +2596,7 @@ def core_func(thread, args): run_lift(sync_opts, inst, ft, core_func, lambda:[], lambda _:()) + def test_async_flat_params(): store = Store() heap = Heap(1000) @@ -2745,14 +2774,14 @@ def core_consumer(thread, args): retp3 = 16 [seti] = canon_waitable_set_new(thread) [] = canon_waitable_join(thread, subi1, seti) - [event] = canon_waitable_set_wait(True, consumer_mem, thread, seti, retp3) + [event] = canon_waitable_set_wait(True, consumer_mem, LiftLowerOptions(), thread, seti, retp3) assert(event == EventCode.SUBTASK) assert(consumer_mem[retp3] == subi1) assert(consumer_mem[retp3+4] == Subtask.State.RETURNED) assert(consumer_mem[retp1] == 42) [] = canon_waitable_join(thread, subi2, seti) - [event] = canon_waitable_set_wait(True, consumer_mem, thread, seti, retp3) + [event] = canon_waitable_set_wait(True, consumer_mem, LiftLowerOptions(), thread, seti, retp3) assert(event == EventCode.SUBTASK) assert(consumer_mem[retp3] == subi2) assert(consumer_mem[retp3+4] == Subtask.State.RETURNED) @@ -2801,6 +2830,47 @@ def mk_task(supertask, inst): assert(call_might_be_recursive(p_task, c2)) +def test_mixed_table_memory_types(): + store = Store() + rt = ResourceType(ComponentInstance(store), None) + + # Verify alignment and elem_size for mixed configurations + opts64_addr = LiftLowerOptions(addr_type='i64', tbl_idx_type='i32') + assert(alignment(StringType(), opts64_addr) == 8) + assert(elem_size(StringType(), opts64_addr) == 16) + assert(alignment(OwnType(rt), opts64_addr) == 4) + assert(elem_size(OwnType(rt), opts64_addr) == 4) + + opts64_tbl = LiftLowerOptions(addr_type='i32', tbl_idx_type='i64') + assert(alignment(StringType(), opts64_tbl) == 4) + assert(elem_size(StringType(), opts64_tbl) == 8) + assert(alignment(OwnType(rt), opts64_tbl) == 4) + assert(elem_size(OwnType(rt), opts64_tbl) == 4) + + # Round-trip a type exercising both memory pointers and table pointers + before = definitions.MAX_FLAT_RESULTS + definitions.MAX_FLAT_RESULTS = 16 + t = TupleType([ListType(OwnType(rt)), StringType()]) + + def core_wasm(thread, args): + return args + + for addr_type, tbl_idx_type in [('i64','i32'), ('i32','i64')]: + heap = Heap(1000) + inst = ComponentInstance(store) + opts = mk_opts(heap.memory, 'utf8', heap.realloc, addr_type=addr_type, tbl_idx_type=tbl_idx_type) + + ft = FuncType([t], [t]) + v = {'0': [42, 43], '1': mk_str("hello")} + got = None + def on_resolve(results): + nonlocal got + got = results + run_lift(opts, inst, ft, core_wasm, lambda: [v], on_resolve) + assert(got[0] == v) + + definitions.MAX_FLAT_RESULTS = before + test_roundtrips() test_handles() test_async_to_async() @@ -2827,5 +2897,6 @@ def mk_task(supertask, inst): test_threads() test_thread_cancel_callback() test_reentrance() +test_mixed_table_memory_types() print("All tests passed") diff --git a/test/wasm-tools/memory64.wast b/test/wasm-tools/memory64.wast index 0ec55341..a72eb3b3 100644 --- a/test/wasm-tools/memory64.wast +++ b/test/wasm-tools/memory64.wast @@ -42,13 +42,14 @@ (core instance (instantiate $B (with "" (instance (export "" (table $m)))))) ) -(assert_invalid - (component - (import "x" (func $x (param "x" string))) - (core module $A - (memory (export "m") i64 1)) - (core instance $A (instantiate $A)) - (alias core export $A "m" (core memory $m)) - (core func (canon lower (func $x) (memory $m))) +(component + (import "x" (func $x (param "x" string))) + (core module $A + (memory (export "m") i64 1) + (func (export "realloc") (param i64 i64 i64 i64) (result i64) unreachable) ) - "canonical ABI memory is not a 32-bit linear memory") + (core instance $A (instantiate $A)) + (alias core export $A "m" (core memory $m)) + (core func $realloc (alias core export $A "realloc")) + (core func (canon lower (func $x) (memory $m) (realloc (func $realloc)))) +)