stm32.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. # Copyright 2024 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import array
  15. import logging
  16. import os
  17. import struct
  18. import time
  19. from intelhex import IntelHex
  20. LOG = logging.getLogger(__name__)
  21. class STM32FlashProgrammer(object):
  22. # CPUID register
  23. CPUID_ADDR = 0xE000ED00
  24. # Flash constants
  25. FLASH_BASE_ADDR = 0x08000000
  26. # Flash key register (FLASH_KEYR)
  27. FLASH_KEYR_ADDR = 0x40023C04
  28. FLASH_KEYR_VAL1 = 0x45670123
  29. FLASH_KEYR_VAL2 = 0xCDEF89AB
  30. # Flash status register (FLASH_SR)
  31. FLASH_SR_ADDR = 0x40023C0C
  32. FLASH_SR_BSY = (1 << 16)
  33. # Flash control register (FLASH_CR)
  34. FLASH_CR_ADDR = 0x40023C10
  35. FLASH_CR_PG = (1 << 0)
  36. FLASH_CR_SER = (1 << 1)
  37. FLASH_CR_SNB_OFFSET = 3
  38. FLASH_CR_PSIZE_8BIT = (0x0 << 8)
  39. FLASH_CR_PSIZE_16BIT = (0x1 << 8)
  40. FLASH_CR_PSIZE_32BIT = (0x2 << 8)
  41. FLASH_CR_STRT = (1 << 16)
  42. # Debug halting control and status register (DHCSR)
  43. DHCSR_ADDR = 0xE000EDF0
  44. DHCSR_DBGKEY_VALUE = 0xA05F0000
  45. DHCSR_HALT = (1 << 0)
  46. DHCSR_DEBUGEN = (1 << 1)
  47. DHCSR_S_REGRDY = (1 << 16)
  48. DHCSR_S_LOCKUP = (1 << 19)
  49. # Application interrupt and reset control register (AIRCR)
  50. AIRCR_ADDR = 0xE000ED0C
  51. AIRCR_VECTKEY_VALUE = 0x05FA0000
  52. AIRCR_SYSRESETREQ = (1 << 2)
  53. # Debug Core Register Selector Register (DCRSR)
  54. DCRSR_ADDR = 0xE000EDF4
  55. DCRSR_WRITE = (1 << 16)
  56. # Debug Core Register Data register (DCRDR)
  57. DCRDR_ADDR = 0xE000EDF8
  58. # Debug Exception and Monitor Control register (DEMCR)
  59. DEMCR_ADDR = 0xE000EDFC
  60. DEMCR_RESET_CATCH = (1 << 0)
  61. DEMCR_TRCENA = (1 << 24)
  62. # Program Counter Sample Register (PCSR)
  63. PCSR_ADDR = 0xE000101C
  64. # Loader addresses
  65. PBLLDR_HEADER_ADDR = 0x20000400
  66. PBLLDR_HEADER_OFFSET = PBLLDR_HEADER_ADDR + 0x4
  67. PBLLDR_HEADER_LENGTH = PBLLDR_HEADER_ADDR + 0x8
  68. PBLLDR_DATA_ADDR = 0x20000800
  69. PBLLDR_DATA_MAX_LENGTH = 0x20000
  70. PBLLDR_STATE_WAIT = 0
  71. PBLLDR_STATE_WRITE = 1
  72. PBLLDR_STATE_CRC = 2
  73. # SRAM base addr
  74. SRAM_BASE_ADDR = 0x20000000
  75. def __init__(self, driver):
  76. self._driver = driver
  77. self._step_start_time = 0
  78. self.FLASH_SECTOR_SIZES = [x*1024 for x in self.FLASH_SECTOR_SIZES]
  79. def __enter__(self):
  80. try:
  81. self.connect()
  82. return self
  83. except:
  84. self.close()
  85. raise
  86. def __exit__(self, exc, value, trace):
  87. self.close()
  88. def _fatal(self, message):
  89. raise Exception('FATAL ERROR: {}'.format(message))
  90. def _start_step(self, msg):
  91. LOG.info(msg)
  92. self._step_start_time = time.time()
  93. def _end_step(self, msg, no_time=False, num_bytes=None):
  94. total_time = round(time.time() - self._step_start_time, 2)
  95. if not no_time:
  96. msg += ' in {}s'.format(total_time)
  97. if num_bytes:
  98. kibps = round(num_bytes / 1024.0 / total_time, 2)
  99. msg += ' ({} KiB/s)'.format(kibps)
  100. LOG.info(msg)
  101. def connect(self):
  102. self._start_step('Connecting...')
  103. # connect and check the IDCODE
  104. if self._driver.connect() != self.IDCODE:
  105. self._fatal('Invalid IDCODE')
  106. # check the CPUID register
  107. if self._driver.read_memory_address(self.CPUID_ADDR) != self.CPUID_VALUE:
  108. self._fatal('Invalid CPU ID')
  109. self._end_step('Connected', no_time=True)
  110. def halt_core(self):
  111. # halt the core immediately
  112. dhcsr_value = self.DHCSR_DBGKEY_VALUE | self.DHCSR_DEBUGEN | self.DHCSR_HALT
  113. self._driver.write_memory_address(self.DHCSR_ADDR, dhcsr_value)
  114. def resume_core(self):
  115. # resume the core
  116. dhcsr_value = self.DHCSR_DBGKEY_VALUE
  117. self._driver.write_memory_address(self.DHCSR_ADDR, dhcsr_value)
  118. def reset_core(self, halt=False):
  119. if self._driver.read_memory_address(self.DHCSR_ADDR) & self.DHCSR_S_LOCKUP:
  120. # halt the core first to clear the lockup
  121. LOG.info('Clearing lockup condition')
  122. self.halt_core()
  123. # enable reset vector catch
  124. demcr_value = 0
  125. if halt:
  126. demcr_value |= self.DEMCR_RESET_CATCH
  127. self._driver.write_memory_address(self.DEMCR_ADDR, demcr_value)
  128. self._driver.read_memory_address(self.DHCSR_ADDR)
  129. # reset the core
  130. aircr_value = self.AIRCR_VECTKEY_VALUE | self.AIRCR_SYSRESETREQ
  131. self._driver.write_memory_address(self.AIRCR_ADDR, aircr_value)
  132. if halt:
  133. self.halt_core()
  134. def unlock_flash(self):
  135. # unlock the flash
  136. self._driver.write_memory_address(self.FLASH_KEYR_ADDR, self.FLASH_KEYR_VAL1)
  137. self._driver.write_memory_address(self.FLASH_KEYR_ADDR, self.FLASH_KEYR_VAL2)
  138. def _poll_register(self, timeout=0.5):
  139. end_time = time.time() + timeout
  140. while end_time > time.time():
  141. val = self._driver.read_memory_address(self.DHCSR_ADDR)
  142. if val & self.DHCSR_S_REGRDY:
  143. break
  144. else:
  145. raise Exception('Register operation was not confirmed')
  146. def write_register(self, reg, val):
  147. self._driver.write_memory_address(self.DCRDR_ADDR, val)
  148. reg |= self.DCRSR_WRITE
  149. self._driver.write_memory_address(self.DCRSR_ADDR, reg)
  150. self._poll_register()
  151. def read_register(self, reg):
  152. self._driver.write_memory_address(self.DCRSR_ADDR, reg)
  153. self._poll_register()
  154. return self._driver.read_memory_address(self.DCRDR_ADDR)
  155. def erase_flash(self, flash_offset, length):
  156. self._start_step('Erasing...')
  157. def overlap(a1, a2, b1, b2):
  158. return max(a1, b1) < min(a2, b2)
  159. # find all the sectors which we need to erase
  160. erase_sectors = []
  161. for i, size in enumerate(self.FLASH_SECTOR_SIZES):
  162. addr = self.FLASH_BASE_ADDR + sum(self.FLASH_SECTOR_SIZES[:i])
  163. if overlap(flash_offset, flash_offset+length, addr, addr+size):
  164. erase_sectors += [i]
  165. if not erase_sectors:
  166. self._fatal('Could not find sectors to erase!')
  167. # erase the sectors
  168. for sector in erase_sectors:
  169. # start the erase
  170. reg_value = (sector << self.FLASH_CR_SNB_OFFSET)
  171. reg_value |= self.FLASH_CR_PSIZE_8BIT
  172. reg_value |= self.FLASH_CR_STRT
  173. reg_value |= self.FLASH_CR_SER
  174. self._driver.write_memory_address(self.FLASH_CR_ADDR, reg_value)
  175. # wait for the erase to finish
  176. while self._driver.read_memory_address(self.FLASH_SR_ADDR) & self.FLASH_SR_BSY:
  177. time.sleep(0)
  178. self._end_step('Erased')
  179. def close(self):
  180. self._driver.close()
  181. def _write_loader_state(self, state):
  182. self._driver.write_memory_address(self.PBLLDR_HEADER_ADDR, state)
  183. def _wait_loader_state(self, wanted_state, timeout=3):
  184. end_time = time.time() + timeout
  185. state = -1
  186. while time.time() < end_time:
  187. time.sleep(0)
  188. state = self._driver.read_memory_address(self.PBLLDR_HEADER_ADDR)
  189. if state == wanted_state:
  190. break
  191. else:
  192. raise Exception("Timed out waiting for loader state %d, got %d" % (wanted_state, state))
  193. @staticmethod
  194. def _chunks(l, n):
  195. for i in xrange(0, len(l), n):
  196. yield l[i:i+n], len(l[i:i+n]), i
  197. def execute_loader(self):
  198. # reset and halt the core
  199. self.reset_core(halt=True)
  200. with open(os.path.join(os.path.dirname(__file__), "loader.bin")) as f:
  201. loader_bin = f.read()
  202. # load loader binary into SRAM
  203. self._driver.write_memory_bulk(self.SRAM_BASE_ADDR, array.array('B', loader_bin))
  204. # set PC based on value in loader
  205. reg_sp, = struct.unpack("<I", loader_bin[:4])
  206. self.write_register(13, reg_sp)
  207. # set PC to new reset handler
  208. pc, = struct.unpack('<I', loader_bin[4:8])
  209. self.write_register(15, pc)
  210. # unlock flash
  211. self.unlock_flash()
  212. self.resume_core()
  213. @staticmethod
  214. def generate_crc(data):
  215. length = len(data)
  216. lookup_table = [0, 47, 94, 113, 188, 147, 226, 205, 87, 120, 9, 38, 235, 196, 181, 154]
  217. crc = 0
  218. for i in xrange(length*2):
  219. nibble = data[i / 2]
  220. if i % 2 == 0:
  221. nibble >>= 4
  222. index = nibble ^ (crc >> 4)
  223. crc = lookup_table[index & 0xf] ^ ((crc << 4) & 0xf0)
  224. return crc
  225. def read_crc(self, addr, length):
  226. self._driver.write_memory_address(self.PBLLDR_HEADER_OFFSET, addr)
  227. self._driver.write_memory_address(self.PBLLDR_HEADER_LENGTH, length)
  228. self._write_loader_state(self.PBLLDR_STATE_CRC)
  229. self._wait_loader_state(self.PBLLDR_STATE_WAIT)
  230. return self._driver.read_memory_address(self.PBLLDR_DATA_ADDR) & 0xFF
  231. def load_hex(self, hex_path):
  232. self._start_step("Loading binary: %s" % hex_path)
  233. ih = IntelHex(hex_path)
  234. offset = ih.minaddr()
  235. data = ih.tobinarray()
  236. self.load_bin(offset, data)
  237. self._end_step("Loaded binary", num_bytes=len(data))
  238. def load_bin(self, offset, data):
  239. while len(data) % 4 != 0:
  240. data.append(0xFF)
  241. length = len(data)
  242. # prepare the flash for programming
  243. self.erase_flash(offset, length)
  244. cr_value = self.FLASH_CR_PSIZE_8BIT | self.FLASH_CR_PG
  245. self._driver.write_memory_address(self.FLASH_CR_ADDR, cr_value)
  246. # set the base address
  247. self._wait_loader_state(self.PBLLDR_STATE_WAIT)
  248. self._driver.write_memory_address(self.PBLLDR_HEADER_OFFSET, offset)
  249. for chunk, chunk_length, pos in self._chunks(data, self.PBLLDR_DATA_MAX_LENGTH):
  250. LOG.info("Written %d/%d", pos, length)
  251. self._driver.write_memory_address(self.PBLLDR_HEADER_LENGTH, chunk_length)
  252. self._driver.write_memory_bulk(self.PBLLDR_DATA_ADDR, chunk)
  253. self._write_loader_state(self.PBLLDR_STATE_WRITE)
  254. self._wait_loader_state(self.PBLLDR_STATE_WAIT)
  255. expected_crc = self.generate_crc(data)
  256. actual_crc = self.read_crc(offset, length)
  257. if actual_crc != expected_crc:
  258. raise Exception("Bad CRC, expected %d, found %d" % (expected_crc, actual_crc))
  259. LOG.info("CRC-8 matched: %d", actual_crc)
  260. def profile(self, duration):
  261. LOG.info('Collecting %f second(s) worth of samples...', duration)
  262. # ensure DWT is enabled so we can get PC samples from PCSR
  263. demcr_value = self._driver.read_memory_address(self.DEMCR_ADDR)
  264. self._driver.write_memory_address(self.DEMCR_ADDR, demcr_value | self.DEMCR_TRCENA)
  265. # take the samples
  266. samples = self._driver.continuous_read(self.PCSR_ADDR, duration)
  267. # restore the original DEMCR value
  268. self._driver.write_memory_address(self.DEMCR_ADDR, demcr_value)
  269. # process the samples
  270. pcs = dict()
  271. for sample in samples:
  272. sample = '0x%08x' % sample
  273. pcs[sample] = pcs.get(sample, 0) + 1
  274. LOG.info('Collected %d samples!', len(samples))
  275. return pcs