1414if t .TYPE_CHECKING :
1515 from jinja2 import Environment
1616
17+ from sqlmesh .dbt .basemodel import Dependencies
1718 from sqlmesh .dbt .model import ModelConfig
1819 from sqlmesh .dbt .seed import SeedConfig
1920 from sqlmesh .dbt .source import SourceConfig
@@ -40,7 +41,7 @@ class DbtContext:
4041 _models : t .Dict [str , ModelConfig ] = field (default_factory = dict )
4142 _seeds : t .Dict [str , SeedConfig ] = field (default_factory = dict )
4243 _sources : t .Dict [str , SourceConfig ] = field (default_factory = dict )
43- _refs : t .Dict [str , str ] = field (default_factory = dict )
44+ _refs : t .Dict [str , t . Union [ ModelConfig , SeedConfig ] ] = field (default_factory = dict )
4445
4546 _target : t .Optional [TargetConfig ] = None
4647
@@ -128,9 +129,12 @@ def add_sources(self, sources: t.Dict[str, SourceConfig]) -> None:
128129 self ._jinja_environment = None
129130
130131 @property
131- def refs (self ) -> t .Dict [str , str ]:
132+ def refs (self ) -> t .Dict [str , t . Union [ ModelConfig , SeedConfig ] ]:
132133 if not self ._refs :
133- self ._refs = {k : v .model_name for k , v in {** self ._seeds , ** self ._models }.items ()} # type: ignore
134+ # Refs can be called with or without package name.
135+ for model in {** self ._seeds , ** self ._models }.values (): # type: ignore
136+ self ._refs [model .name ] = model
137+ self ._refs [model .config_name ] = model
134138 return self ._refs
135139
136140 @property
@@ -162,14 +166,46 @@ def jinja_environment(self) -> Environment:
162166
163167 @property
164168 def jinja_globals (self ) -> t .Dict [str , JinjaGlobalAttribute ]:
165- refs : t .Dict [str , t .Union [ModelConfig , SeedConfig ]] = {** self .models , ** self .seeds }
166169 output : t .Dict [str , JinjaGlobalAttribute ] = {
167170 "vars" : AttributeDict (self .variables ),
168- "refs" : AttributeDict ({k : v .relation_info for k , v in refs .items ()}),
171+ "refs" : AttributeDict ({k : v .relation_info for k , v in self . refs .items ()}),
169172 "sources" : AttributeDict ({k : v .relation_info for k , v in self .sources .items ()}),
170173 }
171174 if self .project_name is not None :
172175 output ["project_name" ] = self .project_name
173176 if self ._target is not None :
174177 output ["target" ] = self ._target .attribute_dict ()
175178 return output
179+
180+ def context_for_dependencies (self , dependencies : Dependencies ) -> DbtContext :
181+ from sqlmesh .dbt .model import ModelConfig
182+ from sqlmesh .dbt .seed import SeedConfig
183+
184+ dependency_context = self .copy ()
185+
186+ models = {}
187+ seeds = {}
188+ sources = {}
189+
190+ for ref in dependencies .refs :
191+ model = self .refs .get (ref )
192+ if model :
193+ if isinstance (model , SeedConfig ):
194+ seeds [ref ] = t .cast (SeedConfig , model )
195+ else :
196+ models [ref ] = t .cast (ModelConfig , model )
197+ else :
198+ raise ConfigError (f"Model '{ ref } ' was not found." )
199+
200+ for source in dependencies .sources :
201+ if source in self .sources :
202+ sources [source ] = self .sources [source ]
203+ else :
204+ raise ConfigError (f"Source '{ source } ' was not found." )
205+
206+ dependency_context .sources = sources
207+ dependency_context .seeds = seeds
208+ dependency_context .models = models
209+ dependency_context ._refs = {** dependency_context ._seeds , ** dependency_context ._models } # type: ignore
210+
211+ return dependency_context
0 commit comments