mctp_rs/message_type/
mctp_control.rs

1use crate::{
2    MctpMedium, MctpMessageHeaderTrait, MctpMessageTrait,
3    MctpPacketError::{self, HeaderParseError},
4    error::{MctpPacketResult, ProtocolError},
5    mctp_command_code::MctpControlCommandCode,
6    mctp_completion_code::MctpCompletionCode,
7};
8
9#[derive(Debug, Default, PartialEq, Eq, Clone)]
10#[cfg_attr(feature = "defmt", derive(defmt::Format))]
11pub struct MctpControlHeader {
12    pub request_bit: bool,  // bit 7
13    pub datagram_bit: bool, // bit 6
14    pub instance_id: u8,    // bits 4-0
15    pub command_code: MctpControlCommandCode,
16    pub completion_code: MctpCompletionCode,
17}
18
19#[derive(Debug, PartialEq, Eq, Clone)]
20#[cfg_attr(feature = "defmt", derive(defmt::Format))]
21pub enum MctpControl {
22    SetEndpointIdRequest([u8; 2]),
23    SetEndpointIdResponse([u8; 3]),
24    GetEndpointIdRequest,
25    GetEndpointIdResponse([u8; 3]),
26}
27
28impl MctpMessageHeaderTrait for MctpControlHeader {
29    fn serialize<M: MctpMedium>(self, buffer: &mut [u8]) -> MctpPacketResult<usize, M> {
30        if buffer.len() < 3 {
31            return Err(crate::MctpPacketError::SerializeError(
32                "buffer too small for mctp control header",
33            ));
34        }
35
36        check_request_and_completion_code(self.request_bit, self.completion_code)?;
37
38        buffer[0] = (self.request_bit as u8) << 7
39            | (self.datagram_bit as u8) << 6
40            | (self.instance_id & 0b0001_1111);
41        buffer[1] = self.command_code as u8;
42        buffer[2] = self.completion_code.into();
43        Ok(3)
44    }
45
46    fn deserialize<M: MctpMedium>(buffer: &[u8]) -> MctpPacketResult<(Self, &[u8]), M> {
47        if buffer.len() < 3 {
48            return Err(HeaderParseError("buffer too small for mctp control header"));
49        }
50
51        let request_bit = buffer[0] & 0b1000_0000 != 0;
52        let datagram_bit = buffer[0] & 0b0100_0000 != 0;
53        let instance_id = buffer[0] & 0b0001_1111;
54        let command_code = MctpControlCommandCode::try_from(buffer[1])
55            .map_err(|_| HeaderParseError("invalid mctp command code"))?;
56        let completion_code = MctpCompletionCode::try_from(buffer[2])
57            .map_err(|_| HeaderParseError("invalid mctp completion code"))?;
58
59        check_request_and_completion_code(request_bit, completion_code)?;
60
61        Ok((
62            MctpControlHeader {
63                request_bit,
64                datagram_bit,
65                instance_id,
66                command_code,
67                completion_code,
68            },
69            &buffer[3..],
70        ))
71    }
72}
73
74fn check_request_and_completion_code<M: MctpMedium>(
75    request_bit: bool,
76    completion_code: MctpCompletionCode,
77) -> MctpPacketResult<(), M> {
78    if request_bit && completion_code != MctpCompletionCode::Success {
79        return Err(MctpPacketError::ProtocolError(
80            ProtocolError::CompletionCodeOnRequestMessage(completion_code),
81        ));
82    }
83    Ok(())
84}
85
86impl<'buf> MctpMessageTrait<'buf> for MctpControl {
87    type Header = MctpControlHeader;
88    const MESSAGE_TYPE: u8 = 0x00;
89
90    fn serialize<M: MctpMedium>(self, buffer: &mut [u8]) -> MctpPacketResult<usize, M> {
91        match self {
92            Self::SetEndpointIdRequest(data) => copy_and_check_len(buffer, data),
93            Self::SetEndpointIdResponse(data) => copy_and_check_len(buffer, data),
94            Self::GetEndpointIdRequest => copy_and_check_len(buffer, []),
95            Self::GetEndpointIdResponse(data) => copy_and_check_len(buffer, data),
96        }
97    }
98
99    fn deserialize<M: MctpMedium>(
100        header: &Self::Header,
101        buffer: &'buf [u8],
102    ) -> MctpPacketResult<Self, M> {
103        let message = match (header.request_bit, header.command_code) {
104            (true, MctpControlCommandCode::SetEndpointId) => {
105                Self::SetEndpointIdRequest(try_into_array(buffer)?)
106            }
107            (true, MctpControlCommandCode::GetEndpointId) => Self::GetEndpointIdRequest,
108            (false, MctpControlCommandCode::SetEndpointId) => {
109                Self::SetEndpointIdResponse(try_into_array(buffer)?)
110            }
111            (false, MctpControlCommandCode::GetEndpointId) => {
112                Self::GetEndpointIdResponse(try_into_array(buffer)?)
113            }
114            _ => {
115                return Err(HeaderParseError("invalid mctp control command code"));
116            }
117        };
118        Ok(message)
119    }
120}
121
122fn copy_and_check_len<const N: usize, M: MctpMedium>(
123    buffer: &mut [u8],
124    data: [u8; N],
125) -> MctpPacketResult<usize, M> {
126    if buffer.len() < N {
127        return Err(crate::MctpPacketError::SerializeError(
128            "buffer too small for mctp control message",
129        ));
130    }
131    buffer[..N].copy_from_slice(&data);
132    Ok(N)
133}
134
135fn try_into_array<const N: usize, M: MctpMedium>(buffer: &[u8]) -> MctpPacketResult<[u8; N], M> {
136    if buffer.len() < N {
137        return Err(HeaderParseError(
138            "buffer too small for mctp control message",
139        ));
140    }
141    Ok(buffer[..N].try_into().unwrap())
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::{error::ProtocolError, test_util::TestMedium};
148
149    #[test]
150    fn header_serialize_deserialize_happy_path() {
151        let header = MctpControlHeader {
152            request_bit: true,
153            datagram_bit: false,
154            instance_id: 0b1_1111,
155            command_code: MctpControlCommandCode::GetEndpointId,
156            completion_code: MctpCompletionCode::Success,
157        };
158
159        let mut buf = [0u8; 3];
160        let size = header.clone().serialize::<TestMedium>(&mut buf).unwrap();
161        assert_eq!(size, 3);
162        assert_eq!(
163            buf,
164            [
165                0b1000_0000 | 0b0001_1111, // rq=1, d=0, instance id=0x1F
166                MctpControlCommandCode::GetEndpointId as u8,
167                u8::from(MctpCompletionCode::Success),
168            ]
169        );
170
171        let (parsed, rest) = MctpControlHeader::deserialize::<TestMedium>(&buf).unwrap();
172        assert_eq!(parsed, header);
173        assert_eq!(rest.len(), 0);
174    }
175
176    #[test]
177    fn header_serialize_error_on_completion_code_in_request() {
178        let header = MctpControlHeader {
179            request_bit: true,
180            datagram_bit: false,
181            instance_id: 0,
182            command_code: MctpControlCommandCode::SetEndpointId,
183            completion_code: MctpCompletionCode::Error,
184        };
185
186        let mut buf = [0u8; 3];
187        let err = header.serialize::<TestMedium>(&mut buf).unwrap_err();
188        match err {
189            MctpPacketError::ProtocolError(ProtocolError::CompletionCodeOnRequestMessage(code)) => {
190                assert_eq!(code, MctpCompletionCode::Error)
191            }
192            other => panic!("unexpected error: {:?}", other),
193        }
194    }
195
196    #[rstest::rstest]
197    #[case(MctpControlCommandCode::SetEndpointId, false, MctpControl::SetEndpointIdResponse([0xAA, 0xBB, 0xCC]), &[0xAA, 0xBB, 0xCC])]
198    #[case(MctpControlCommandCode::SetEndpointId, true, MctpControl::SetEndpointIdRequest([0xAA, 0xBB]), &[0xAA, 0xBB])]
199    #[case(MctpControlCommandCode::GetEndpointId, false, MctpControl::GetEndpointIdResponse([0xAA, 0xBB, 0xCC]), &[0xAA, 0xBB, 0xCC])]
200    #[case(MctpControlCommandCode::GetEndpointId, true, MctpControl::GetEndpointIdRequest, &[])]
201    fn message_serialize_deserialize_happy_path(
202        #[case] command_code: MctpControlCommandCode,
203        #[case] request_bit: bool,
204        #[case] message: MctpControl,
205        #[case] expected: &[u8],
206    ) {
207        let mut buf = [0u8; 1024];
208        let size = message.clone().serialize::<TestMedium>(&mut buf).unwrap();
209        assert_eq!(size, expected.len());
210        assert_eq!(&buf[..size], expected);
211
212        let header = MctpControlHeader {
213            request_bit,
214            datagram_bit: false,
215            instance_id: 0,
216            command_code,
217            completion_code: MctpCompletionCode::Success,
218        };
219
220        let parsed = MctpControl::deserialize::<TestMedium>(&header, &buf).unwrap();
221        assert_eq!(parsed, message);
222    }
223
224    #[test]
225    fn message_deserialize_error_on_invalid_command_for_header() {
226        // request message with unsupported command code should error
227        let header = MctpControlHeader {
228            request_bit: true,
229            datagram_bit: false,
230            instance_id: 0,
231            command_code: MctpControlCommandCode::Reserved,
232            completion_code: MctpCompletionCode::Success,
233        };
234
235        let err = MctpControl::deserialize::<TestMedium>(&header, &[]).unwrap_err();
236        match err {
237            MctpPacketError::HeaderParseError(msg) => {
238                assert_eq!(msg, "invalid mctp control command code")
239            }
240            other => panic!("unexpected error: {:?}", other),
241        }
242    }
243}