11import sys
22import warnings
33from pathlib import Path
4- from typing import Iterator
4+ from typing import Iterator , Tuple
55
66import pytest
77
@@ -60,10 +60,30 @@ def verify_output(capsys: pytest.CaptureFixture[str], filename: str) -> None:
6060
6161def verify_output_str (output : str , filename : str ) -> None :
6262 expected = Path (filename ).read_text (encoding = "utf-8" )
63+ # Verify the input size has the same unit
64+ output_input_size , output_input_unit = get_input_size_and_unit (output )
65+ expected_input_size , expected_input_unit = get_input_size_and_unit (expected )
66+ assert output_input_unit == expected_input_unit
67+
68+ # Sometime it does not have the same exact value, depending on torch version.
69+ # We assume the variation cannot be too large.
70+ if output_input_size != 0 :
71+ assert abs (output_input_size - expected_input_size )/ output_input_size < 1e-2
72+
73+ if output_input_size != expected_input_size :
74+ # In case of a difference, replace the expected input size.
75+ expected = replace_input_size (expected , expected_input_unit , expected_input_size , output_input_size )
6376 assert output == expected
6477 for category in (ColumnSettings .NUM_PARAMS , ColumnSettings .MULT_ADDS ):
6578 assert_sum_column_totals_match (output , category )
6679
80+ def replace_input_size (output : str , unit : str , old_value : float , new_value : float ) -> str :
81+ return output .replace (f"Input size { unit } : { old_value :.2f} " , f"Input size { unit } : { new_value :.2f} " )
82+
83+ def get_input_size_and_unit (output_str : str ) -> Tuple [float , str ]:
84+ input_size = float (output_str .split ('Input size' )[1 ].split (':' )[1 ].split ('\n ' )[0 ].strip ())
85+ input_unit = output_str .split ('Input size' )[1 ].split (':' )[0 ].strip ()
86+ return input_size , input_unit
6787
6888def get_column_value_for_row (line : str , offset : int ) -> int :
6989 """Helper function for getting the column totals."""
@@ -88,12 +108,23 @@ def assert_sum_column_totals_match(output: str, category: ColumnSettings) -> Non
88108 if offset == - 1 :
89109 return
90110 layers = lines [1 ].split ("\n " )
91- calculated_total = sum (get_column_value_for_row (line , offset ) for line in layers )
111+ calculated_total = float ( sum (get_column_value_for_row (line , offset ) for line in layers ) )
92112 results = lines [2 ].split ("\n " )
93113
94114 if category == ColumnSettings .NUM_PARAMS :
95115 total_params = results [0 ].split (":" )[1 ].replace ("," , "" )
96- assert calculated_total == int (total_params )
116+ splitted_results = results [0 ].split ('(' )
117+ if len (splitted_results ) > 1 :
118+ units = splitted_results [1 ][0 ]
119+ if units == 'T' :
120+ calculated_total /= 1e12
121+ elif units == 'G' :
122+ calculated_total /= 1e9
123+ elif units == 'M' :
124+ calculated_total /= 1e6
125+ elif units == 'k' :
126+ calculated_total /= 1e3
127+ assert calculated_total == float (total_params )
97128 elif category == ColumnSettings .MULT_ADDS :
98129 total_mult_adds = results [- 1 ].split (":" )[1 ].replace ("," , "" )
99130 assert float (
0 commit comments