bitcoin/consensus/
serde.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Serde serialization via consensus encoding
4//!
5//! This provides functions for (de)serializing any type as consensus-encoded bytes.
6//! For human-readable formats it serializes as a string with a consumer-supplied encoding, for
7//! binary formats it serializes as a sequence of bytes (not `serialize_bytes` to avoid allocations).
8//!
9//! The string encoding has to be specified using a marker type implementing the encoding strategy.
10//! This crate provides hex encoding via `Hex<Upper>` and `Hex<Lower>`
11
12use core::fmt;
13use core::marker::PhantomData;
14
15use io::Write;
16use serde::de::{SeqAccess, Unexpected, Visitor};
17use serde::ser::SerializeSeq;
18use serde::{Deserializer, Serializer};
19
20use super::encode::Error as ConsensusError;
21use super::{Decodable, Encodable};
22use crate::consensus::{DecodeError, IterReader};
23
24/// Hex-encoding strategy
25pub struct Hex<Case = hex::Lower>(PhantomData<Case>)
26where
27    Case: hex::Case;
28
29impl<C: hex::Case> Default for Hex<C> {
30    fn default() -> Self { Hex(Default::default()) }
31}
32
33impl<C: hex::Case> ByteEncoder for Hex<C> {
34    type Encoder = hex::Encoder<C>;
35}
36
37/// Implements hex encoding.
38pub mod hex {
39    use core::fmt;
40    use core::marker::PhantomData;
41
42    use hex::buf_encoder::BufEncoder;
43
44    /// Marker for upper/lower case type-level flags ("type-level enum").
45    ///
46    /// You may use this trait in bounds only.
47    pub trait Case: sealed::Case {}
48    impl<T: sealed::Case> Case for T {}
49
50    /// Marker for using lower-case hex encoding.
51    pub enum Lower {}
52    /// Marker for using upper-case hex encoding.
53    pub enum Upper {}
54
55    mod sealed {
56        pub trait Case {
57            /// Internal detail, don't depend on it!!!
58            const INTERNAL_CASE: hex::Case;
59        }
60
61        impl Case for super::Lower {
62            const INTERNAL_CASE: hex::Case = hex::Case::Lower;
63        }
64
65        impl Case for super::Upper {
66            const INTERNAL_CASE: hex::Case = hex::Case::Upper;
67        }
68    }
69
70    // We just guessed at a reasonably sane value.
71    const HEX_BUF_SIZE: usize = 512;
72
73    /// Hex byte encoder.
74    // We wrap `BufEncoder` to not leak internal representation.
75    pub struct Encoder<C: Case>(BufEncoder<{ HEX_BUF_SIZE }>, PhantomData<C>);
76
77    impl<C: Case> From<super::Hex<C>> for Encoder<C> {
78        fn from(_: super::Hex<C>) -> Self { Encoder(BufEncoder::new(), Default::default()) }
79    }
80
81    impl<C: Case> super::EncodeBytes for Encoder<C> {
82        fn encode_chunk<W: fmt::Write>(&mut self, writer: &mut W, mut bytes: &[u8]) -> fmt::Result {
83            while !bytes.is_empty() {
84                if self.0.is_full() {
85                    self.flush(writer)?;
86                }
87                bytes = self.0.put_bytes_min(bytes, C::INTERNAL_CASE);
88            }
89            Ok(())
90        }
91
92        fn flush<W: fmt::Write>(&mut self, writer: &mut W) -> fmt::Result {
93            writer.write_str(self.0.as_str())?;
94            self.0.clear();
95            Ok(())
96        }
97    }
98
99    // Newtypes to hide internal details.
100
101    /// Error returned when a hex string decoder can't be created.
102    #[derive(Debug, Clone, PartialEq, Eq)]
103    pub struct DecodeInitError(hex::OddLengthStringError);
104
105    /// Error returned when a hex string contains invalid characters.
106    #[derive(Debug, Clone, PartialEq, Eq)]
107    pub struct DecodeError(hex::InvalidCharError);
108
109    /// Hex decoder state.
110    pub struct Decoder<'a>(hex::HexSliceToBytesIter<'a>);
111
112    impl<'a> Decoder<'a> {
113        fn new(s: &'a str) -> Result<Self, DecodeInitError> {
114            match hex::HexToBytesIter::new(s) {
115                Ok(iter) => Ok(Decoder(iter)),
116                Err(error) => Err(DecodeInitError(error)),
117            }
118        }
119    }
120
121    impl<'a> Iterator for Decoder<'a> {
122        type Item = Result<u8, DecodeError>;
123
124        fn next(&mut self) -> Option<Self::Item> {
125            self.0.next().map(|result| result.map_err(DecodeError))
126        }
127    }
128
129    impl<'a, C: Case> super::ByteDecoder<'a> for super::Hex<C> {
130        type InitError = DecodeInitError;
131        type DecodeError = DecodeError;
132        type Decoder = Decoder<'a>;
133
134        fn from_str(s: &'a str) -> Result<Self::Decoder, Self::InitError> { Decoder::new(s) }
135    }
136
137    impl super::IntoDeError for DecodeInitError {
138        fn into_de_error<E: serde::de::Error>(self) -> E {
139            E::invalid_length(self.0.length(), &"an even number of ASCII-encoded hex digits")
140        }
141    }
142
143    impl super::IntoDeError for DecodeError {
144        fn into_de_error<E: serde::de::Error>(self) -> E {
145            use serde::de::Unexpected;
146
147            const EXPECTED_CHAR: &str = "an ASCII-encoded hex digit";
148
149            match self.0.invalid_char() {
150                c if c.is_ascii() => E::invalid_value(Unexpected::Char(c as _), &EXPECTED_CHAR),
151                c => E::invalid_value(Unexpected::Unsigned(c.into()), &EXPECTED_CHAR),
152            }
153        }
154    }
155}
156
157struct DisplayWrapper<'a, T: 'a + Encodable, E>(&'a T, PhantomData<E>);
158
159impl<'a, T: 'a + Encodable, E: ByteEncoder> fmt::Display for DisplayWrapper<'a, T, E> {
160    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161        let mut writer = IoWrapper::<'_, _, E::Encoder>::new(f, E::default().into());
162        self.0.consensus_encode(&mut writer).map_err(|error| {
163            #[cfg(debug_assertions)]
164            {
165                if error.kind() != io::ErrorKind::Other
166                    || error.get_ref().is_some()
167                    || !writer.writer.was_error
168                {
169                    panic!(
170                        "{} returned an unexpected error: {:?}",
171                        core::any::type_name::<T>(),
172                        error
173                    );
174                }
175            }
176            fmt::Error
177        })?;
178        let result = writer.actually_flush();
179        if result.is_err() {
180            writer.writer.assert_was_error::<E>();
181        }
182        result
183    }
184}
185
186struct ErrorTrackingWriter<W: fmt::Write> {
187    writer: W,
188    #[cfg(debug_assertions)]
189    was_error: bool,
190}
191
192impl<W: fmt::Write> ErrorTrackingWriter<W> {
193    fn new(writer: W) -> Self {
194        ErrorTrackingWriter {
195            writer,
196            #[cfg(debug_assertions)]
197            was_error: false,
198        }
199    }
200
201    #[track_caller]
202    fn assert_no_error(&self, fun: &str) {
203        #[cfg(debug_assertions)]
204        {
205            if self.was_error {
206                panic!("`{}` called on errored writer", fun);
207            }
208        }
209    }
210
211    fn assert_was_error<Offender>(&self) {
212        #[cfg(debug_assertions)]
213        {
214            if !self.was_error {
215                panic!("{} returned an error unexpectedly", core::any::type_name::<Offender>());
216            }
217        }
218    }
219
220    fn set_error(&mut self, was: bool) {
221        #[cfg(debug_assertions)]
222        {
223            self.was_error |= was;
224        }
225    }
226
227    fn check_err<T, E>(&mut self, result: Result<T, E>) -> Result<T, E> {
228        self.set_error(result.is_err());
229        result
230    }
231}
232
233impl<W: fmt::Write> fmt::Write for ErrorTrackingWriter<W> {
234    fn write_str(&mut self, s: &str) -> fmt::Result {
235        self.assert_no_error("write_str");
236        let result = self.writer.write_str(s);
237        self.check_err(result)
238    }
239
240    fn write_char(&mut self, c: char) -> fmt::Result {
241        self.assert_no_error("write_char");
242        let result = self.writer.write_char(c);
243        self.check_err(result)
244    }
245}
246
247struct IoWrapper<'a, W: fmt::Write, E: EncodeBytes> {
248    writer: ErrorTrackingWriter<&'a mut W>,
249    encoder: E,
250}
251
252impl<'a, W: fmt::Write, E: EncodeBytes> IoWrapper<'a, W, E> {
253    fn new(writer: &'a mut W, encoder: E) -> Self {
254        IoWrapper { writer: ErrorTrackingWriter::new(writer), encoder }
255    }
256
257    fn actually_flush(&mut self) -> fmt::Result { self.encoder.flush(&mut self.writer) }
258}
259
260impl<'a, W: fmt::Write, E: EncodeBytes> Write for IoWrapper<'a, W, E> {
261    fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
262        match self.encoder.encode_chunk(&mut self.writer, bytes) {
263            Ok(()) => Ok(bytes.len()),
264            Err(fmt::Error) => {
265                self.writer.assert_was_error::<E>();
266                Err(io::Error::from(io::ErrorKind::Other))
267            }
268        }
269    }
270    // we intentionally ignore flushes because we will do a single flush at the end.
271    fn flush(&mut self) -> io::Result<()> { Ok(()) }
272}
273
274/// Provides an instance of byte-to-string encoder.
275///
276/// This is basically a type constructor used in places where value arguments are not accepted.
277/// Such as the generic `serialize`.
278pub trait ByteEncoder: Default {
279    /// The encoder state.
280    type Encoder: EncodeBytes + From<Self>;
281}
282
283/// Transforms given bytes and writes to the writer.
284///
285/// The encoder is allowed to be buffered (and probably should be).
286/// The design passing writer each time bypasses the need for GAT.
287pub trait EncodeBytes {
288    /// Transform the provided slice and write to the writer.
289    ///
290    /// This is similar to the `write_all` method on `io::Write`.
291    fn encode_chunk<W: fmt::Write>(&mut self, writer: &mut W, bytes: &[u8]) -> fmt::Result;
292
293    /// Write data in buffer (if any) to the writer.
294    fn flush<W: fmt::Write>(&mut self, writer: &mut W) -> fmt::Result;
295}
296
297/// Provides an instance of string-to-byte decoder.
298///
299/// This is basically a type constructor used in places where value arguments are not accepted.
300/// Such as the generic `serialize`.
301pub trait ByteDecoder<'a> {
302    /// Error returned when decoder can't be created.
303    ///
304    /// This is typically returned when string length is invalid.
305    type InitError: IntoDeError + fmt::Debug;
306
307    /// Error returned when decoding fails.
308    ///
309    /// This is typically returned when the input string contains malformed chars.
310    type DecodeError: IntoDeError + fmt::Debug;
311
312    /// The decoder state.
313    type Decoder: Iterator<Item = Result<u8, Self::DecodeError>>;
314
315    /// Constructs the decoder from string.
316    fn from_str(s: &'a str) -> Result<Self::Decoder, Self::InitError>;
317}
318
319/// Converts error into a type implementing `serde::de::Error`
320pub trait IntoDeError {
321    /// Performs the conversion.
322    fn into_de_error<E: serde::de::Error>(self) -> E;
323}
324
325struct BinWriter<S: SerializeSeq> {
326    serializer: S,
327    error: Option<S::Error>,
328}
329
330impl<S: SerializeSeq> Write for BinWriter<S> {
331    fn write(&mut self, buf: &[u8]) -> io::Result<usize> { self.write_all(buf).map(|_| buf.len()) }
332
333    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
334        for byte in buf {
335            if let Err(error) = self.serializer.serialize_element(byte) {
336                self.error = Some(error);
337                return Err(io::ErrorKind::Other.into());
338            }
339        }
340        Ok(())
341    }
342
343    fn flush(&mut self) -> io::Result<()> { Ok(()) }
344}
345
346struct DisplayExpected<D: fmt::Display>(D);
347
348impl<D: fmt::Display> serde::de::Expected for DisplayExpected<D> {
349    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
350        fmt::Display::fmt(&self.0, formatter)
351    }
352}
353
354// not a trait impl because we panic on some variants
355fn consensus_error_into_serde<E: serde::de::Error>(error: ConsensusError) -> E {
356    match error {
357        ConsensusError::Io(error) => panic!("unexpected IO error {:?}", error),
358        ConsensusError::OversizedVectorAllocation { requested, max } => E::custom(format_args!(
359            "the requested allocation of {} items exceeds maximum of {}",
360            requested, max
361        )),
362        ConsensusError::InvalidChecksum { expected, actual } => E::invalid_value(
363            Unexpected::Bytes(&actual),
364            &DisplayExpected(format_args!(
365                "checksum {:02x}{:02x}{:02x}{:02x}",
366                expected[0], expected[1], expected[2], expected[3]
367            )),
368        ),
369        ConsensusError::NonMinimalVarInt =>
370            E::custom(format_args!("compact size was not encoded minimally")),
371        ConsensusError::ParseFailed(msg) => E::custom(msg),
372        ConsensusError::UnsupportedSegwitFlag(flag) =>
373            E::invalid_value(Unexpected::Unsigned(flag.into()), &"segwit version 1 flag"),
374    }
375}
376
377impl<E> DecodeError<E>
378where
379    E: serde::de::Error,
380{
381    fn unify(self) -> E {
382        match self {
383            DecodeError::Other(error) => error,
384            DecodeError::TooManyBytes => E::custom(format_args!("got more bytes than expected")),
385            DecodeError::Consensus(error) => consensus_error_into_serde(error),
386        }
387    }
388}
389
390impl<E> IntoDeError for DecodeError<E>
391where
392    E: IntoDeError,
393{
394    fn into_de_error<DE: serde::de::Error>(self) -> DE {
395        match self {
396            DecodeError::Other(error) => error.into_de_error(),
397            DecodeError::TooManyBytes => DE::custom(format_args!("got more bytes than expected")),
398            DecodeError::Consensus(error) => consensus_error_into_serde(error),
399        }
400    }
401}
402
403/// Helper for `#[serde(with = "")]`.
404///
405/// To (de)serialize a field using consensus encoding you can write e.g.:
406///
407/// ```
408/// # use actual_serde::{Serialize, Deserialize};
409/// use bitcoin::Transaction;
410/// use bitcoin::consensus;
411///
412/// #[derive(Serialize, Deserialize)]
413/// # #[serde(crate = "actual_serde")]
414/// pub struct MyStruct {
415///     #[serde(with = "consensus::serde::With::<consensus::serde::Hex>")]
416///     tx: Transaction,
417/// }
418/// ```
419pub struct With<E>(PhantomData<E>);
420
421impl<E> With<E> {
422    /// Serializes the value as consensus-encoded
423    pub fn serialize<T: Encodable, S: Serializer>(
424        value: &T,
425        serializer: S,
426    ) -> Result<S::Ok, S::Error>
427    where
428        E: ByteEncoder,
429    {
430        if serializer.is_human_readable() {
431            serializer.collect_str(&DisplayWrapper::<'_, _, E>(value, Default::default()))
432        } else {
433            let serializer = serializer.serialize_seq(None)?;
434            let mut writer = BinWriter { serializer, error: None };
435
436            let result = value.consensus_encode(&mut writer);
437            match (result, writer.error) {
438                (Ok(_), None) => writer.serializer.end(),
439                (Ok(_), Some(error)) =>
440                    panic!("{} silently ate an IO error: {:?}", core::any::type_name::<T>(), error),
441                (Err(io_error), Some(ser_error))
442                    if io_error.kind() == io::ErrorKind::Other && io_error.get_ref().is_none() =>
443                    Err(ser_error),
444                (Err(io_error), ser_error) => panic!(
445                    "{} returned an unexpected IO error: {:?} serialization error: {:?}",
446                    core::any::type_name::<T>(),
447                    io_error,
448                    ser_error
449                ),
450            }
451        }
452    }
453
454    /// Deserializes the value as consensus-encoded
455    pub fn deserialize<'d, T: Decodable, D: Deserializer<'d>>(
456        deserializer: D,
457    ) -> Result<T, D::Error>
458    where
459        for<'a> E: ByteDecoder<'a>,
460    {
461        if deserializer.is_human_readable() {
462            deserializer.deserialize_str(HRVisitor::<_, E>(Default::default()))
463        } else {
464            deserializer.deserialize_seq(BinVisitor(Default::default()))
465        }
466    }
467}
468
469struct HRVisitor<T: Decodable, D: for<'a> ByteDecoder<'a>>(PhantomData<fn() -> (T, D)>);
470
471impl<'de, T: Decodable, D: for<'a> ByteDecoder<'a>> Visitor<'de> for HRVisitor<T, D> {
472    type Value = T;
473
474    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
475        formatter.write_str("bytes encoded as a hex string")
476    }
477
478    fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<T, E> {
479        let decoder = D::from_str(s).map_err(IntoDeError::into_de_error)?;
480        IterReader::new(decoder).decode().map_err(IntoDeError::into_de_error)
481    }
482}
483
484struct BinVisitor<T: Decodable>(PhantomData<fn() -> T>);
485
486impl<'de, T: Decodable> Visitor<'de> for BinVisitor<T> {
487    type Value = T;
488
489    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
490        formatter.write_str("a sequence of bytes")
491    }
492
493    fn visit_seq<S: SeqAccess<'de>>(self, s: S) -> Result<T, S::Error> {
494        IterReader::new(SeqIterator(s, Default::default())).decode().map_err(DecodeError::unify)
495    }
496}
497
498struct SeqIterator<'a, S: serde::de::SeqAccess<'a>>(S, PhantomData<&'a ()>);
499
500impl<'a, S: serde::de::SeqAccess<'a>> Iterator for SeqIterator<'a, S> {
501    type Item = Result<u8, S::Error>;
502
503    fn next(&mut self) -> Option<Self::Item> { self.0.next_element::<u8>().transpose() }
504}