{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Test.Cardano.Ledger.Shelley.Serialisation.GoldenUtils (
  checkEncoding,
  checkEncodingCBOR,
  checkEncodingCBORAnnotated,
  ToTokens (..),
  roundTripFailure,
  checkEncodingCBORDecodeFailure,
)
where

import Cardano.Ledger.Binary (
  Annotator,
  DecCBOR (..),
  DecoderError,
  EncCBOR (..),
  EncCBORGroup (..),
  Encoding,
  ToCBOR (..),
  Tokens (..),
  Version,
  decodeFullAnnotator,
  decodeFullDecoder,
  decodeTerm,
  encCBOR,
  fromPlainEncoding,
  serialize,
  serialize',
 )

-- ToExpr (CBOR.Term) instance
import qualified Codec.CBOR.Encoding as CBOR (Encoding (..))
import Control.Exception (throwIO)
import Control.Monad (unless)
import qualified Data.ByteString.Lazy as BSL (ByteString)
import Data.String (fromString)
import GHC.Stack
import qualified Prettyprinter as Pretty
import Test.Cardano.Ledger.Binary.TreeDiff (ansiDocToString, diffExpr)
import Test.Tasty (TestTree)
import Test.Tasty.HUnit (Assertion, assertFailure, testCase, (@?=))

expectDecodingSuccess :: (HasCallStack, Show a, Eq a) => (a -> Either DecoderError a) -> a -> IO ()
expectDecodingSuccess :: forall a.
(HasCallStack, Show a, Eq a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingSuccess a -> Either DecoderError a
action a
x =
  case a -> Either DecoderError a
action a
x of
    Left DecoderError
e -> forall a. HasCallStack => [Char] -> IO a
assertFailure forall a b. (a -> b) -> a -> b
$ [Char]
"could not decode serialization of " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show a
x forall a. [a] -> [a] -> [a]
++ [Char]
", " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show DecoderError
e
    Right a
y -> a
y forall a. (Eq a, Show a, HasCallStack) => a -> a -> IO ()
@?= a
x

expectDecodingFailure :: (HasCallStack, Show a) => (a -> Either DecoderError a) -> a -> IO ()
expectDecodingFailure :: forall a.
(HasCallStack, Show a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingFailure a -> Either DecoderError a
action a
x =
  case a -> Either DecoderError a
action a
x of
    Left DecoderError
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Right a
_ -> forall a. HasCallStack => [Char] -> IO a
assertFailure forall a b. (a -> b) -> a -> b
$ [Char]
"Did not expect successful decoding of " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show a
x

roundtrip ::
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  a ->
  Either DecoderError a
roundtrip :: forall a.
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
roundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode = ByteString -> Either DecoderError a
decode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. EncCBOR a => Version -> a -> ByteString
serialize Version
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Encoding
encode

roundTripSuccess ::
  (Show a, Eq a) =>
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  a ->
  Assertion
roundTripSuccess :: forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode a
x = forall a.
(HasCallStack, Show a, Eq a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingSuccess (forall a.
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
roundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode) a
x

roundTripFailure ::
  Show a => Version -> (a -> Encoding) -> (BSL.ByteString -> Either DecoderError a) -> a -> Assertion
roundTripFailure :: forall a.
Show a =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripFailure Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode a
x = forall a.
(HasCallStack, Show a) =>
(a -> Either DecoderError a) -> a -> IO ()
expectDecodingFailure (forall a.
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> Either DecoderError a
roundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode) a
x

checkEncoding ::
  (HasCallStack, Show a, Eq a) =>
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncoding :: forall a.
(HasCallStack, Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncoding Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode [Char]
name a
x ToTokens
t = forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess [Char]
name a
x ToTokens
t

checkEncodingWithRoundtrip ::
  HasCallStack =>
  Version ->
  (a -> Encoding) ->
  (BSL.ByteString -> Either DecoderError a) ->
  (Version -> (a -> Encoding) -> (BSL.ByteString -> Either DecoderError a) -> a -> Assertion) ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingWithRoundtrip :: forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTrip [Char]
name a
x ToTokens
t =
  [Char] -> IO () -> TestTree
testCase [Char]
testName forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
expectedBinary forall a. Eq a => a -> a -> Bool
== ByteString
actualBinary) forall a b. (a -> b) -> a -> b
$ do
      Term
expectedTerms <- Text -> ByteString -> IO Term
getTerms Text
"expected" ByteString
expectedBinary
      Term
actualTerms <- Text -> ByteString -> IO Term
getTerms Text
"actual" ByteString
actualBinary
      forall a. HasCallStack => [Char] -> IO a
assertFailure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc AnsiStyle -> [Char]
ansiDocToString forall a b. (a -> b) -> a -> b
$
        forall ann. [Doc ann] -> Doc ann
Pretty.vsep
          [ Doc AnsiStyle
"Serialization did not match:"
          , forall ann. Int -> Doc ann -> Doc ann
Pretty.indent Int
2 forall a b. (a -> b) -> a -> b
$ forall a. ToExpr a => a -> a -> Doc AnsiStyle
diffExpr Term
expectedTerms Term
actualTerms
          ]
    Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTrip Version
v a -> Encoding
encode ByteString -> Either DecoderError a
decode a
x
  where
    getTerms :: Text -> ByteString -> IO Term
getTerms Text
lbl = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e a. Exception e => e -> IO a
throwIO forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
decodeFullDecoder Version
v Text
lbl forall s. Decoder s Term
decodeTerm
    expectedBinary :: ByteString
expectedBinary = forall a. EncCBOR a => Version -> a -> ByteString
serialize Version
v ToTokens
t
    actualBinary :: ByteString
actualBinary = forall a. EncCBOR a => Version -> a -> ByteString
serialize Version
v forall a b. (a -> b) -> a -> b
$ a -> Encoding
encode a
x
    testName :: [Char]
testName = [Char]
"golden_serialize_" forall a. Semigroup a => a -> a -> a
<> [Char]
name

checkEncodingCBORDecodeFailure ::
  (HasCallStack, DecCBOR a, EncCBOR a, Show a) =>
  Version ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingCBORDecodeFailure :: forall a.
(HasCallStack, DecCBOR a, EncCBOR a, Show a) =>
Version -> [Char] -> a -> ToTokens -> TestTree
checkEncodingCBORDecodeFailure Version
v [Char]
name a
x ToTokens
t =
  let d :: ByteString -> Either DecoderError a
d = forall a.
Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
decodeFullDecoder Version
v (forall a. IsString a => [Char] -> a
fromString [Char]
name) forall a s. DecCBOR a => Decoder s a
decCBOR
   in forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v forall a. EncCBOR a => a -> Encoding
encCBOR ByteString -> Either DecoderError a
d forall a.
Show a =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripFailure [Char]
name a
x ToTokens
t

checkEncodingCBOR ::
  (HasCallStack, DecCBOR a, EncCBOR a, Show a, Eq a) =>
  Version ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingCBOR :: forall a.
(HasCallStack, DecCBOR a, EncCBOR a, Show a, Eq a) =>
Version -> [Char] -> a -> ToTokens -> TestTree
checkEncodingCBOR Version
v [Char]
name a
x ToTokens
t =
  let d :: ByteString -> Either DecoderError a
d = forall a.
Version
-> Text
-> (forall s. Decoder s a)
-> ByteString
-> Either DecoderError a
decodeFullDecoder Version
v (forall a. IsString a => [Char] -> a
fromString [Char]
name) forall a s. DecCBOR a => Decoder s a
decCBOR
   in forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v forall a. EncCBOR a => a -> Encoding
encCBOR ByteString -> Either DecoderError a
d forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess [Char]
name a
x ToTokens
t

checkEncodingCBORAnnotated ::
  (HasCallStack, DecCBOR (Annotator a), ToCBOR a, Show a, Eq a) =>
  Version ->
  String ->
  a ->
  ToTokens ->
  TestTree
checkEncodingCBORAnnotated :: forall a.
(HasCallStack, DecCBOR (Annotator a), ToCBOR a, Show a, Eq a) =>
Version -> [Char] -> a -> ToTokens -> TestTree
checkEncodingCBORAnnotated Version
v [Char]
name a
x ToTokens
t =
  let d :: ByteString -> Either DecoderError a
d = forall a.
Version
-> Text
-> (forall s. Decoder s (Annotator a))
-> ByteString
-> Either DecoderError a
decodeFullAnnotator Version
v (forall a. IsString a => [Char] -> a
fromString [Char]
name) forall a s. DecCBOR a => Decoder s a
decCBOR
   in forall a.
HasCallStack =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> (Version
    -> (a -> Encoding)
    -> (ByteString -> Either DecoderError a)
    -> a
    -> IO ())
-> [Char]
-> a
-> ToTokens
-> TestTree
checkEncodingWithRoundtrip Version
v (Encoding -> Encoding
fromPlainEncoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToCBOR a => a -> Encoding
toCBOR) ByteString -> Either DecoderError a
d forall a.
(Show a, Eq a) =>
Version
-> (a -> Encoding)
-> (ByteString -> Either DecoderError a)
-> a
-> IO ()
roundTripSuccess [Char]
name a
x ToTokens
annTokens
  where
    annTokens :: ToTokens
annTokens = (Tokens -> Tokens) -> ToTokens
T forall a b. (a -> b) -> a -> b
$ ByteString -> Tokens -> Tokens
TkEncoded forall a b. (a -> b) -> a -> b
$ forall a. EncCBOR a => Version -> a -> ByteString
serialize' Version
v ToTokens
t

data ToTokens where
  T :: (Tokens -> Tokens) -> ToTokens
  S :: EncCBOR a => a -> ToTokens
  G :: EncCBORGroup a => a -> ToTokens
  Plus :: ToTokens -> ToTokens -> ToTokens

instance EncCBOR ToTokens where
  encCBOR :: ToTokens -> Encoding
encCBOR (T Tokens -> Tokens
xs) = Encoding -> Encoding
fromPlainEncoding ((Tokens -> Tokens) -> Encoding
CBOR.Encoding Tokens -> Tokens
xs)
  encCBOR (S a
s) = forall a. EncCBOR a => a -> Encoding
encCBOR a
s
  encCBOR (G a
g) = forall a. EncCBORGroup a => a -> Encoding
encCBORGroup a
g
  encCBOR (Plus ToTokens
a ToTokens
b) = forall a. EncCBOR a => a -> Encoding
encCBOR ToTokens
a forall a. Semigroup a => a -> a -> a
<> forall a. EncCBOR a => a -> Encoding
encCBOR ToTokens
b

instance Semigroup ToTokens where
  <> :: ToTokens -> ToTokens -> ToTokens
(<>) = ToTokens -> ToTokens -> ToTokens
Plus

instance Monoid ToTokens where
  mempty :: ToTokens
mempty = (Tokens -> Tokens) -> ToTokens
T forall a. a -> a
id