@@ -139,16 +139,9 @@ def inlinetest_namespace() -> Dict[str, Any]:
139139## InlineTest
140140######################################################################
141141class InlineTest :
142- # https://docs.python.org/3/tutorial/stdlib.html
143- import_libraries = [
144- "import re" ,
145- "import unittest" ,
146- "from unittest.mock import patch" ,
147- ]
148-
149142 def __init__ (self ):
150143 self .assume_stmts = []
151- self .assume_node : ast .If = None
144+ self .assume_node : ast .If = None
152145 self .check_stmts = []
153146 self .given_stmts = []
154147 self .previous_stmts = []
@@ -168,42 +161,49 @@ def __init__(self):
168161
169162 def to_test (self ):
170163 if self .prev_stmt_type == PrevStmtType .CondExpr :
171- if self .assume_stmts == []:
164+ if self .assume_stmts == []:
172165 return "\n " .join (
173- self .import_libraries
174- + [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
175- + [ExtractInlineTest .node_to_source_code (n ) for n in self .check_stmts ]
166+ [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
167+ + [
168+ ExtractInlineTest .node_to_source_code (n )
169+ for n in self .check_stmts
170+ ]
176171 )
177172 else :
178- body_nodes = [n for n in self .given_stmts ] + [n for n in self .previous_stmts ] + [n for n in self .check_stmts ]
173+ body_nodes = (
174+ [n for n in self .given_stmts ]
175+ + [n for n in self .previous_stmts ]
176+ + [n for n in self .check_stmts ]
177+ )
179178 assume_statement = self .assume_stmts [0 ]
180179 assume_node = self .build_assume_node (assume_statement , body_nodes )
181- return "\n " .join (
182- self .import_libraries
183- + ExtractInlineTest .node_to_source_code (assume_node )
184-
185- )
186-
180+ return "\n " .join (ExtractInlineTest .node_to_source_code (assume_node ))
187181
188182 else :
189183 if self .assume_stmts == []:
190184 return "\n " .join (
191- self .import_libraries
192- + [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
193- + [ExtractInlineTest .node_to_source_code (n ) for n in self .previous_stmts ]
194- + [ExtractInlineTest .node_to_source_code (n ) for n in self .check_stmts ]
185+ [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
186+ + [
187+ ExtractInlineTest .node_to_source_code (n )
188+ for n in self .previous_stmts
189+ ]
190+ + [
191+ ExtractInlineTest .node_to_source_code (n )
192+ for n in self .check_stmts
193+ ]
195194 )
196195 else :
197- body_nodes = [n for n in self .given_stmts ] + [n for n in self .previous_stmts ] + [n for n in self .check_stmts ]
196+ body_nodes = (
197+ [n for n in self .given_stmts ]
198+ + [n for n in self .previous_stmts ]
199+ + [n for n in self .check_stmts ]
200+ )
198201 assume_statement = self .assume_stmts [0 ]
199202 assume_node = self .build_assume_node (assume_statement , body_nodes )
200- return "\n " .join (
201- self .import_libraries
202- + [ExtractInlineTest .node_to_source_code (assume_node )]
203- )
204-
203+ return "\n " .join ([ExtractInlineTest .node_to_source_code (assume_node )])
204+
205205 def build_assume_node (self , assumption_node , body_nodes ):
206- return ast .If (assumption_node , body_nodes ,[])
206+ return ast .If (assumption_node , body_nodes , [])
207207
208208 def __repr__ (self ):
209209 if self .test_name :
@@ -216,8 +216,7 @@ def is_empty(self) -> bool:
216216
217217 def __eq__ (self , other ):
218218 return (
219- self .import_libraries == other .import_libraries
220- and self .assume_stmts == other .assume_stmts
219+ self .assume_stmts == other .assume_stmts
221220 and self .given_stmts == other .given_stmts
222221 and self .previous_stmts == other .previous_stmts
223222 and self .check_stmts == other .check_stmts
@@ -481,7 +480,10 @@ def parse_constructor(self, node):
481480 elif (
482481 keyword .arg == self .arg_timeout_str
483482 and isinstance (keyword .value , ast .Constant )
484- and (isinstance (keyword .value .value , float ) or isinstance (keyword .value .value , int ))
483+ and (
484+ isinstance (keyword .value .value , float )
485+ or isinstance (keyword .value .value , int )
486+ )
485487 ):
486488 if keyword .value .value <= 0.0 :
487489 raise MalformedException (
@@ -606,7 +608,10 @@ def parse_constructor(self, node):
606608 elif (
607609 keyword .arg == self .arg_timeout_str
608610 and isinstance (keyword .value , ast .Num )
609- and (isinstance (keyword .value .n , float ) or isinstance (keyword .value .n , int ))
611+ and (
612+ isinstance (keyword .value .n , float )
613+ or isinstance (keyword .value .n , int )
614+ )
610615 ):
611616 if keyword .value .n <= 0.0 :
612617 raise MalformedException (
@@ -1023,7 +1028,9 @@ def parse_fail(self, node):
10231028 if len (node .args ) == 0 :
10241029 self .build_fail ()
10251030 else :
1026- raise MalformedException ("inline test: fail() does not expect any arguments" )
1031+ raise MalformedException (
1032+ "inline test: fail() does not expect any arguments"
1033+ )
10271034
10281035 def parse_group (self , node ):
10291036 if (
@@ -1090,15 +1097,12 @@ def parse_inline_test(self, node):
10901097
10911098 # "assume_true(...) or assume_false(...)
10921099 inline_test_call_index = 1
1093- if ( len (inline_test_calls ) >= 2 ) :
1100+ if len (inline_test_calls ) >= 2 :
10941101 call = inline_test_calls [1 ]
1095- if (
1096- isinstance (call .func , ast .Attribute )
1097- and call .func .attr == self .assume
1098- ):
1102+ if isinstance (call .func , ast .Attribute ) and call .func .attr == self .assume :
10991103 self .parse_assume (call )
11001104 inline_test_call_index += 1
1101-
1105+
11021106 # "given(a, 1)"
11031107 for call in inline_test_calls [inline_test_call_index :]:
11041108 if (
@@ -1169,6 +1173,7 @@ def node_to_source_code(node):
11691173 ast .fix_missing_locations (node )
11701174 return ast_unparse (node )
11711175
1176+
11721177######################################################################
11731178## InlineTest Finder
11741179######################################################################
0 commit comments