stylesheet.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright (c) 2010-2019 openpyxl
  2. from warnings import warn
  3. from openpyxl.descriptors.serialisable import Serialisable
  4. from openpyxl.descriptors import (
  5. Typed,
  6. )
  7. from openpyxl.descriptors.sequence import NestedSequence
  8. from openpyxl.descriptors.excel import ExtensionList
  9. from openpyxl.utils.indexed_list import IndexedList
  10. from openpyxl.xml.constants import ARC_STYLE, SHEET_MAIN_NS
  11. from openpyxl.xml.functions import fromstring
  12. from .builtins import styles
  13. from .colors import ColorList, COLOR_INDEX
  14. from .differential import DifferentialStyle
  15. from .table import TableStyleList
  16. from .borders import Border
  17. from .fills import Fill
  18. from .fonts import Font
  19. from .numbers import (
  20. NumberFormatList,
  21. BUILTIN_FORMATS,
  22. BUILTIN_FORMATS_MAX_SIZE,
  23. BUILTIN_FORMATS_REVERSE,
  24. is_date_format,
  25. builtin_format_code
  26. )
  27. from .named_styles import (
  28. _NamedCellStyleList
  29. )
  30. from .cell_style import CellStyle, CellStyleList
  31. class Stylesheet(Serialisable):
  32. tagname = "styleSheet"
  33. numFmts = Typed(expected_type=NumberFormatList)
  34. fonts = NestedSequence(expected_type=Font, count=True)
  35. fills = NestedSequence(expected_type=Fill, count=True)
  36. borders = NestedSequence(expected_type=Border, count=True)
  37. cellStyleXfs = Typed(expected_type=CellStyleList)
  38. cellXfs = Typed(expected_type=CellStyleList)
  39. cellStyles = Typed(expected_type=_NamedCellStyleList)
  40. dxfs = NestedSequence(expected_type=DifferentialStyle, count=True)
  41. tableStyles = Typed(expected_type=TableStyleList, allow_none=True)
  42. colors = Typed(expected_type=ColorList, allow_none=True)
  43. extLst = Typed(expected_type=ExtensionList, allow_none=True)
  44. __elements__ = ('numFmts', 'fonts', 'fills', 'borders', 'cellStyleXfs',
  45. 'cellXfs', 'cellStyles', 'dxfs', 'tableStyles', 'colors')
  46. def __init__(self,
  47. numFmts=None,
  48. fonts=(),
  49. fills=(),
  50. borders=(),
  51. cellStyleXfs=None,
  52. cellXfs=None,
  53. cellStyles=None,
  54. dxfs=(),
  55. tableStyles=None,
  56. colors=None,
  57. extLst=None,
  58. ):
  59. if numFmts is None:
  60. numFmts = NumberFormatList()
  61. self.numFmts = numFmts
  62. self.number_formats = IndexedList()
  63. self.fonts = fonts
  64. self.fills = fills
  65. self.borders = borders
  66. if cellStyleXfs is None:
  67. cellStyleXfs = CellStyleList()
  68. self.cellStyleXfs = cellStyleXfs
  69. if cellXfs is None:
  70. cellXfs = CellStyleList()
  71. self.cellXfs = cellXfs
  72. if cellStyles is None:
  73. cellStyles = _NamedCellStyleList()
  74. self.cellStyles = cellStyles
  75. self.dxfs = dxfs
  76. self.tableStyles = tableStyles
  77. self.colors = colors
  78. self.cell_styles = self.cellXfs._to_array()
  79. self.alignments = self.cellXfs.alignments
  80. self.protections = self.cellXfs.prots
  81. self._normalise_numbers()
  82. self.named_styles = self._merge_named_styles()
  83. @classmethod
  84. def from_tree(cls, node):
  85. # strip all attribs
  86. attrs = dict(node.attrib)
  87. for k in attrs:
  88. del node.attrib[k]
  89. return super(Stylesheet, cls).from_tree(node)
  90. def _merge_named_styles(self):
  91. """
  92. Merge named style names "cellStyles" with their associated styles
  93. "cellStyleXfs"
  94. """
  95. named_styles = self.cellStyles.names
  96. for style in named_styles:
  97. self._expand_named_style(style)
  98. return named_styles
  99. def _expand_named_style(self, named_style):
  100. """
  101. Bind format definitions for a named style from the associated style
  102. record
  103. """
  104. xf = self.cellStyleXfs[named_style.xfId]
  105. named_style.font = self.fonts[xf.fontId]
  106. named_style.fill = self.fills[xf.fillId]
  107. named_style.border = self.borders[xf.borderId]
  108. if xf.numFmtId < BUILTIN_FORMATS_MAX_SIZE:
  109. formats = BUILTIN_FORMATS
  110. else:
  111. formats = self.custom_formats
  112. if xf.numFmtId in formats:
  113. named_style.number_format = formats[xf.numFmtId]
  114. if xf.alignment:
  115. named_style.alignment = xf.alignment
  116. if xf.protection:
  117. named_style.protection = xf.protection
  118. def _split_named_styles(self, wb):
  119. """
  120. Convert NamedStyle into separate CellStyle and Xf objects
  121. """
  122. for style in wb._named_styles:
  123. self.cellStyles.cellStyle.append(style.as_name())
  124. self.cellStyleXfs.xf.append(style.as_xf())
  125. @property
  126. def custom_formats(self):
  127. return dict([(n.numFmtId, n.formatCode) for n in self.numFmts.numFmt])
  128. def _normalise_numbers(self):
  129. """
  130. Rebase custom numFmtIds with a floor of 164 when reading stylesheet
  131. And index datetime formats
  132. """
  133. date_formats = set()
  134. custom = self.custom_formats
  135. formats = self.number_formats
  136. for idx, style in enumerate(self.cell_styles):
  137. if style.numFmtId in custom:
  138. fmt = custom[style.numFmtId]
  139. if fmt in BUILTIN_FORMATS_REVERSE: # remove builtins
  140. style.numFmtId = BUILTIN_FORMATS_REVERSE[fmt]
  141. else:
  142. style.numFmtId = formats.add(fmt) + BUILTIN_FORMATS_MAX_SIZE
  143. else:
  144. fmt = builtin_format_code(style.numFmtId)
  145. if is_date_format(fmt):
  146. # Create an index of which styles refer to datetimes
  147. date_formats.add(idx)
  148. self.date_formats = date_formats
  149. def to_tree(self, tagname=None, idx=None, namespace=None):
  150. tree = super(Stylesheet, self).to_tree(tagname, idx, namespace)
  151. tree.set("xmlns", SHEET_MAIN_NS)
  152. return tree
  153. def apply_stylesheet(archive, wb):
  154. """
  155. Add styles to workbook if present
  156. """
  157. try:
  158. src = archive.read(ARC_STYLE)
  159. except KeyError:
  160. return wb
  161. node = fromstring(src)
  162. stylesheet = Stylesheet.from_tree(node)
  163. wb._borders = IndexedList(stylesheet.borders)
  164. wb._fonts = IndexedList(stylesheet.fonts)
  165. wb._fills = IndexedList(stylesheet.fills)
  166. wb._differential_styles.styles = stylesheet.dxfs
  167. wb._number_formats = stylesheet.number_formats
  168. wb._protections = stylesheet.protections
  169. wb._alignments = stylesheet.alignments
  170. wb._table_styles = stylesheet.tableStyles
  171. # need to overwrite openpyxl defaults in case workbook has different ones
  172. wb._cell_styles = stylesheet.cell_styles
  173. wb._named_styles = stylesheet.named_styles
  174. wb._date_formats = stylesheet.date_formats
  175. for ns in wb._named_styles:
  176. ns.bind(wb)
  177. if not wb._named_styles:
  178. normal = styles['Normal']
  179. wb.add_named_style(normal)
  180. warn("Workbook contains no default style, apply openpyxl's default")
  181. if stylesheet.colors is not None:
  182. wb._colors = stylesheet.colors.index
  183. def write_stylesheet(wb):
  184. stylesheet = Stylesheet()
  185. stylesheet.fonts = wb._fonts
  186. stylesheet.fills = wb._fills
  187. stylesheet.borders = wb._borders
  188. stylesheet.dxfs = wb._differential_styles.styles
  189. from .numbers import NumberFormat
  190. fmts = []
  191. for idx, code in enumerate(wb._number_formats, BUILTIN_FORMATS_MAX_SIZE):
  192. fmt = NumberFormat(idx, code)
  193. fmts.append(fmt)
  194. stylesheet.numFmts.numFmt = fmts
  195. xfs = []
  196. for style in wb._cell_styles:
  197. xf = CellStyle.from_array(style)
  198. if style.alignmentId:
  199. xf.alignment = wb._alignments[style.alignmentId]
  200. if style.protectionId:
  201. xf.protection = wb._protections[style.protectionId]
  202. xfs.append(xf)
  203. stylesheet.cellXfs = CellStyleList(xf=xfs)
  204. stylesheet._split_named_styles(wb)
  205. stylesheet.tableStyles = wb._table_styles
  206. return stylesheet.to_tree()