diff --git a/changelog.d/7068-refactor-special-map.changed b/changelog.d/7068-refactor-special-map.changed new file mode 100644 index 00000000000..78948ead746 --- /dev/null +++ b/changelog.d/7068-refactor-special-map.changed @@ -0,0 +1 @@ +Refactor special_map to simplify logic and reduce memory allocations \ No newline at end of file diff --git a/clarity/src/vm/analysis/type_checker/v2_05/tests/mod.rs b/clarity/src/vm/analysis/type_checker/v2_05/tests/mod.rs index 05fae0997ba..e697129e55e 100644 --- a/clarity/src/vm/analysis/type_checker/v2_05/tests/mod.rs +++ b/clarity/src/vm/analysis/type_checker/v2_05/tests/mod.rs @@ -925,6 +925,9 @@ fn test_lists() { "(map hash160 (+ u1 u2))", "(len 1)", "(map + (list 1 2 3 4 5) (list true true true true true))", + "(map + (list) (list 1))", + "(map pow (list) (list 1))", + "(define-private (add (x int) (y int)) (+ x y)) (map add (list) (list 1))", ]; let bad_expected = [ StaticCheckErrorKind::TypeError(Box::new(BoolType), Box::new(IntType)), @@ -944,8 +947,25 @@ fn test_lists() { StaticCheckErrorKind::ExpectedSequence(Box::new(UIntType)), StaticCheckErrorKind::ExpectedSequence(Box::new(IntType)), StaticCheckErrorKind::TypeError(Box::new(IntType), Box::new(BoolType)), + StaticCheckErrorKind::UnionTypeError( + vec![IntType, UIntType], + Box::new(TypeSignature::NoType), + ), + StaticCheckErrorKind::UnionTypeError( + vec![IntType, UIntType], + Box::new(TypeSignature::NoType), + ), + StaticCheckErrorKind::TypeError( + Box::new(TypeSignature::IntType), + Box::new(TypeSignature::NoType), + ), ]; + assert_eq!( + good.len(), + expected.len(), + "Good tests should match related expected count" + ); for (good_test, expected) in good.iter().zip(expected.iter()) { assert_eq!( expected, @@ -953,6 +973,11 @@ fn test_lists() { ); } + assert_eq!( + bad.len(), + bad_expected.len(), + "Bad tests should match related expected count" + ); for (bad_test, expected) in bad.iter().zip(bad_expected.iter()) { assert_eq!(*expected, *type_check_helper(bad_test).unwrap_err().err); } diff --git a/clarity/src/vm/analysis/type_checker/v2_1/tests/mod.rs b/clarity/src/vm/analysis/type_checker/v2_1/tests/mod.rs index de7bd5d2b22..b383ccd4d78 100644 --- a/clarity/src/vm/analysis/type_checker/v2_1/tests/mod.rs +++ b/clarity/src/vm/analysis/type_checker/v2_1/tests/mod.rs @@ -1520,6 +1520,9 @@ fn test_lists() { "(map hash160 (+ u1 u2))", "(len 1)", "(map + (list 1 2 3 4 5) (list true true true true true))", + "(map + (list) (list 1))", + "(map pow (list) (list 1))", + "(define-private (add (x int) (y int)) (+ x y)) (map add (list) (list 1))", ]; let bad_expected = [ StaticCheckErrorKind::TypeError(Box::new(BoolType), Box::new(IntType)), @@ -1539,8 +1542,25 @@ fn test_lists() { StaticCheckErrorKind::ExpectedSequence(Box::new(UIntType)), StaticCheckErrorKind::ExpectedSequence(Box::new(IntType)), StaticCheckErrorKind::TypeError(Box::new(IntType), Box::new(BoolType)), + StaticCheckErrorKind::UnionTypeError( + vec![IntType, UIntType], + Box::new(TypeSignature::NoType), + ), + StaticCheckErrorKind::UnionTypeError( + vec![IntType, UIntType], + Box::new(TypeSignature::NoType), + ), + StaticCheckErrorKind::TypeError( + Box::new(TypeSignature::IntType), + Box::new(TypeSignature::NoType), + ), ]; + assert_eq!( + good.len(), + expected.len(), + "Good tests should match related expected count" + ); for (good_test, expected) in good.iter().zip(expected.iter()) { assert_eq!( expected, @@ -1548,6 +1568,11 @@ fn test_lists() { ); } + assert_eq!( + bad.len(), + bad_expected.len(), + "Bad tests should match related expected count" + ); for (bad_test, expected) in bad.iter().zip(bad_expected.iter()) { assert_eq!(*expected, *type_check_helper(bad_test).unwrap_err().err); } diff --git a/clarity/src/vm/functions/sequences.rs b/clarity/src/vm/functions/sequences.rs index 5d327a8781a..1a8054de2ed 100644 --- a/clarity/src/vm/functions/sequences.rs +++ b/clarity/src/vm/functions/sequences.rs @@ -204,10 +204,8 @@ pub fn special_map( ))?; let function = lookup_function(function_name, exec_state, invoke_ctx)?; - // Let's consider a function f (f a b c ...) - // We will first re-arrange our sequences [a0, a1, ...] [b0, b1, ...] [c0, c1, ...] ... - // To get something like: [a0, b0, c0, ...] [a1, b1, c1, ...] - let mut mapped_func_args: Vec> = vec![]; + // Evaluate each sequence argument into an iterator and record its length. + let mut args_iterators = Vec::with_capacity(args.len() - 1); let mut min_args_len = usize::MAX; for map_arg in args[1..].iter() { let sequence = @@ -219,38 +217,31 @@ pub fn special_map( )) .into()); }; - let seq_len = seq.len(); - min_args_len = min_args_len.min(seq_len); - for (apply_index, element_result) in seq.into_iter().enumerate() { - let value = element_result.map_err(|_| { - VmInternalError::Expect( - "ERROR: Invalid sequence data successfully constructed".into(), - ) - })?; - if apply_index > min_args_len { - break; - } - if apply_index >= mapped_func_args.len() { - mapped_func_args.push(vec![value]); - } else { - mapped_func_args[apply_index].push(value); - } - } + let args_iter = seq.into_iter(); + min_args_len = min_args_len.min(args_iter.len()); + args_iterators.push(args_iter); } - // We can now apply the map - let mut mapped_results = vec![]; - let mut previous_len = None; - for arguments in mapped_func_args.into_iter() { - // Stop iterating when we are done with the shortest sequence - if let Some(previous_len) = previous_len { - if previous_len != arguments.len() { - break; - } - } else { - previous_len = Some(arguments.len()); + // Apply the function element-wise, stopping at the shortest sequence. + let mut mapped_results = Vec::with_capacity(min_args_len); + for _ in 0..min_args_len { + let mut call_args = Vec::with_capacity(args_iterators.len()); + for iter in args_iterators.iter_mut() { + let value = iter + .next() + .ok_or_else(|| { + RuntimeCheckErrorKind::Unreachable( + "iterator can't be shorter than min len".into(), + ) + })? + .map_err(|_| { + VmInternalError::Expect( + "ERROR: Invalid sequence data successfully constructed".into(), + ) + })?; + call_args.push(value); } - let res = apply_evaluated(&function, arguments, exec_state, invoke_ctx, context)?; + let res = apply_evaluated(&function, call_args, exec_state, invoke_ctx, context)?; mapped_results.push(res); } diff --git a/clarity/src/vm/tests/sequences.rs b/clarity/src/vm/tests/sequences.rs index 9008437a82b..5e08cac2c0e 100644 --- a/clarity/src/vm/tests/sequences.rs +++ b/clarity/src/vm/tests/sequences.rs @@ -389,9 +389,34 @@ fn test_simple_map_list() { #[test] fn test_variadic_map_list() { + // User functions + // - 1 sequence: 1 empty, result empty. let test = "(define-private (area (w int) (h int)) (* w h)) - (map area (list 5 10 1 2) (list 5 2 30 3))"; + (map area (list))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 2 sequences: 2 empties, result empty. + let test = "(define-private (area (w int) (h int)) (* w h)) + (map area (list) (list))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 2 sequences: 1 empty and 1 not, result empty. + let test = "(define-private (area (w int) (h int)) (* w h)) + (map area (list) (list 10 20))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 3 sequences: 1 empty in between, result empty. + let test = "(define-private (area (w int) (h int)) (* w h)) + (map area (list 1 2 3) (list) (list 10 20))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + // - 2 sequences: same length 4, result contains 4 elements. + let test = "(define-private (area (w int) (h int)) (* w h)) + (map area (list 5 10 1 2) (list 5 2 30 3))"; let expected = Value::list_from(vec![ Value::Int(25), Value::Int(20), @@ -401,9 +426,9 @@ fn test_variadic_map_list() { .unwrap(); assert_eq!(expected, execute(test).unwrap().unwrap()); + // - 2 sequences: same length 4 with different int types, result contains 4 elements. let test = "(define-private (u+ (a uint) (b int)) (+ a (to-uint b))) - (map u+ (list u5 u10 u1 u2) (list 5 2 30 3))"; - + (map u+ (list u5 u10 u1 u2) (list 5 2 30 3))"; let expected = Value::list_from(vec![ Value::UInt(10), Value::UInt(12), @@ -413,20 +438,66 @@ fn test_variadic_map_list() { .unwrap(); assert_eq!(expected, execute(test).unwrap().unwrap()); - let test = "(map + (list 5 10) (list 5 2 30 3))"; + // Native functions + // - 1 sequence: 1 empty, result empty. + let test = "(map + (list))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + // - 2 sequences: 2 empties, result empty. + let test = "(map + (list) (list))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 2 sequences: 1 empty and 1 not, result empty. + let test = "(map + (list) (list 10 20))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 3 sequences: 1 empty in between, result empty. + let test = "(map + (list 1 2 3) (list) (list 10 20))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 2 sequences: last is longer, result contains 2 elements. + let test = "(map + (list 5 10) (list 5 2 30 3))"; let expected = Value::list_from(vec![Value::Int(10), Value::Int(12)]).unwrap(); assert_eq!(expected, execute(test).unwrap().unwrap()); - let test = "(map pow (list 2 2 2 2) (list 1 2 3 4 5 6 7))"; + // - 3 sequences: last is the longest, result contains 2 elements. + let test = "(map + (list 1 2) (list 10 20 30) (list 100 200 300 400))"; + let expected = Value::list_from(vec![Value::Int(111), Value::Int(222)]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); - let expected = Value::list_from(vec![ - Value::Int(2), - Value::Int(4), - Value::Int(8), - Value::Int(16), - ]) - .unwrap(); + // - 3 sequences: last is the shortest, result contains 1 element. + let test = "(map + (list 1 2) (list 10 20 30) (list 100))"; + let expected = Value::list_from(vec![Value::Int(111)]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // Special Functions + // - 1 sequence: 1 empty, result empty. + let test = "(map > (list))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 2 sequences: 2 empties, result empty. + let test = "(map > (list) (list))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 2 sequences: 1 empty and 1 not, result empty. + let test = "(map > (list) (list 10 20))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 3 sequences: 1 empty in between, result empty. + let test = "(map > (list 1 2 3) (list) (list 10 20))"; + let expected = Value::list_from(vec![]).unwrap(); + assert_eq!(expected, execute(test).unwrap().unwrap()); + + // - 2 sequences: last is longer, result contains 2 elements. + let test = "(map > (list 1 10) (list 2 3 4 5))"; + let expected = Value::list_from(vec![Value::Bool(false), Value::Bool(true)]).unwrap(); assert_eq!(expected, execute(test).unwrap().unwrap()); }